diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 04a510dc..0c5f43c3 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -137,6 +137,16 @@ public: float contextErase = 0.5f; // percent of context to erase if we exceed the context window }; + struct Message { + std::string content; + std::string role; + }; + + struct MessageFrame { + std::string before; + std::string after; + }; + using ProgressCallback = std::function; explicit LLModel() {} @@ -155,8 +165,8 @@ public: // This method requires the model to return true from supportsCompletion otherwise it will throw // an error - virtual void prompt(const std::string &prompt, - const std::string &promptTemplate, + virtual void prompt(const std::vector &messages, + std::function framingCallback, std::function promptCallback, std::function responseCallback, bool allowContextShift, diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index f3fd68ff..28bb2d0e 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -102,8 +102,28 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src) return wrapper->llModel->restoreState(src); } -void llmodel_prompt(llmodel_model model, const char *prompt, - const char *prompt_template, +LLModel::Message convertToMessage(const llmodel_message &msg) { + return {std::string(msg.content), std::string(msg.role)}; +} + +LLModel::MessageFrame convertToMessageFrame(const llmodel_message_frame &frame) +{ + return {std::string(frame.before), std::string(frame.after)}; +} + +auto wrapFramingCallback(llmodel_framing_callback c_callback) +{ + return [c_callback](const LLModel::Message &msg) -> LLModel::MessageFrame { + llmodel_message c_message = {msg.content.c_str(), msg.role.c_str()}; + llmodel_message_frame c_frame = c_callback(c_message); + return convertToMessageFrame(c_frame); + }; +} + +void llmodel_prompt(llmodel_model model, + llmodel_message *messages, + size_t n_messages, + llmodel_framing_callback framing_callback, llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, bool allow_context_shift, @@ -135,8 +155,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt, auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr; // Call the C++ prompt method - wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift, - wrapper->promptContext, special, fake_reply_p); + std::vector messagesVec; + for (size_t i = 0; i < n_messages; ++i) + messagesVec.push_back(convertToMessage(messages[i])); + + auto cpp_framing_callback = wrapFramingCallback(framing_callback); + + wrapper->llModel->prompt({ messagesVec }, cpp_framing_callback, prompt_callback, response_func, + allow_context_shift, wrapper->promptContext, special, fake_reply_p); // Update the C context by giving access to the wrappers raw pointers to std::vector data // which involves no copies diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 327bea2e..0dacb0eb 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -54,11 +54,28 @@ struct llmodel_gpu_device { const char * vendor; }; +struct llmodel_message { + const char * content; + const char * role; +}; + +struct llmodel_message_frame { + const char * before; + const char * after; +}; + #ifndef __cplusplus typedef struct llmodel_prompt_context llmodel_prompt_context; typedef struct llmodel_gpu_device llmodel_gpu_device; #endif +/** + * Callback type for framing strings. + * @param message The message. + * @return a message frame with framing strings. + */ +typedef llmodel_message_frame (*llmodel_framing_callback)(const llmodel_message message); + /** * Callback type for prompt processing. * @param token_id The token id of the prompt. @@ -164,8 +181,9 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); /** * Generate a response using the model. * @param model A pointer to the llmodel_model instance. - * @param prompt A string representing the input prompt. - * @param prompt_template A string representing the input prompt template. + * @param messages An array of messages. + * @param n_messages The number of messages. + * @param framing_callback A callback function for retrieving the message framing strings. * @param prompt_callback A callback function for handling the processing of prompt. * @param response_callback A callback function for handling the generated response. * @param allow_context_shift Whether to allow shifting of context to make room for more input. @@ -173,8 +191,10 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); * @param fake_reply A string to insert into context as the model's reply, or NULL to generate one. * @param ctx A pointer to the llmodel_prompt_context structure. */ -void llmodel_prompt(llmodel_model model, const char *prompt, - const char *prompt_template, +void llmodel_prompt(llmodel_model model, + llmodel_message *messages, + size_t n_messages, + llmodel_framing_callback framing_callback, llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, bool allow_context_shift, diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 570f62c6..01553902 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -15,31 +15,8 @@ namespace ranges = std::ranges; -static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) -{ - static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))"); - - auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex); - placeholders.clear(); - placeholders.insert(placeholders.end(), it, std::sregex_iterator()); - - if (placeholders.size() > 2) { - err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size()); - return false; - } - if (placeholders.size() >= 1 && placeholders[0].str() != "%1") { - err = "ERROR: first placeholder must be %1, got " + placeholders[0].str(); - return false; - } - if (placeholders.size() >= 2 && placeholders[1].str() != "%2") { - err = "ERROR: second placeholder must be %2, got " + placeholders[1].str(); - return false; - } - return true; -} - -void LLModel::prompt(const std::string &prompt, - const std::string &promptTemplate, +void LLModel::prompt(const std::vector &messages, + std::function framingCallback, std::function promptCallback, std::function responseCallback, bool allowContextShift, @@ -78,45 +55,24 @@ void LLModel::prompt(const std::string &prompt, promptCtx.tokens.resize(promptCtx.n_past); m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized - // parse the prompt template - std::vector placeholders; - { - std::string err; - if (!parsePromptTemplate(promptTemplate, placeholders, err)) { - responseCallback(-1, err); - std::cerr << err << "\n"; - return; - } - } + const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize - auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize - - // tokenize the user prompt std::vector embd_inp; - if (placeholders.empty()) { - // this is unusual, but well-defined - std::cerr << __func__ << ": prompt template has no placeholder\n"; - embd_inp = tokenize(promptCtx, promptTemplate, true); - } else { - // template: beginning of user prompt - const auto &phUser = placeholders[0]; - std::string userPrefix(phUser.prefix()); - if (!userPrefix.empty()) { - embd_inp = tokenize(promptCtx, userPrefix, true); + for (const Message &msg : messages) { + const MessageFrame msgFrame = framingCallback(msg); + if (!msgFrame.before.empty()) { + auto tokens = tokenize(promptCtx, msgFrame.before, true); + embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); promptCtx.n_past += embd_inp.size(); } - // user input (shouldn't have special token processing) - auto tokens = tokenize(promptCtx, prompt, special); + // message content (shouldn't have special token processing) + auto tokens = tokenize(promptCtx, msg.content, special); embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); promptCtx.n_past += tokens.size(); - // template: end of user prompt + start of assistant prompt - size_t start = phUser.position() + phUser.length(); - size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length(); - auto userToAsst = promptTemplate.substr(start, end - start); - if (!userToAsst.empty()) { - tokens = tokenize(promptCtx, userToAsst, true); + if (!msgFrame.after.empty()) { + tokens = tokenize(promptCtx, msgFrame.after, true); embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); promptCtx.n_past += tokens.size(); } @@ -128,6 +84,24 @@ void LLModel::prompt(const std::string &prompt, if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp)) return; // error + // Nothing more to be done if we're not being asked to predict any tokens + if (promptCtx.n_predict < 1) + return; + + // If we're being asked to predict then we create an assistant message with framing strings + Message asstMsg; + asstMsg.role = "assistant"; + MessageFrame asstMsgFrame = framingCallback(asstMsg); + + // Tokenize and decode the prefixed assistant framing string if any + if (!asstMsgFrame.before.empty()) { + const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize + auto tokens = tokenize(promptCtx, asstMsgFrame.before, true); + promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens)) + return; // error + } + // decode the assistant's reply, either generated or spoofed if (fakeReply == nullptr) { generateResponse(responseCallback, allowContextShift, promptCtx); @@ -137,18 +111,13 @@ void LLModel::prompt(const std::string &prompt, return; // error } - // decode the rest of the prompt template - // template: end of assistant prompt - std::string asstSuffix; - if (placeholders.size() >= 2) { - size_t start = placeholders[1].position() + placeholders[1].length(); - asstSuffix = promptTemplate.substr(start); - } else { - asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca - } - if (!asstSuffix.empty()) { - embd_inp = tokenize(promptCtx, asstSuffix, true); - decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp); + // Tokenize and decode the suffixed assistant framing string if any + if (!asstMsgFrame.after.empty()) { + const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize + auto tokens = tokenize(promptCtx, asstMsgFrame.after, true); + promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens)) + return; // error } } diff --git a/gpt4all-chat/chatapi.cpp b/gpt4all-chat/chatapi.cpp index b443f24c..58badb3c 100644 --- a/gpt4all-chat/chatapi.cpp +++ b/gpt4all-chat/chatapi.cpp @@ -86,8 +86,8 @@ size_t ChatAPI::restoreState(const uint8_t *src) return 0; } -void ChatAPI::prompt(const std::string &prompt, - const std::string &promptTemplate, +void ChatAPI::prompt(const std::vector &messages, + std::function framingCallback, std::function promptCallback, std::function responseCallback, bool allowContextShift, diff --git a/gpt4all-chat/chatapi.h b/gpt4all-chat/chatapi.h index 59b68f58..4802cb9f 100644 --- a/gpt4all-chat/chatapi.h +++ b/gpt4all-chat/chatapi.h @@ -65,8 +65,8 @@ public: size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; - void prompt(const std::string &prompt, - const std::string &promptTemplate, + void prompt(const std::vector &, + std::function framingCallback, std::function promptCallback, std::function responseCallback, bool allowContextShift,