mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Proposed modification to the prompt method.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
221b9cff5a
commit
7037eb04d1
@ -137,6 +137,16 @@ public:
|
||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||
};
|
||||
|
||||
struct Message {
|
||||
std::string content;
|
||||
std::string role;
|
||||
};
|
||||
|
||||
struct MessageFrame {
|
||||
std::string before;
|
||||
std::string after;
|
||||
};
|
||||
|
||||
using ProgressCallback = std::function<bool(float progress)>;
|
||||
|
||||
explicit LLModel() {}
|
||||
@ -155,8 +165,8 @@ public:
|
||||
|
||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||
// an error
|
||||
virtual void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
virtual void prompt(const std::vector<Message> &messages,
|
||||
std::function<MessageFrame(const Message &)> framingCallback,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
|
@ -102,8 +102,28 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
||||
return wrapper->llModel->restoreState(src);
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
LLModel::Message convertToMessage(const llmodel_message &msg) {
|
||||
return {std::string(msg.content), std::string(msg.role)};
|
||||
}
|
||||
|
||||
LLModel::MessageFrame convertToMessageFrame(const llmodel_message_frame &frame)
|
||||
{
|
||||
return {std::string(frame.before), std::string(frame.after)};
|
||||
}
|
||||
|
||||
auto wrapFramingCallback(llmodel_framing_callback c_callback)
|
||||
{
|
||||
return [c_callback](const LLModel::Message &msg) -> LLModel::MessageFrame {
|
||||
llmodel_message c_message = {msg.content.c_str(), msg.role.c_str()};
|
||||
llmodel_message_frame c_frame = c_callback(c_message);
|
||||
return convertToMessageFrame(c_frame);
|
||||
};
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model,
|
||||
llmodel_message *messages,
|
||||
size_t n_messages,
|
||||
llmodel_framing_callback framing_callback,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
bool allow_context_shift,
|
||||
@ -135,8 +155,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
||||
wrapper->promptContext, special, fake_reply_p);
|
||||
std::vector<LLModel::Message> messagesVec;
|
||||
for (size_t i = 0; i < n_messages; ++i)
|
||||
messagesVec.push_back(convertToMessage(messages[i]));
|
||||
|
||||
auto cpp_framing_callback = wrapFramingCallback(framing_callback);
|
||||
|
||||
wrapper->llModel->prompt({ messagesVec }, cpp_framing_callback, prompt_callback, response_func,
|
||||
allow_context_shift, wrapper->promptContext, special, fake_reply_p);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
|
@ -54,11 +54,28 @@ struct llmodel_gpu_device {
|
||||
const char * vendor;
|
||||
};
|
||||
|
||||
struct llmodel_message {
|
||||
const char * content;
|
||||
const char * role;
|
||||
};
|
||||
|
||||
struct llmodel_message_frame {
|
||||
const char * before;
|
||||
const char * after;
|
||||
};
|
||||
|
||||
#ifndef __cplusplus
|
||||
typedef struct llmodel_prompt_context llmodel_prompt_context;
|
||||
typedef struct llmodel_gpu_device llmodel_gpu_device;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Callback type for framing strings.
|
||||
* @param message The message.
|
||||
* @return a message frame with framing strings.
|
||||
*/
|
||||
typedef llmodel_message_frame (*llmodel_framing_callback)(const llmodel_message message);
|
||||
|
||||
/**
|
||||
* Callback type for prompt processing.
|
||||
* @param token_id The token id of the prompt.
|
||||
@ -164,8 +181,9 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
/**
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param prompt_template A string representing the input prompt template.
|
||||
* @param messages An array of messages.
|
||||
* @param n_messages The number of messages.
|
||||
* @param framing_callback A callback function for retrieving the message framing strings.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
||||
@ -173,8 +191,10 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
const char *prompt_template,
|
||||
void llmodel_prompt(llmodel_model model,
|
||||
llmodel_message *messages,
|
||||
size_t n_messages,
|
||||
llmodel_framing_callback framing_callback,
|
||||
llmodel_prompt_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
bool allow_context_shift,
|
||||
|
@ -15,31 +15,8 @@
|
||||
|
||||
namespace ranges = std::ranges;
|
||||
|
||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
|
||||
{
|
||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
||||
|
||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
||||
placeholders.clear();
|
||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
||||
|
||||
if (placeholders.size() > 2) {
|
||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
||||
return false;
|
||||
}
|
||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LLModel::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
void LLModel::prompt(const std::vector<Message> &messages,
|
||||
std::function<MessageFrame(const Message &)> framingCallback,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
@ -78,45 +55,24 @@ void LLModel::prompt(const std::string &prompt,
|
||||
promptCtx.tokens.resize(promptCtx.n_past);
|
||||
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
||||
|
||||
// parse the prompt template
|
||||
std::vector<std::smatch> placeholders;
|
||||
{
|
||||
std::string err;
|
||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
||||
responseCallback(-1, err);
|
||||
std::cerr << err << "\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
|
||||
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
|
||||
// tokenize the user prompt
|
||||
std::vector<Token> embd_inp;
|
||||
if (placeholders.empty()) {
|
||||
// this is unusual, but well-defined
|
||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
||||
embd_inp = tokenize(promptCtx, promptTemplate, true);
|
||||
} else {
|
||||
// template: beginning of user prompt
|
||||
const auto &phUser = placeholders[0];
|
||||
std::string userPrefix(phUser.prefix());
|
||||
if (!userPrefix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, userPrefix, true);
|
||||
for (const Message &msg : messages) {
|
||||
const MessageFrame msgFrame = framingCallback(msg);
|
||||
if (!msgFrame.before.empty()) {
|
||||
auto tokens = tokenize(promptCtx, msgFrame.before, true);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += embd_inp.size();
|
||||
}
|
||||
|
||||
// user input (shouldn't have special token processing)
|
||||
auto tokens = tokenize(promptCtx, prompt, special);
|
||||
// message content (shouldn't have special token processing)
|
||||
auto tokens = tokenize(promptCtx, msg.content, special);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
|
||||
// template: end of user prompt + start of assistant prompt
|
||||
size_t start = phUser.position() + phUser.length();
|
||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
||||
if (!userToAsst.empty()) {
|
||||
tokens = tokenize(promptCtx, userToAsst, true);
|
||||
if (!msgFrame.after.empty()) {
|
||||
tokens = tokenize(promptCtx, msgFrame.after, true);
|
||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||
promptCtx.n_past += tokens.size();
|
||||
}
|
||||
@ -128,6 +84,24 @@ void LLModel::prompt(const std::string &prompt,
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||
return; // error
|
||||
|
||||
// Nothing more to be done if we're not being asked to predict any tokens
|
||||
if (promptCtx.n_predict < 1)
|
||||
return;
|
||||
|
||||
// If we're being asked to predict then we create an assistant message with framing strings
|
||||
Message asstMsg;
|
||||
asstMsg.role = "assistant";
|
||||
MessageFrame asstMsgFrame = framingCallback(asstMsg);
|
||||
|
||||
// Tokenize and decode the prefixed assistant framing string if any
|
||||
if (!asstMsgFrame.before.empty()) {
|
||||
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
auto tokens = tokenize(promptCtx, asstMsgFrame.before, true);
|
||||
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
|
||||
return; // error
|
||||
}
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (fakeReply == nullptr) {
|
||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||
@ -137,18 +111,13 @@ void LLModel::prompt(const std::string &prompt,
|
||||
return; // error
|
||||
}
|
||||
|
||||
// decode the rest of the prompt template
|
||||
// template: end of assistant prompt
|
||||
std::string asstSuffix;
|
||||
if (placeholders.size() >= 2) {
|
||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
||||
asstSuffix = promptTemplate.substr(start);
|
||||
} else {
|
||||
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
|
||||
}
|
||||
if (!asstSuffix.empty()) {
|
||||
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
||||
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
|
||||
// Tokenize and decode the suffixed assistant framing string if any
|
||||
if (!asstMsgFrame.after.empty()) {
|
||||
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||
auto tokens = tokenize(promptCtx, asstMsgFrame.after, true);
|
||||
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
|
||||
return; // error
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -86,8 +86,8 @@ size_t ChatAPI::restoreState(const uint8_t *src)
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ChatAPI::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
void ChatAPI::prompt(const std::vector<Message> &messages,
|
||||
std::function<MessageFrame(const Message &)> framingCallback,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
|
@ -65,8 +65,8 @@ public:
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
void prompt(const std::vector<Message> &,
|
||||
std::function<MessageFrame(const Message &)> framingCallback,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
|
Loading…
Reference in New Issue
Block a user