diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index fce8d0e8..a25a484d 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -317,9 +317,9 @@ jobs: wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb packages=( - bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6 - libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6 - libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 + bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1 + libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev + libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0 libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf python3 vulkan-sdk @@ -352,6 +352,8 @@ jobs: ~/Qt/Tools/CMake/bin/cmake \ -S ../gpt4all-chat -B . \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=gcc-12 \ + -DCMAKE_CXX_COMPILER=g++-12 \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ @@ -391,9 +393,9 @@ jobs: wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb packages=( - bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6 - libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6 - libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 + bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1 + libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev + libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0 libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 patchelf python3 vulkan-sdk @@ -426,6 +428,8 @@ jobs: ~/Qt/Tools/CMake/bin/cmake \ -S ../gpt4all-chat -B . \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=gcc-12 \ + -DCMAKE_CXX_COMPILER=g++-12 \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ @@ -447,7 +451,7 @@ jobs: build-offline-chat-installer-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -538,7 +542,7 @@ jobs: sign-offline-chat-installer-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -568,7 +572,7 @@ jobs: build-online-chat-installer-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -666,7 +670,7 @@ jobs: sign-online-chat-installer-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -720,9 +724,9 @@ jobs: wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb packages=( - bison build-essential ccache cuda-compiler-11-8 flex gperf libcublas-dev-11-8 libfontconfig1 libfreetype6 - libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev libx11-6 - libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 + bison build-essential ccache cuda-compiler-11-8 flex g++-12 gperf libcublas-dev-11-8 libfontconfig1 + libfreetype6 libgl1-mesa-dev libmysqlclient21 libnvidia-compute-550-server libodbc2 libpq5 libwayland-dev + libx11-6 libx11-xcb1 libxcb-cursor0 libxcb-glx0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-shape0 libxcb-shm0 libxcb-sync1 libxcb-util1 libxcb-xfixes0 libxcb-xinerama0 libxcb-xkb1 libxcb1 libxext6 libxfixes3 libxi6 libxkbcommon-x11-0 libxkbcommon0 libxrender1 python3 vulkan-sdk @@ -744,6 +748,8 @@ jobs: ~/Qt/Tools/CMake/bin/cmake \ -S gpt4all-chat -B build \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=gcc-12 \ + -DCMAKE_CXX_COMPILER=g++-12 \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ @@ -758,7 +764,7 @@ jobs: build-gpt4all-chat-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -864,8 +870,8 @@ jobs: paths: - ../.ccache - build-ts-docs: - docker: + build-ts-docs: + docker: - image: cimg/base:stable steps: - checkout @@ -887,7 +893,7 @@ jobs: docker: - image: circleci/python:3.8 steps: - - checkout + - checkout - run: name: Install dependencies command: | @@ -928,7 +934,8 @@ jobs: wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb packages=( - build-essential ccache cmake cuda-compiler-11-8 libcublas-dev-11-8 libnvidia-compute-550-server vulkan-sdk + build-essential ccache cmake cuda-compiler-11-8 g++-12 libcublas-dev-11-8 libnvidia-compute-550-server + vulkan-sdk ) sudo apt-get update sudo apt-get install -y "${packages[@]}" @@ -942,6 +949,8 @@ jobs: cd gpt4all-backend cmake -B build \ -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=gcc-12 \ + -DCMAKE_CXX_COMPILER=g++-12 \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ @@ -1014,7 +1023,7 @@ jobs: build-py-windows: machine: - image: 'windows-server-2019-vs2019:2022.08.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -1118,11 +1127,12 @@ jobs: name: Install dependencies command: | wget -qO- https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo tee /etc/apt/trusted.gpg.d/lunarg.asc - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list http://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list http://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb packages=( - build-essential ccache cmake cuda-compiler-11-8 libcublas-dev-11-8 libnvidia-compute-550-server vulkan-sdk + build-essential ccache cmake cuda-compiler-11-8 g++-12 libcublas-dev-11-8 libnvidia-compute-550-server + vulkan-sdk ) sudo apt-get update sudo apt-get install -y "${packages[@]}" @@ -1135,6 +1145,9 @@ jobs: mkdir -p runtimes/build cd runtimes/build cmake ../.. \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=gcc-12 \ + -DCMAKE_C_COMPILER=g++-12 \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ @@ -1204,7 +1217,7 @@ jobs: build-bindings-backend-windows: machine: - image: 'windows-server-2022-gui:2023.03.1' + image: windows-server-2022-gui:current resource_class: windows.large shell: powershell.exe -ExecutionPolicy Bypass steps: @@ -1230,7 +1243,7 @@ jobs: - run: name: Install dependencies command: | - choco install -y ccache cmake ninja --installargs 'ADD_CMAKE_TO_PATH=System' + choco install -y ccache cmake ninja --installargs 'ADD_CMAKE_TO_PATH=System' - run: name: Build Libraries command: | @@ -1263,8 +1276,8 @@ jobs: paths: - runtimes/win-x64_msvc/*.dll - build-nodejs-linux: - docker: + build-nodejs-linux: + docker: - image: cimg/base:stable steps: - checkout @@ -1280,10 +1293,10 @@ jobs: pkg-manager: yarn override-ci-command: yarn install - run: - command: | + command: | cd gpt4all-bindings/typescript yarn prebuildify -t 18.16.0 --napi - - run: + - run: command: | mkdir -p gpt4all-backend/prebuilds/linux-x64 mkdir -p gpt4all-backend/runtimes/linux-x64 @@ -1292,10 +1305,10 @@ jobs: - persist_to_workspace: root: gpt4all-backend paths: - - prebuilds/linux-x64/*.node + - prebuilds/linux-x64/*.node - runtimes/linux-x64/*-*.so - build-nodejs-macos: + build-nodejs-macos: macos: xcode: 15.4.0 steps: @@ -1312,12 +1325,12 @@ jobs: pkg-manager: yarn override-ci-command: yarn install - run: - command: | + command: | cd gpt4all-bindings/typescript yarn prebuildify -t 18.16.0 --napi - - run: + - run: name: "Persisting all necessary things to workspace" - command: | + command: | mkdir -p gpt4all-backend/prebuilds/darwin-x64 mkdir -p gpt4all-backend/runtimes/darwin cp /tmp/gpt4all-backend/runtimes/osx-x64/*-*.* gpt4all-backend/runtimes/darwin @@ -1328,7 +1341,7 @@ jobs: - prebuilds/darwin-x64/*.node - runtimes/darwin/*-*.* - build-nodejs-windows: + build-nodejs-windows: executor: name: win/default size: large @@ -1342,29 +1355,29 @@ jobs: command: wget https://nodejs.org/dist/v18.16.0/node-v18.16.0-x86.msi -P C:\Users\circleci\Downloads\ shell: cmd.exe - run: MsiExec.exe /i C:\Users\circleci\Downloads\node-v18.16.0-x86.msi /qn - - run: + - run: command: | Start-Process powershell -verb runAs -Args "-start GeneralProfile" nvm install 18.16.0 nvm use 18.16.0 - - run: node --version + - run: node --version - run: corepack enable - - run: + - run: command: | npm install -g yarn cd gpt4all-bindings/typescript yarn install - run: - command: | + command: | cd gpt4all-bindings/typescript - yarn prebuildify -t 18.16.0 --napi - - run: + yarn prebuildify -t 18.16.0 --napi + - run: command: | mkdir -p gpt4all-backend/prebuilds/win32-x64 mkdir -p gpt4all-backend/runtimes/win32-x64 cp /tmp/gpt4all-backend/runtimes/win-x64_msvc/*-*.dll gpt4all-backend/runtimes/win32-x64 cp gpt4all-bindings/typescript/prebuilds/win32-x64/*.node gpt4all-backend/prebuilds/win32-x64 - + - persist_to_workspace: root: gpt4all-backend paths: @@ -1372,7 +1385,7 @@ jobs: - runtimes/win32-x64/*-*.dll prepare-npm-pkg: - docker: + docker: - image: cimg/base:stable steps: - attach_workspace: @@ -1383,19 +1396,19 @@ jobs: node-version: "18.16" - run: node --version - run: corepack enable - - run: + - run: command: | cd gpt4all-bindings/typescript # excluding llmodel. nodejs bindings dont need llmodel.dll mkdir -p runtimes/win32-x64/native mkdir -p prebuilds/win32-x64/ - cp /tmp/gpt4all-backend/runtimes/win-x64_msvc/*-*.dll runtimes/win32-x64/native/ - cp /tmp/gpt4all-backend/prebuilds/win32-x64/*.node prebuilds/win32-x64/ + cp /tmp/gpt4all-backend/runtimes/win-x64_msvc/*-*.dll runtimes/win32-x64/native/ + cp /tmp/gpt4all-backend/prebuilds/win32-x64/*.node prebuilds/win32-x64/ - mkdir -p runtimes/linux-x64/native + mkdir -p runtimes/linux-x64/native mkdir -p prebuilds/linux-x64/ - cp /tmp/gpt4all-backend/runtimes/linux-x64/*-*.so runtimes/linux-x64/native/ - cp /tmp/gpt4all-backend/prebuilds/linux-x64/*.node prebuilds/linux-x64/ + cp /tmp/gpt4all-backend/runtimes/linux-x64/*-*.so runtimes/linux-x64/native/ + cp /tmp/gpt4all-backend/prebuilds/linux-x64/*.node prebuilds/linux-x64/ # darwin has univeral runtime libraries mkdir -p runtimes/darwin/native @@ -1403,22 +1416,22 @@ jobs: cp /tmp/gpt4all-backend/runtimes/darwin/*-*.* runtimes/darwin/native/ - cp /tmp/gpt4all-backend/prebuilds/darwin-x64/*.node prebuilds/darwin-x64/ - + cp /tmp/gpt4all-backend/prebuilds/darwin-x64/*.node prebuilds/darwin-x64/ + # Fallback build if user is not on above prebuilds mv -f binding.ci.gyp binding.gyp mkdir gpt4all-backend cd ../../gpt4all-backend mv llmodel.h llmodel.cpp llmodel_c.cpp llmodel_c.h sysinfo.h dlhandle.h ../gpt4all-bindings/typescript/gpt4all-backend/ - + # Test install - node/install-packages: app-dir: gpt4all-bindings/typescript pkg-manager: yarn override-ci-command: yarn install - - run: - command: | + - run: + command: | cd gpt4all-bindings/typescript yarn run test - run: @@ -1552,7 +1565,7 @@ workflows: - build-py-linux - build-py-macos build-bindings: - when: + when: or: - << pipeline.parameters.run-all-workflows >> - << pipeline.parameters.run-python-workflow >> @@ -1585,8 +1598,8 @@ workflows: requires: - hold - # NodeJs Jobs - - prepare-npm-pkg: + # NodeJs Jobs + - prepare-npm-pkg: filters: branches: only: diff --git a/.gitmodules b/.gitmodules index b59d07fd..6ed4b266 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,3 +8,6 @@ [submodule "gpt4all-chat/deps/SingleApplication"] path = gpt4all-chat/deps/SingleApplication url = https://github.com/nomic-ai/SingleApplication.git +[submodule "gpt4all-chat/deps/fmt"] + path = gpt4all-chat/deps/fmt + url = https://github.com/fmtlib/fmt.git diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index 2c1fbb46..fb5937aa 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0) set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}") project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C) -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) set(BUILD_SHARED_LIBS ON) diff --git a/gpt4all-backend/deps/llama.cpp-mainline b/gpt4all-backend/deps/llama.cpp-mainline index 443665ae..ced74fba 160000 --- a/gpt4all-backend/deps/llama.cpp-mainline +++ b/gpt4all-backend/deps/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit 443665aec4721ecf57df8162e7e093a0cd674a76 +Subproject commit ced74fbad4b258507f3ec06e77eec9445583511a diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel.h b/gpt4all-backend/include/gpt4all-backend/llmodel.h index 04a510dc..d18584eb 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel.h @@ -162,7 +162,7 @@ public: bool allowContextShift, PromptContext &ctx, bool special = false, - std::string *fakeReply = nullptr); + std::optional fakeReply = {}); using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); @@ -212,7 +212,7 @@ public: protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions - virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0; + virtual std::vector tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0; virtual bool isSpecialToken(Token id) const = 0; virtual std::string tokenToString(Token id) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0; @@ -249,7 +249,8 @@ protected: std::function responseCallback, bool allowContextShift, PromptContext &promptCtx, - std::vector embd_inp); + std::vector embd_inp, + bool isResponse = false); void generateResponse(std::function responseCallback, bool allowContextShift, PromptContext &promptCtx); diff --git a/gpt4all-backend/src/llamamodel.cpp b/gpt4all-backend/src/llamamodel.cpp index e2bbd0ac..8c92b025 100644 --- a/gpt4all-backend/src/llamamodel.cpp +++ b/gpt4all-backend/src/llamamodel.cpp @@ -536,13 +536,13 @@ size_t LLamaModel::restoreState(const uint8_t *src) return llama_set_state_data(d_ptr->ctx, const_cast(src)); } -std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) +std::vector LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special) { bool atStart = m_tokenize_last_token == -1; bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token); std::vector fres(str.length() + 4); int32_t fres_len = llama_tokenize_gpt4all( - d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, + d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, /*parse_special*/ special, /*insert_space*/ insertSpace ); fres.resize(fres_len); diff --git a/gpt4all-backend/src/llamamodel_impl.h b/gpt4all-backend/src/llamamodel_impl.h index 7c698ffa..5189b9b3 100644 --- a/gpt4all-backend/src/llamamodel_impl.h +++ b/gpt4all-backend/src/llamamodel_impl.h @@ -8,6 +8,7 @@ #include #include +#include #include struct LLamaPrivate; @@ -52,7 +53,7 @@ private: bool m_supportsCompletion = false; protected: - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override; + std::vector tokenize(PromptContext &ctx, std::string_view str, bool special) override; bool isSpecialToken(Token id) const override; std::string tokenToString(Token id) const override; Token sampleToken(PromptContext &ctx) const override; diff --git a/gpt4all-backend/src/llmodel_c.cpp b/gpt4all-backend/src/llmodel_c.cpp index f3fd68ff..b0974223 100644 --- a/gpt4all-backend/src/llmodel_c.cpp +++ b/gpt4all-backend/src/llmodel_c.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include struct LLModelWrapper { @@ -130,13 +131,10 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; wrapper->promptContext.contextErase = ctx->context_erase; - std::string fake_reply_str; - if (fake_reply) { fake_reply_str = fake_reply; } - auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr; - // Call the C++ prompt method wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift, - wrapper->promptContext, special, fake_reply_p); + wrapper->promptContext, special, + fake_reply ? std::make_optional(fake_reply) : std::nullopt); // Update the C context by giving access to the wrappers raw pointers to std::vector data // which involves no copies diff --git a/gpt4all-backend/src/llmodel_shared.cpp b/gpt4all-backend/src/llmodel_shared.cpp index 570f62c6..b4d5accc 100644 --- a/gpt4all-backend/src/llmodel_shared.cpp +++ b/gpt4all-backend/src/llmodel_shared.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace ranges = std::ranges; @@ -45,7 +46,7 @@ void LLModel::prompt(const std::string &prompt, bool allowContextShift, PromptContext &promptCtx, bool special, - std::string *fakeReply) + std::optional fakeReply) { if (!isModelLoaded()) { std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; @@ -129,11 +130,11 @@ void LLModel::prompt(const std::string &prompt, return; // error // decode the assistant's reply, either generated or spoofed - if (fakeReply == nullptr) { + if (!fakeReply) { generateResponse(responseCallback, allowContextShift, promptCtx); } else { embd_inp = tokenize(promptCtx, *fakeReply, false); - if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp)) + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true)) return; // error } @@ -157,7 +158,8 @@ bool LLModel::decodePrompt(std::function promptCallback, std::function responseCallback, bool allowContextShift, PromptContext &promptCtx, - std::vector embd_inp) { + std::vector embd_inp, + bool isResponse) { if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << @@ -196,7 +198,9 @@ bool LLModel::decodePrompt(std::function promptCallback, for (size_t t = 0; t < tokens; ++t) { promptCtx.tokens.push_back(batch.at(t)); promptCtx.n_past += 1; - if (!promptCallback(batch.at(t))) + Token tok = batch.at(t); + bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); + if (!res) return false; } i = batch_end; diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index c774a040..91b42f51 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Fix a typo in Model Settings (by [@3Simplex](https://github.com/3Simplex) in [#2916](https://github.com/nomic-ai/gpt4all/pull/2916)) - Fix the antenna icon tooltip when using the local server ([#2922](https://github.com/nomic-ai/gpt4all/pull/2922)) - Fix a few issues with locating files and handling errors when loading remote models on startup ([#2875](https://github.com/nomic-ai/gpt4all/pull/2875)) +- Significantly improve API server request parsing and response correctness ([#2929](https://github.com/nomic-ai/gpt4all/pull/2929)) ## [3.2.1] - 2024-08-13 diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index fa70a793..bbc7d9b2 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.16) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) if(APPLE) @@ -64,6 +64,12 @@ message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}") set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(FMT_INSTALL OFF) +set(BUILD_SHARED_LIBS_SAVED "${BUILD_SHARED_LIBS}") +set(BUILD_SHARED_LIBS OFF) +add_subdirectory(deps/fmt) +set(BUILD_SHARED_LIBS "${BUILD_SHARED_LIBS_SAVED}") + add_subdirectory(../gpt4all-backend llmodel) set(CHAT_EXE_RESOURCES) @@ -240,7 +246,7 @@ else() PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf) endif() target_link_libraries(chat - PRIVATE llmodel SingleApplication) + PRIVATE llmodel SingleApplication fmt::fmt) # -- install -- diff --git a/gpt4all-chat/deps/fmt b/gpt4all-chat/deps/fmt new file mode 160000 index 00000000..0c9fce2f --- /dev/null +++ b/gpt4all-chat/deps/fmt @@ -0,0 +1 @@ +Subproject commit 0c9fce2ffefecfdce794e1859584e25877b7b592 diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index d9a66091..dd0bf1ec 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -239,16 +239,17 @@ void Chat::newPromptResponsePair(const QString &prompt) resetResponseState(); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt("Prompt: ", prompt); - m_chatModel->appendResponse("Response: ", prompt); + m_chatModel->appendResponse("Response: ", QString()); emit resetResponseRequested(); } +// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread void Chat::serverNewPromptResponsePair(const QString &prompt) { resetResponseState(); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->appendPrompt("Prompt: ", prompt); - m_chatModel->appendResponse("Response: ", prompt); + m_chatModel->appendResponse("Response: ", QString()); } bool Chat::restoringFromText() const diff --git a/gpt4all-chat/src/chatapi.cpp b/gpt4all-chat/src/chatapi.cpp index 06594a32..5634e8b2 100644 --- a/gpt4all-chat/src/chatapi.cpp +++ b/gpt4all-chat/src/chatapi.cpp @@ -93,7 +93,7 @@ void ChatAPI::prompt(const std::string &prompt, bool allowContextShift, PromptContext &promptCtx, bool special, - std::string *fakeReply) { + std::optional fakeReply) { Q_UNUSED(promptCallback); Q_UNUSED(allowContextShift); @@ -121,7 +121,7 @@ void ChatAPI::prompt(const std::string &prompt, if (fakeReply) { promptCtx.n_past += 1; m_context.append(formattedPrompt); - m_context.append(QString::fromStdString(*fakeReply)); + m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size())); return; } diff --git a/gpt4all-chat/src/chatapi.h b/gpt4all-chat/src/chatapi.h index 724178de..a5e1ad58 100644 --- a/gpt4all-chat/src/chatapi.h +++ b/gpt4all-chat/src/chatapi.h @@ -12,9 +12,10 @@ #include #include -#include #include +#include #include +#include #include class QNetworkAccessManager; @@ -72,7 +73,7 @@ public: bool allowContextShift, PromptContext &ctx, bool special, - std::string *fakeReply) override; + std::optional fakeReply) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; @@ -97,7 +98,7 @@ protected: // them as they are only called from the default implementation of 'prompt' which we override and // completely replace - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override + std::vector tokenize(PromptContext &ctx, std::string_view str, bool special) override { (void)ctx; (void)str; diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index fd9316f5..a81d49bd 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -626,16 +626,16 @@ void ChatLLM::regenerateResponse() m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; - m_response = std::string(); - emit responseChanged(QString::fromStdString(m_response)); + m_response = m_trimmedResponse = std::string(); + emit responseChanged(QString::fromStdString(m_trimmedResponse)); } void ChatLLM::resetResponse() { m_promptTokens = 0; m_promptResponseTokens = 0; - m_response = std::string(); - emit responseChanged(QString::fromStdString(m_response)); + m_response = m_trimmedResponse = std::string(); + emit responseChanged(QString::fromStdString(m_trimmedResponse)); } void ChatLLM::resetContext() @@ -645,9 +645,12 @@ void ChatLLM::resetContext() m_ctx = LLModel::PromptContext(); } -QString ChatLLM::response() const +QString ChatLLM::response(bool trim) const { - return QString::fromStdString(remove_leading_whitespace(m_response)); + std::string resp = m_response; + if (trim) + resp = remove_leading_whitespace(resp); + return QString::fromStdString(resp); } ModelInfo ChatLLM::modelInfo() const @@ -705,7 +708,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) // check for error if (token < 0) { m_response.append(response); - emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + m_trimmedResponse = remove_leading_whitespace(m_response); + emit responseChanged(QString::fromStdString(m_trimmedResponse)); return false; } @@ -715,7 +719,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) m_timer->inc(); Q_ASSERT(!response.empty()); m_response.append(response); - emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + m_trimmedResponse = remove_leading_whitespace(m_response); + emit responseChanged(QString::fromStdString(m_trimmedResponse)); return !m_stopGenerating; } @@ -741,7 +746,7 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens) + int32_t repeat_penalty_tokens, std::optional fakeReply) { if (!isModelLoaded()) return false; @@ -751,7 +756,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString QList databaseResults; const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); - if (!collectionList.isEmpty()) { + if (!fakeReply && !collectionList.isEmpty()) { emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks emit databaseResultsChanged(databaseResults); } @@ -797,7 +802,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_ctx.n_predict = old_n_predict; // now we are ready for a response } m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, - /*allowContextShift*/ true, m_ctx); + /*allowContextShift*/ true, m_ctx, false, + fakeReply.transform(std::mem_fn(&QString::toStdString))); #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -805,9 +811,9 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); - if (trimmed != m_response) { - m_response = trimmed; - emit responseChanged(QString::fromStdString(m_response)); + if (trimmed != m_trimmedResponse) { + m_trimmedResponse = trimmed; + emit responseChanged(QString::fromStdString(m_trimmedResponse)); } SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); @@ -1078,6 +1084,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, QString response; stream >> response; m_response = response.toStdString(); + m_trimmedResponse = trim_whitespace(m_response); QString nameResponse; stream >> nameResponse; m_nameResponse = nameResponse.toStdString(); @@ -1306,10 +1313,9 @@ void ChatLLM::processRestoreStateFromText() auto &response = *it++; Q_ASSERT(response.first != "Prompt: "); - auto responseText = response.second.toStdString(); m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr, - /*allowContextShift*/ true, m_ctx, false, &responseText); + /*allowContextShift*/ true, m_ctx, false, response.second.toUtf8().constData()); } if (!m_stopGenerating) { diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index eb8d044f..62d83753 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -116,7 +116,7 @@ public: void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } - QString response() const; + QString response(bool trim = true) const; ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &info); @@ -198,7 +198,7 @@ Q_SIGNALS: protected: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens); + int32_t repeat_penalty_tokens, std::optional fakeReply = {}); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleNamePrompt(int32_t token); @@ -221,6 +221,7 @@ private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); std::string m_response; + std::string m_trimmedResponse; std::string m_nameResponse; QString m_questionResponse; LLModelInfo m_llModelInfo; diff --git a/gpt4all-chat/src/localdocsmodel.h b/gpt4all-chat/src/localdocsmodel.h index 82b5f882..ddce8963 100644 --- a/gpt4all-chat/src/localdocsmodel.h +++ b/gpt4all-chat/src/localdocsmodel.h @@ -20,24 +20,25 @@ class LocalDocsCollectionsModel : public QSortFilterProxyModel Q_OBJECT Q_PROPERTY(int count READ count NOTIFY countChanged) Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged) + public: explicit LocalDocsCollectionsModel(QObject *parent); + int count() const { return rowCount(); } + int updatingCount() const; public Q_SLOTS: - int count() const { return rowCount(); } void setCollections(const QList &collections); - int updatingCount() const; Q_SIGNALS: void countChanged(); void updatingCountChanged(); -private Q_SLOT: - void maybeTriggerUpdatingCountChanged(); - protected: bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; +private Q_SLOTS: + void maybeTriggerUpdatingCountChanged(); + private: QList m_collections; int m_updatingCount = 0; diff --git a/gpt4all-chat/src/modellist.h b/gpt4all-chat/src/modellist.h index 21d9aeef..6123dde8 100644 --- a/gpt4all-chat/src/modellist.h +++ b/gpt4all-chat/src/modellist.h @@ -18,10 +18,12 @@ #include #include #include -#include + +#include using namespace Qt::Literals::StringLiterals; + struct ModelInfo { Q_GADGET Q_PROPERTY(QString id READ id WRITE setId) @@ -523,7 +525,7 @@ private: protected: explicit ModelList(); - ~ModelList() { for (auto *model: m_models) { delete model; } } + ~ModelList() override { for (auto *model: std::as_const(m_models)) { delete model; } } friend class MyModelList; }; diff --git a/gpt4all-chat/src/mysettings.h b/gpt4all-chat/src/mysettings.h index 3db8b234..85335f0b 100644 --- a/gpt4all-chat/src/mysettings.h +++ b/gpt4all-chat/src/mysettings.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 1da962f5..9d5c9583 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -4,7 +4,13 @@ #include "modellist.h" #include "mysettings.h" +#include +#include + #include +#include +#include +#include #include #include #include @@ -14,19 +20,67 @@ #include #include #include +#include #include +#include #include +#include #include +#include +#include #include +#include +#include #include #include +#include #include +namespace ranges = std::ranges; +using namespace std::string_literals; using namespace Qt::Literals::StringLiterals; //#define DEBUG + +#define MAKE_FORMATTER(type, conversion) \ + template <> \ + struct fmt::formatter: fmt::formatter { \ + template \ + FmtContext::iterator format(const type &value, FmtContext &ctx) const \ + { \ + return formatter::format(conversion, ctx); \ + } \ + } + +MAKE_FORMATTER(QString, value.toStdString() ); +MAKE_FORMATTER(QVariant, value.toString().toStdString()); + +namespace { + +class InvalidRequestError: public std::invalid_argument { + using std::invalid_argument::invalid_argument; + +public: + QHttpServerResponse asResponse() const + { + QJsonObject error { + { "message", what(), }, + { "type", u"invalid_request_error"_s, }, + { "param", QJsonValue::Null }, + { "code", QJsonValue::Null }, + }; + return { QJsonObject {{ "error", error }}, + QHttpServerResponder::StatusCode::BadRequest }; + } + +private: + Q_DISABLE_COPY_MOVE(InvalidRequestError) +}; + +} // namespace + static inline QJsonObject modelToJson(const ModelInfo &info) { QJsonObject model; @@ -39,7 +93,7 @@ static inline QJsonObject modelToJson(const ModelInfo &info) QJsonArray permissions; QJsonObject permissionObj; - permissionObj.insert("id", "foobarbaz"); + permissionObj.insert("id", "placeholder"); permissionObj.insert("object", "model_permission"); permissionObj.insert("created", 0); permissionObj.insert("allow_create_engine", false); @@ -70,6 +124,328 @@ static inline QJsonObject resultToJson(const ResultInfo &info) return result; } +class BaseCompletionRequest { +public: + QString model; // required + // NB: some parameters are not supported yet + int32_t max_tokens = 16; + qint64 n = 1; + float temperature = 1.f; + float top_p = 1.f; + float min_p = 0.f; + + BaseCompletionRequest() = default; + virtual ~BaseCompletionRequest() = default; + + virtual BaseCompletionRequest &parse(QCborMap request) + { + parseImpl(request); + if (!request.isEmpty()) + throw InvalidRequestError(fmt::format( + "Unrecognized request argument supplied: {}", request.keys().constFirst().toString() + )); + return *this; + } + +protected: + virtual void parseImpl(QCborMap &request) + { + using enum Type; + + auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); }; + QCborValue value; + + this->model = reqValue("model", String, /*required*/ true).toString(); + + value = reqValue("frequency_penalty", Number, false, /*min*/ -2, /*max*/ 2); + if (value.isDouble() || value.toInteger() != 0) + throw InvalidRequestError("'frequency_penalty' is not supported"); + + value = reqValue("max_tokens", Integer, false, /*min*/ 1); + if (!value.isNull()) + this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX)); + + value = reqValue("n", Integer, false, /*min*/ 1); + if (!value.isNull()) + this->n = value.toInteger(); + + value = reqValue("presence_penalty", Number); + if (value.isDouble() || value.toInteger() != 0) + throw InvalidRequestError("'presence_penalty' is not supported"); + + value = reqValue("seed", Integer); + if (!value.isNull()) + throw InvalidRequestError("'seed' is not supported"); + + value = reqValue("stop"); + if (!value.isNull()) + throw InvalidRequestError("'stop' is not supported"); + + value = reqValue("stream", Boolean); + if (value.isTrue()) + throw InvalidRequestError("'stream' is not supported"); + + value = reqValue("stream_options", Object); + if (!value.isNull()) + throw InvalidRequestError("'stream_options' is not supported"); + + value = reqValue("temperature", Number, false, /*min*/ 0, /*max*/ 2); + if (!value.isNull()) + this->temperature = float(value.toDouble()); + + value = reqValue("top_p", Number, /*min*/ 0, /*max*/ 1); + if (!value.isNull()) + this->top_p = float(value.toDouble()); + + value = reqValue("min_p", Number, /*min*/ 0, /*max*/ 1); + if (!value.isNull()) + this->min_p = float(value.toDouble()); + + reqValue("user", String); // validate but don't use + } + + enum class Type : uint8_t { + Boolean, + Integer, + Number, + String, + Array, + Object, + }; + + static const std::unordered_map s_typeNames; + + static bool typeMatches(const QCborValue &value, Type type) noexcept { + using enum Type; + switch (type) { + case Boolean: return value.isBool(); + case Integer: return value.isInteger(); + case Number: return value.isInteger() || value.isDouble(); + case String: return value.isString(); + case Array: return value.isArray(); + case Object: return value.isMap(); + } + Q_UNREACHABLE(); + } + + static QCborValue takeValue( + QCborMap &obj, const char *key, std::optional type = {}, bool required = false, + std::optional min = {}, std::optional max = {} + ) { + auto value = obj.take(QLatin1StringView(key)); + if (value.isUndefined()) + value = QCborValue(QCborSimpleType::Null); + if (required && value.isNull()) + throw InvalidRequestError(fmt::format("you must provide a {} parameter", key)); + if (type && !value.isNull() && !typeMatches(value, *type)) + throw InvalidRequestError(fmt::format("'{}' is not of type '{}' - '{}'", + value.toVariant(), s_typeNames.at(*type), key)); + if (!value.isNull()) { + double num = value.toDouble(); + if (min && num < double(*min)) + throw InvalidRequestError(fmt::format("{} is less than the minimum of {} - '{}'", num, *min, key)); + if (max && num > double(*max)) + throw InvalidRequestError(fmt::format("{} is greater than the maximum of {} - '{}'", num, *max, key)); + } + return value; + } + +private: + Q_DISABLE_COPY_MOVE(BaseCompletionRequest) +}; + +class CompletionRequest : public BaseCompletionRequest { +public: + QString prompt; // required + // some parameters are not supported yet - these ones are + bool echo = false; + + CompletionRequest &parse(QCborMap request) override + { + BaseCompletionRequest::parse(std::move(request)); + return *this; + } + +protected: + void parseImpl(QCborMap &request) override + { + using enum Type; + + auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); }; + QCborValue value; + + BaseCompletionRequest::parseImpl(request); + + this->prompt = reqValue("prompt", String, /*required*/ true).toString(); + + value = reqValue("best_of", Integer); + { + qint64 bof = value.toInteger(1); + if (this->n > bof) + throw InvalidRequestError(fmt::format( + "You requested that the server return more choices than it will generate (HINT: you must set 'n' " + "(currently {}) to be at most 'best_of' (currently {}), or omit either parameter if you don't " + "specifically want to use them.)", + this->n, bof + )); + if (bof > this->n) + throw InvalidRequestError("'best_of' is not supported"); + } + + value = reqValue("echo", Boolean); + if (value.isBool()) + this->echo = value.toBool(); + + // we don't bother deeply typechecking unsupported subobjects for now + value = reqValue("logit_bias", Object); + if (!value.isNull()) + throw InvalidRequestError("'logit_bias' is not supported"); + + value = reqValue("logprobs", Integer, false, /*min*/ 0); + if (!value.isNull()) + throw InvalidRequestError("'logprobs' is not supported"); + + value = reqValue("suffix", String); + if (!value.isNull() && !value.toString().isEmpty()) + throw InvalidRequestError("'suffix' is not supported"); + } +}; + +const std::unordered_map BaseCompletionRequest::s_typeNames = { + { BaseCompletionRequest::Type::Boolean, "boolean" }, + { BaseCompletionRequest::Type::Integer, "integer" }, + { BaseCompletionRequest::Type::Number, "number" }, + { BaseCompletionRequest::Type::String, "string" }, + { BaseCompletionRequest::Type::Array, "array" }, + { BaseCompletionRequest::Type::Object, "object" }, +}; + +class ChatRequest : public BaseCompletionRequest { +public: + struct Message { + enum class Role : uint8_t { + User, + Assistant, + }; + Role role; + QString content; + }; + + QList messages; // required + + ChatRequest &parse(QCborMap request) override + { + BaseCompletionRequest::parse(std::move(request)); + return *this; + } + +protected: + void parseImpl(QCborMap &request) override + { + using enum Type; + + auto reqValue = [&request](auto &&...args) { return takeValue(request, args...); }; + QCborValue value; + + BaseCompletionRequest::parseImpl(request); + + value = reqValue("messages", std::nullopt, /*required*/ true); + if (!value.isArray() || value.toArray().isEmpty()) + throw InvalidRequestError(fmt::format( + "Invalid type for 'messages': expected a non-empty array of objects, but got '{}' instead.", + value.toVariant() + )); + + this->messages.clear(); + { + QCborArray arr = value.toArray(); + Message::Role nextRole = Message::Role::User; + for (qsizetype i = 0; i < arr.size(); i++) { + const auto &elem = arr[i]; + if (!elem.isMap()) + throw InvalidRequestError(fmt::format( + "Invalid type for 'messages[{}]': expected an object, but got '{}' instead.", + i, elem.toVariant() + )); + QCborMap msg = elem.toMap(); + Message res; + QString role = takeValue(msg, "role", String, /*required*/ true).toString(); + if (role == u"system"_s) + continue; // FIXME(jared): don't ignore these + if (role == u"user"_s) { + res.role = Message::Role::User; + } else if (role == u"assistant"_s) { + res.role = Message::Role::Assistant; + } else { + throw InvalidRequestError(fmt::format( + "Invalid 'messages[{}].role': expected one of 'system', 'assistant', or 'user', but got '{}'" + " instead.", + i, role.toStdString() + )); + } + res.content = takeValue(msg, "content", String, /*required*/ true).toString(); + if (res.role != nextRole) + throw InvalidRequestError(fmt::format( + "Invalid 'messages[{}].role': did not expect '{}' here", i, role + )); + this->messages.append(res); + nextRole = res.role == Message::Role::User ? Message::Role::Assistant + : Message::Role::User; + + if (!msg.isEmpty()) + throw InvalidRequestError(fmt::format( + "Invalid 'messages[{}]': unrecognized key: '{}'", i, msg.keys().constFirst().toString() + )); + } + } + + // we don't bother deeply typechecking unsupported subobjects for now + value = reqValue("logit_bias", Object); + if (!value.isNull()) + throw InvalidRequestError("'logit_bias' is not supported"); + + value = reqValue("logprobs", Boolean); + if (value.isTrue()) + throw InvalidRequestError("'logprobs' is not supported"); + + value = reqValue("top_logprobs", Integer, false, /*min*/ 0); + if (!value.isNull()) + throw InvalidRequestError("The 'top_logprobs' parameter is only allowed when 'logprobs' is enabled."); + + value = reqValue("response_format", Object); + if (!value.isNull()) + throw InvalidRequestError("'response_format' is not supported"); + + reqValue("service_tier", String); // validate but don't use + + value = reqValue("tools", Array); + if (!value.isNull()) + throw InvalidRequestError("'tools' is not supported"); + + value = reqValue("tool_choice"); + if (!value.isNull()) + throw InvalidRequestError("'tool_choice' is not supported"); + + // validate but don't use + reqValue("parallel_tool_calls", Boolean); + + value = reqValue("function_call"); + if (!value.isNull()) + throw InvalidRequestError("'function_call' is not supported"); + + value = reqValue("functions", Array); + if (!value.isNull()) + throw InvalidRequestError("'functions' is not supported"); + } +}; + +template +T &parseRequest(T &request, QJsonObject &&obj) +{ + // lossless conversion to CBOR exposes more type information + return request.parse(QCborMap::fromJsonObject(obj)); +} + Server::Server(Chat *chat) : ChatLLM(chat, true /*isServer*/) , m_chat(chat) @@ -80,20 +456,28 @@ Server::Server(Chat *chat) connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection); } -Server::~Server() +static QJsonObject requestFromJson(const QByteArray &request) { + QJsonParseError err; + const QJsonDocument document = QJsonDocument::fromJson(request, &err); + if (err.error || !document.isObject()) + throw InvalidRequestError(fmt::format( + "error parsing request JSON: {}", + err.error ? err.errorString().toStdString() : "not an object"s + )); + return document.object(); } void Server::start() { - m_server = new QHttpServer(this); + m_server = std::make_unique(this); if (!m_server->listen(QHostAddress::LocalHost, MySettings::globalInstance()->networkPort())) { qWarning() << "ERROR: Unable to start the server"; return; } m_server->route("/v1/models", QHttpServerRequest::Method::Get, - [](const QHttpServerRequest &request) { + [](const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); @@ -113,7 +497,7 @@ void Server::start() ); m_server->route("/v1/models/", QHttpServerRequest::Method::Get, - [](const QString &model, const QHttpServerRequest &request) { + [](const QString &model, const QHttpServerRequest &) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); @@ -137,7 +521,23 @@ void Server::start() [this](const QHttpServerRequest &request) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); - return handleCompletionRequest(request, false); + + try { + auto reqObj = requestFromJson(request.body()); +#if defined(DEBUG) + qDebug().noquote() << "/v1/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented); +#endif + CompletionRequest req; + parseRequest(req, std::move(reqObj)); + auto [resp, respObj] = handleCompletionRequest(req); +#if defined(DEBUG) + if (respObj) + qDebug().noquote() << "/v1/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented); +#endif + return std::move(resp); + } catch (const InvalidRequestError &e) { + return e.asResponse(); + } } ); @@ -145,13 +545,30 @@ void Server::start() [this](const QHttpServerRequest &request) { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); - return handleCompletionRequest(request, true); + + try { + auto reqObj = requestFromJson(request.body()); +#if defined(DEBUG) + qDebug().noquote() << "/v1/chat/completions request" << QJsonDocument(reqObj).toJson(QJsonDocument::Indented); +#endif + ChatRequest req; + parseRequest(req, std::move(reqObj)); + auto [resp, respObj] = handleChatRequest(req); + (void)respObj; +#if defined(DEBUG) + if (respObj) + qDebug().noquote() << "/v1/chat/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented); +#endif + return std::move(resp); + } catch (const InvalidRequestError &e) { + return e.asResponse(); + } } ); // Respond with code 405 to wrong HTTP methods: m_server->route("/v1/models", QHttpServerRequest::Method::Post, - [](const QHttpServerRequest &request) { + [] { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -163,7 +580,8 @@ void Server::start() ); m_server->route("/v1/models/", QHttpServerRequest::Method::Post, - [](const QString &model, const QHttpServerRequest &request) { + [](const QString &model) { + (void)model; if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -175,7 +593,7 @@ void Server::start() ); m_server->route("/v1/completions", QHttpServerRequest::Method::Get, - [](const QHttpServerRequest &request) { + [] { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -186,7 +604,7 @@ void Server::start() ); m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get, - [](const QHttpServerRequest &request) { + [] { if (!MySettings::globalInstance()->serverChat()) return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized); return QHttpServerResponse( @@ -205,268 +623,261 @@ void Server::start() &Chat::serverNewPromptResponsePair, Qt::BlockingQueuedConnection); } -QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat) +static auto makeError(auto &&...args) -> std::pair> { - // We've been asked to do a completion... - QJsonParseError err; - const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err); - if (err.error || !document.isObject()) { - std::cerr << "ERROR: invalid json in completions body" << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); - } -#if defined(DEBUG) - printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented))); - fflush(stdout); -#endif - const QJsonObject body = document.object(); - if (!body.contains("model")) { // required - std::cerr << "ERROR: completions contains no model" << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); - } - QJsonArray messages; - if (isChat) { - if (!body.contains("messages")) { - std::cerr << "ERROR: chat completions contains no messages" << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent); - } - messages = body["messages"].toArray(); - } + return {QHttpServerResponse(args...), std::nullopt}; +} - const QString modelRequested = body["model"].toString(); +auto Server::handleCompletionRequest(const CompletionRequest &request) + -> std::pair> +{ ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); const QList modelList = ModelList::globalInstance()->selectableModelList(); for (const ModelInfo &info : modelList) { Q_ASSERT(info.installed); if (!info.installed) continue; - if (modelRequested == info.name() || modelRequested == info.filename()) { + if (request.model == info.name() || request.model == info.filename()) { modelInfo = info; break; } } - // We only support one prompt for now - QList prompts; - if (body.contains("prompt")) { - QJsonValue promptValue = body["prompt"]; - if (promptValue.isString()) - prompts.append(promptValue.toString()); - else { - QJsonArray array = promptValue.toArray(); - for (const QJsonValue &v : array) - prompts.append(v.toString()); - } - } else - prompts.append(" "); - - int max_tokens = 16; - if (body.contains("max_tokens")) - max_tokens = body["max_tokens"].toInt(); - - float temperature = 1.f; - if (body.contains("temperature")) - temperature = body["temperature"].toDouble(); - - float top_p = 1.f; - if (body.contains("top_p")) - top_p = body["top_p"].toDouble(); - - float min_p = 0.f; - if (body.contains("min_p")) - min_p = body["min_p"].toDouble(); - - int n = 1; - if (body.contains("n")) - n = body["n"].toInt(); - - int logprobs = -1; // supposed to be null by default?? - if (body.contains("logprobs")) - logprobs = body["logprobs"].toInt(); - - bool echo = false; - if (body.contains("echo")) - echo = body["echo"].toBool(); - - // We currently don't support any of the following... -#if 0 - // FIXME: Need configurable reverse prompts - QList stop; - if (body.contains("stop")) { - QJsonValue stopValue = body["stop"]; - if (stopValue.isString()) - stop.append(stopValue.toString()); - else { - QJsonArray array = stopValue.toArray(); - for (QJsonValue v : array) - stop.append(v.toString()); - } - } - - // FIXME: QHttpServer doesn't support server-sent events - bool stream = false; - if (body.contains("stream")) - stream = body["stream"].toBool(); - - // FIXME: What does this do? - QString suffix; - if (body.contains("suffix")) - suffix = body["suffix"].toString(); - - // FIXME: We don't support - float presence_penalty = 0.f; - if (body.contains("presence_penalty")) - top_p = body["presence_penalty"].toDouble(); - - // FIXME: We don't support - float frequency_penalty = 0.f; - if (body.contains("frequency_penalty")) - top_p = body["frequency_penalty"].toDouble(); - - // FIXME: We don't support - int best_of = 1; - if (body.contains("best_of")) - logprobs = body["best_of"].toInt(); - - // FIXME: We don't need - QString user; - if (body.contains("user")) - suffix = body["user"].toString(); -#endif - - QString actualPrompt = prompts.first(); - - // if we're a chat completion we have messages which means we need to prepend these to the prompt - if (!messages.isEmpty()) { - QList chats; - for (int i = 0; i < messages.count(); ++i) { - QJsonValue v = messages.at(i); - // FIXME: Deal with system messages correctly - QString role = v.toObject()["role"].toString(); - if (role != "user") - continue; - QString content = v.toObject()["content"].toString(); - if (!content.endsWith("\n") && i < messages.count() - 1) - content += "\n"; - chats.append(content); - } - actualPrompt.prepend(chats.join("\n")); - } - // adds prompt/response items to GUI - emit requestServerNewPromptResponsePair(actualPrompt); // blocks + emit requestServerNewPromptResponsePair(request.prompt); // blocks + resetResponse(); // load the new model if necessary setShouldBeLoaded(true); if (modelInfo.filename().isEmpty()) { - std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::BadRequest); + std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl; + return makeError(QHttpServerResponder::StatusCode::InternalServerError); } // NB: this resets the context, regardless of whether this model is already loaded if (!loadModel(modelInfo)) { std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); + return makeError(QHttpServerResponder::StatusCode::InternalServerError); } - const QString promptTemplate = modelInfo.promptTemplate(); - const float top_k = modelInfo.topK(); - const int n_batch = modelInfo.promptBatchSize(); - const float repeat_penalty = modelInfo.repeatPenalty(); - const int repeat_last_n = modelInfo.repeatPenaltyTokens(); + // FIXME(jared): taking parameters from the UI inhibits reproducibility of results + const int top_k = modelInfo.topK(); + const int n_batch = modelInfo.promptBatchSize(); + const auto repeat_penalty = float(modelInfo.repeatPenalty()); + const int repeat_last_n = modelInfo.repeatPenaltyTokens(); int promptTokens = 0; int responseTokens = 0; QList>> responses; - for (int i = 0; i < n; ++i) { + for (int i = 0; i < request.n; ++i) { if (!promptInternal( m_collections, - actualPrompt, - promptTemplate, - max_tokens /*n_predict*/, + request.prompt, + /*promptTemplate*/ u"%1"_s, + request.max_tokens, top_k, - top_p, - min_p, - temperature, + request.top_p, + request.min_p, + request.temperature, n_batch, repeat_penalty, repeat_last_n)) { std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; - return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError); + return makeError(QHttpServerResponder::StatusCode::InternalServerError); } - QString echoedPrompt = actualPrompt; - if (!echoedPrompt.endsWith("\n")) - echoedPrompt += "\n"; - responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults)); + QString resp = response(/*trim*/ false); + if (request.echo) + resp = request.prompt + resp; + responses.append({resp, m_databaseResults}); if (!promptTokens) - promptTokens += m_promptTokens; + promptTokens = m_promptTokens; responseTokens += m_promptResponseTokens - m_promptTokens; - if (i != n - 1) + if (i < request.n - 1) resetResponse(); } - QJsonObject responseObject; - responseObject.insert("id", "foobarbaz"); - responseObject.insert("object", "text_completion"); - responseObject.insert("created", QDateTime::currentSecsSinceEpoch()); - responseObject.insert("model", modelInfo.name()); + QJsonObject responseObject { + { "id", "placeholder" }, + { "object", "text_completion" }, + { "created", QDateTime::currentSecsSinceEpoch() }, + { "model", modelInfo.name() }, + }; QJsonArray choices; - - if (isChat) { + { int index = 0; for (const auto &r : responses) { QString result = r.first; QList infos = r.second; - QJsonObject choice; - choice.insert("index", index++); - choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); - QJsonObject message; - message.insert("role", "assistant"); - message.insert("content", result); - choice.insert("message", message); + QJsonObject choice { + { "text", result }, + { "index", index++ }, + { "logprobs", QJsonValue::Null }, + { "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" }, + }; if (MySettings::globalInstance()->localDocsShowReferences()) { QJsonArray references; for (const auto &ref : infos) references.append(resultToJson(ref)); - choice.insert("references", references); - } - choices.append(choice); - } - } else { - int index = 0; - for (const auto &r : responses) { - QString result = r.first; - QList infos = r.second; - QJsonObject choice; - choice.insert("text", result); - choice.insert("index", index++); - choice.insert("logprobs", QJsonValue::Null); // We don't support - choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop"); - if (MySettings::globalInstance()->localDocsShowReferences()) { - QJsonArray references; - for (const auto &ref : infos) - references.append(resultToJson(ref)); - choice.insert("references", references); + choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references)); } choices.append(choice); } } responseObject.insert("choices", choices); + responseObject.insert("usage", QJsonObject { + { "prompt_tokens", promptTokens }, + { "completion_tokens", responseTokens }, + { "total_tokens", promptTokens + responseTokens }, + }); - QJsonObject usage; - usage.insert("prompt_tokens", int(promptTokens)); - usage.insert("completion_tokens", int(responseTokens)); - usage.insert("total_tokens", int(promptTokens + responseTokens)); - responseObject.insert("usage", usage); - -#if defined(DEBUG) - QJsonDocument newDoc(responseObject); - printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented))); - fflush(stdout); -#endif - - return QHttpServerResponse(responseObject); + return {QHttpServerResponse(responseObject), responseObject}; +} + +auto Server::handleChatRequest(const ChatRequest &request) + -> std::pair> +{ + ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); + const QList modelList = ModelList::globalInstance()->selectableModelList(); + for (const ModelInfo &info : modelList) { + Q_ASSERT(info.installed); + if (!info.installed) + continue; + if (request.model == info.name() || request.model == info.filename()) { + modelInfo = info; + break; + } + } + + // load the new model if necessary + setShouldBeLoaded(true); + + if (modelInfo.filename().isEmpty()) { + std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl; + return makeError(QHttpServerResponder::StatusCode::InternalServerError); + } + + // NB: this resets the context, regardless of whether this model is already loaded + if (!loadModel(modelInfo)) { + std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; + return makeError(QHttpServerResponder::StatusCode::InternalServerError); + } + + const QString promptTemplate = modelInfo.promptTemplate(); + const int top_k = modelInfo.topK(); + const int n_batch = modelInfo.promptBatchSize(); + const auto repeat_penalty = float(modelInfo.repeatPenalty()); + const int repeat_last_n = modelInfo.repeatPenaltyTokens(); + + int promptTokens = 0; + int responseTokens = 0; + QList>> responses; + Q_ASSERT(!request.messages.isEmpty()); + Q_ASSERT(request.messages.size() % 2 == 1); + for (int i = 0; i < request.messages.size() - 2; i += 2) { + using enum ChatRequest::Message::Role; + auto &user = request.messages[i]; + auto &assistant = request.messages[i + 1]; + Q_ASSERT(user.role == User); + Q_ASSERT(assistant.role == Assistant); + + // adds prompt/response items to GUI + emit requestServerNewPromptResponsePair(user.content); // blocks + resetResponse(); + + if (!promptInternal( + {}, + user.content, + promptTemplate, + request.max_tokens, + top_k, + request.top_p, + request.min_p, + request.temperature, + n_batch, + repeat_penalty, + repeat_last_n, + assistant.content) + ) { + std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; + return makeError(QHttpServerResponder::StatusCode::InternalServerError); + } + promptTokens += m_promptResponseTokens; // previous responses are part of current prompt + } + + QString lastMessage = request.messages.last().content; + // adds prompt/response items to GUI + emit requestServerNewPromptResponsePair(lastMessage); // blocks + resetResponse(); + + for (int i = 0; i < request.n; ++i) { + if (!promptInternal( + m_collections, + lastMessage, + promptTemplate, + request.max_tokens, + top_k, + request.top_p, + request.min_p, + request.temperature, + n_batch, + repeat_penalty, + repeat_last_n) + ) { + std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl; + return makeError(QHttpServerResponder::StatusCode::InternalServerError); + } + responses.append({response(), m_databaseResults}); + // FIXME(jared): these are UI counts and do not include framing tokens, which they should + if (i == 0) + promptTokens += m_promptTokens; + responseTokens += m_promptResponseTokens - m_promptTokens; + if (i != request.n - 1) + resetResponse(); + } + + QJsonObject responseObject { + { "id", "placeholder" }, + { "object", "chat.completion" }, + { "created", QDateTime::currentSecsSinceEpoch() }, + { "model", modelInfo.name() }, + }; + + QJsonArray choices; + { + int index = 0; + for (const auto &r : responses) { + QString result = r.first; + QList infos = r.second; + QJsonObject message { + { "role", "assistant" }, + { "content", result }, + }; + QJsonObject choice { + { "index", index++ }, + { "message", message }, + { "finish_reason", responseTokens == request.max_tokens ? "length" : "stop" }, + { "logprobs", QJsonValue::Null }, + }; + if (MySettings::globalInstance()->localDocsShowReferences()) { + QJsonArray references; + for (const auto &ref : infos) + references.append(resultToJson(ref)); + choice.insert("references", references.isEmpty() ? QJsonValue::Null : QJsonValue(references)); + } + choices.append(choice); + } + } + + responseObject.insert("choices", choices); + responseObject.insert("usage", QJsonObject { + { "prompt_tokens", promptTokens }, + { "completion_tokens", responseTokens }, + { "total_tokens", promptTokens + responseTokens }, + }); + + return {QHttpServerResponse(responseObject), responseObject}; } diff --git a/gpt4all-chat/src/server.h b/gpt4all-chat/src/server.h index 689f0b60..a1d46264 100644 --- a/gpt4all-chat/src/server.h +++ b/gpt4all-chat/src/server.h @@ -4,22 +4,29 @@ #include "chatllm.h" #include "database.h" -#include +#include #include -#include +#include #include +#include #include +#include +#include +#include + class Chat; -class QHttpServer; +class ChatRequest; +class CompletionRequest; + class Server : public ChatLLM { Q_OBJECT public: - Server(Chat *parent); - virtual ~Server(); + explicit Server(Chat *chat); + ~Server() override = default; public Q_SLOTS: void start(); @@ -27,14 +34,17 @@ public Q_SLOTS: Q_SIGNALS: void requestServerNewPromptResponsePair(const QString &prompt); +private: + auto handleCompletionRequest(const CompletionRequest &request) -> std::pair>; + auto handleChatRequest(const ChatRequest &request) -> std::pair>; + private Q_SLOTS: - QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); void handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; } void handleCollectionListChanged(const QList &collectionList) { m_collections = collectionList; } private: Chat *m_chat; - QHttpServer *m_server; + std::unique_ptr m_server; QList m_databaseResults; QList m_collections; };