mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Proposed modification to the prompt method.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
221b9cff5a
commit
7037eb04d1
@ -137,6 +137,16 @@ public:
|
|||||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Message {
|
||||||
|
std::string content;
|
||||||
|
std::string role;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MessageFrame {
|
||||||
|
std::string before;
|
||||||
|
std::string after;
|
||||||
|
};
|
||||||
|
|
||||||
using ProgressCallback = std::function<bool(float progress)>;
|
using ProgressCallback = std::function<bool(float progress)>;
|
||||||
|
|
||||||
explicit LLModel() {}
|
explicit LLModel() {}
|
||||||
@ -155,8 +165,8 @@ public:
|
|||||||
|
|
||||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||||
// an error
|
// an error
|
||||||
virtual void prompt(const std::string &prompt,
|
virtual void prompt(const std::vector<Message> &messages,
|
||||||
const std::string &promptTemplate,
|
std::function<MessageFrame(const Message &)> framingCallback,
|
||||||
std::function<bool(int32_t)> promptCallback,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
|
@ -102,8 +102,28 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
|||||||
return wrapper->llModel->restoreState(src);
|
return wrapper->llModel->restoreState(src);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
LLModel::Message convertToMessage(const llmodel_message &msg) {
|
||||||
const char *prompt_template,
|
return {std::string(msg.content), std::string(msg.role)};
|
||||||
|
}
|
||||||
|
|
||||||
|
LLModel::MessageFrame convertToMessageFrame(const llmodel_message_frame &frame)
|
||||||
|
{
|
||||||
|
return {std::string(frame.before), std::string(frame.after)};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto wrapFramingCallback(llmodel_framing_callback c_callback)
|
||||||
|
{
|
||||||
|
return [c_callback](const LLModel::Message &msg) -> LLModel::MessageFrame {
|
||||||
|
llmodel_message c_message = {msg.content.c_str(), msg.role.c_str()};
|
||||||
|
llmodel_message_frame c_frame = c_callback(c_message);
|
||||||
|
return convertToMessageFrame(c_frame);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void llmodel_prompt(llmodel_model model,
|
||||||
|
llmodel_message *messages,
|
||||||
|
size_t n_messages,
|
||||||
|
llmodel_framing_callback framing_callback,
|
||||||
llmodel_prompt_callback prompt_callback,
|
llmodel_prompt_callback prompt_callback,
|
||||||
llmodel_response_callback response_callback,
|
llmodel_response_callback response_callback,
|
||||||
bool allow_context_shift,
|
bool allow_context_shift,
|
||||||
@ -135,8 +155,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
||||||
|
|
||||||
// Call the C++ prompt method
|
// Call the C++ prompt method
|
||||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
std::vector<LLModel::Message> messagesVec;
|
||||||
wrapper->promptContext, special, fake_reply_p);
|
for (size_t i = 0; i < n_messages; ++i)
|
||||||
|
messagesVec.push_back(convertToMessage(messages[i]));
|
||||||
|
|
||||||
|
auto cpp_framing_callback = wrapFramingCallback(framing_callback);
|
||||||
|
|
||||||
|
wrapper->llModel->prompt({ messagesVec }, cpp_framing_callback, prompt_callback, response_func,
|
||||||
|
allow_context_shift, wrapper->promptContext, special, fake_reply_p);
|
||||||
|
|
||||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||||
// which involves no copies
|
// which involves no copies
|
||||||
|
@ -54,11 +54,28 @@ struct llmodel_gpu_device {
|
|||||||
const char * vendor;
|
const char * vendor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llmodel_message {
|
||||||
|
const char * content;
|
||||||
|
const char * role;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llmodel_message_frame {
|
||||||
|
const char * before;
|
||||||
|
const char * after;
|
||||||
|
};
|
||||||
|
|
||||||
#ifndef __cplusplus
|
#ifndef __cplusplus
|
||||||
typedef struct llmodel_prompt_context llmodel_prompt_context;
|
typedef struct llmodel_prompt_context llmodel_prompt_context;
|
||||||
typedef struct llmodel_gpu_device llmodel_gpu_device;
|
typedef struct llmodel_gpu_device llmodel_gpu_device;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback type for framing strings.
|
||||||
|
* @param message The message.
|
||||||
|
* @return a message frame with framing strings.
|
||||||
|
*/
|
||||||
|
typedef llmodel_message_frame (*llmodel_framing_callback)(const llmodel_message message);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback type for prompt processing.
|
* Callback type for prompt processing.
|
||||||
* @param token_id The token id of the prompt.
|
* @param token_id The token id of the prompt.
|
||||||
@ -164,8 +181,9 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
|||||||
/**
|
/**
|
||||||
* Generate a response using the model.
|
* Generate a response using the model.
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param prompt A string representing the input prompt.
|
* @param messages An array of messages.
|
||||||
* @param prompt_template A string representing the input prompt template.
|
* @param n_messages The number of messages.
|
||||||
|
* @param framing_callback A callback function for retrieving the message framing strings.
|
||||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||||
* @param response_callback A callback function for handling the generated response.
|
* @param response_callback A callback function for handling the generated response.
|
||||||
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
|
||||||
@ -173,8 +191,10 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
|||||||
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
|
||||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||||
*/
|
*/
|
||||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
void llmodel_prompt(llmodel_model model,
|
||||||
const char *prompt_template,
|
llmodel_message *messages,
|
||||||
|
size_t n_messages,
|
||||||
|
llmodel_framing_callback framing_callback,
|
||||||
llmodel_prompt_callback prompt_callback,
|
llmodel_prompt_callback prompt_callback,
|
||||||
llmodel_response_callback response_callback,
|
llmodel_response_callback response_callback,
|
||||||
bool allow_context_shift,
|
bool allow_context_shift,
|
||||||
|
@ -15,31 +15,8 @@
|
|||||||
|
|
||||||
namespace ranges = std::ranges;
|
namespace ranges = std::ranges;
|
||||||
|
|
||||||
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
|
void LLModel::prompt(const std::vector<Message> &messages,
|
||||||
{
|
std::function<MessageFrame(const Message &)> framingCallback,
|
||||||
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
|
||||||
|
|
||||||
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
|
||||||
placeholders.clear();
|
|
||||||
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
|
||||||
|
|
||||||
if (placeholders.size() > 2) {
|
|
||||||
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
|
||||||
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
|
||||||
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void LLModel::prompt(const std::string &prompt,
|
|
||||||
const std::string &promptTemplate,
|
|
||||||
std::function<bool(int32_t)> promptCallback,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
@ -78,45 +55,24 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
promptCtx.tokens.resize(promptCtx.n_past);
|
promptCtx.tokens.resize(promptCtx.n_past);
|
||||||
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
||||||
|
|
||||||
// parse the prompt template
|
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||||
std::vector<std::smatch> placeholders;
|
|
||||||
{
|
|
||||||
std::string err;
|
|
||||||
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
|
||||||
responseCallback(-1, err);
|
|
||||||
std::cerr << err << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
|
||||||
|
|
||||||
// tokenize the user prompt
|
|
||||||
std::vector<Token> embd_inp;
|
std::vector<Token> embd_inp;
|
||||||
if (placeholders.empty()) {
|
for (const Message &msg : messages) {
|
||||||
// this is unusual, but well-defined
|
const MessageFrame msgFrame = framingCallback(msg);
|
||||||
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
if (!msgFrame.before.empty()) {
|
||||||
embd_inp = tokenize(promptCtx, promptTemplate, true);
|
auto tokens = tokenize(promptCtx, msgFrame.before, true);
|
||||||
} else {
|
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||||
// template: beginning of user prompt
|
|
||||||
const auto &phUser = placeholders[0];
|
|
||||||
std::string userPrefix(phUser.prefix());
|
|
||||||
if (!userPrefix.empty()) {
|
|
||||||
embd_inp = tokenize(promptCtx, userPrefix, true);
|
|
||||||
promptCtx.n_past += embd_inp.size();
|
promptCtx.n_past += embd_inp.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// user input (shouldn't have special token processing)
|
// message content (shouldn't have special token processing)
|
||||||
auto tokens = tokenize(promptCtx, prompt, special);
|
auto tokens = tokenize(promptCtx, msg.content, special);
|
||||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||||
promptCtx.n_past += tokens.size();
|
promptCtx.n_past += tokens.size();
|
||||||
|
|
||||||
// template: end of user prompt + start of assistant prompt
|
if (!msgFrame.after.empty()) {
|
||||||
size_t start = phUser.position() + phUser.length();
|
tokens = tokenize(promptCtx, msgFrame.after, true);
|
||||||
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
|
||||||
auto userToAsst = promptTemplate.substr(start, end - start);
|
|
||||||
if (!userToAsst.empty()) {
|
|
||||||
tokens = tokenize(promptCtx, userToAsst, true);
|
|
||||||
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
||||||
promptCtx.n_past += tokens.size();
|
promptCtx.n_past += tokens.size();
|
||||||
}
|
}
|
||||||
@ -128,6 +84,24 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||||
return; // error
|
return; // error
|
||||||
|
|
||||||
|
// Nothing more to be done if we're not being asked to predict any tokens
|
||||||
|
if (promptCtx.n_predict < 1)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// If we're being asked to predict then we create an assistant message with framing strings
|
||||||
|
Message asstMsg;
|
||||||
|
asstMsg.role = "assistant";
|
||||||
|
MessageFrame asstMsgFrame = framingCallback(asstMsg);
|
||||||
|
|
||||||
|
// Tokenize and decode the prefixed assistant framing string if any
|
||||||
|
if (!asstMsgFrame.before.empty()) {
|
||||||
|
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||||
|
auto tokens = tokenize(promptCtx, asstMsgFrame.before, true);
|
||||||
|
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||||
|
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
|
||||||
|
return; // error
|
||||||
|
}
|
||||||
|
|
||||||
// decode the assistant's reply, either generated or spoofed
|
// decode the assistant's reply, either generated or spoofed
|
||||||
if (fakeReply == nullptr) {
|
if (fakeReply == nullptr) {
|
||||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||||
@ -137,18 +111,13 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
return; // error
|
return; // error
|
||||||
}
|
}
|
||||||
|
|
||||||
// decode the rest of the prompt template
|
// Tokenize and decode the suffixed assistant framing string if any
|
||||||
// template: end of assistant prompt
|
if (!asstMsgFrame.after.empty()) {
|
||||||
std::string asstSuffix;
|
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
||||||
if (placeholders.size() >= 2) {
|
auto tokens = tokenize(promptCtx, asstMsgFrame.after, true);
|
||||||
size_t start = placeholders[1].position() + placeholders[1].length();
|
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
||||||
asstSuffix = promptTemplate.substr(start);
|
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
|
||||||
} else {
|
return; // error
|
||||||
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
|
|
||||||
}
|
|
||||||
if (!asstSuffix.empty()) {
|
|
||||||
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
|
||||||
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,8 +86,8 @@ size_t ChatAPI::restoreState(const uint8_t *src)
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatAPI::prompt(const std::string &prompt,
|
void ChatAPI::prompt(const std::vector<Message> &messages,
|
||||||
const std::string &promptTemplate,
|
std::function<MessageFrame(const Message &)> framingCallback,
|
||||||
std::function<bool(int32_t)> promptCallback,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
|
@ -65,8 +65,8 @@ public:
|
|||||||
size_t stateSize() const override;
|
size_t stateSize() const override;
|
||||||
size_t saveState(uint8_t *dest) const override;
|
size_t saveState(uint8_t *dest) const override;
|
||||||
size_t restoreState(const uint8_t *src) override;
|
size_t restoreState(const uint8_t *src) override;
|
||||||
void prompt(const std::string &prompt,
|
void prompt(const std::vector<Message> &,
|
||||||
const std::string &promptTemplate,
|
std::function<MessageFrame(const Message &)> framingCallback,
|
||||||
std::function<bool(int32_t)> promptCallback,
|
std::function<bool(int32_t)> promptCallback,
|
||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
bool allowContextShift,
|
bool allowContextShift,
|
||||||
|
Loading…
Reference in New Issue
Block a user