Proposed modification to the prompt method.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-08-25 09:34:57 -04:00
parent 221b9cff5a
commit 7037eb04d1
6 changed files with 107 additions and 82 deletions

View File

@ -137,6 +137,16 @@ public:
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
struct Message {
std::string content;
std::string role;
};
struct MessageFrame {
std::string before;
std::string after;
};
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {}
@ -155,8 +165,8 @@ public:
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
virtual void prompt(const std::string &prompt,
const std::string &promptTemplate,
virtual void prompt(const std::vector<Message> &messages,
std::function<MessageFrame(const Message &)> framingCallback,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,

View File

@ -102,8 +102,28 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
return wrapper->llModel->restoreState(src);
}
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
LLModel::Message convertToMessage(const llmodel_message &msg) {
return {std::string(msg.content), std::string(msg.role)};
}
LLModel::MessageFrame convertToMessageFrame(const llmodel_message_frame &frame)
{
return {std::string(frame.before), std::string(frame.after)};
}
auto wrapFramingCallback(llmodel_framing_callback c_callback)
{
return [c_callback](const LLModel::Message &msg) -> LLModel::MessageFrame {
llmodel_message c_message = {msg.content.c_str(), msg.role.c_str()};
llmodel_message_frame c_frame = c_callback(c_message);
return convertToMessageFrame(c_frame);
};
}
void llmodel_prompt(llmodel_model model,
llmodel_message *messages,
size_t n_messages,
llmodel_framing_callback framing_callback,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,
@ -135,8 +155,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);
std::vector<LLModel::Message> messagesVec;
for (size_t i = 0; i < n_messages; ++i)
messagesVec.push_back(convertToMessage(messages[i]));
auto cpp_framing_callback = wrapFramingCallback(framing_callback);
wrapper->llModel->prompt({ messagesVec }, cpp_framing_callback, prompt_callback, response_func,
allow_context_shift, wrapper->promptContext, special, fake_reply_p);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies

View File

@ -54,11 +54,28 @@ struct llmodel_gpu_device {
const char * vendor;
};
struct llmodel_message {
const char * content;
const char * role;
};
struct llmodel_message_frame {
const char * before;
const char * after;
};
#ifndef __cplusplus
typedef struct llmodel_prompt_context llmodel_prompt_context;
typedef struct llmodel_gpu_device llmodel_gpu_device;
#endif
/**
* Callback type for framing strings.
* @param message The message.
* @return a message frame with framing strings.
*/
typedef llmodel_message_frame (*llmodel_framing_callback)(const llmodel_message message);
/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
@ -164,8 +181,9 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
/**
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param prompt_template A string representing the input prompt template.
* @param messages An array of messages.
* @param n_messages The number of messages.
* @param framing_callback A callback function for retrieving the message framing strings.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
@ -173,8 +191,10 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
void llmodel_prompt(llmodel_model model,
llmodel_message *messages,
size_t n_messages,
llmodel_framing_callback framing_callback,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
bool allow_context_shift,

View File

@ -15,31 +15,8 @@
namespace ranges = std::ranges;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
{
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
placeholders.clear();
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
if (placeholders.size() > 2) {
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
return false;
}
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
return false;
}
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
return false;
}
return true;
}
void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
void LLModel::prompt(const std::vector<Message> &messages,
std::function<MessageFrame(const Message &)> framingCallback,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
@ -78,45 +55,24 @@ void LLModel::prompt(const std::string &prompt,
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
// parse the prompt template
std::vector<std::smatch> placeholders;
{
std::string err;
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
responseCallback(-1, err);
std::cerr << err << "\n";
return;
}
}
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptCtx, promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty()) {
embd_inp = tokenize(promptCtx, userPrefix, true);
for (const Message &msg : messages) {
const MessageFrame msgFrame = framingCallback(msg);
if (!msgFrame.before.empty()) {
auto tokens = tokenize(promptCtx, msgFrame.before, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += embd_inp.size();
}
// user input (shouldn't have special token processing)
auto tokens = tokenize(promptCtx, prompt, special);
// message content (shouldn't have special token processing)
auto tokens = tokenize(promptCtx, msg.content, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(promptCtx, userToAsst, true);
if (!msgFrame.after.empty()) {
tokens = tokenize(promptCtx, msgFrame.after, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
}
@ -128,6 +84,24 @@ void LLModel::prompt(const std::string &prompt,
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
// Nothing more to be done if we're not being asked to predict any tokens
if (promptCtx.n_predict < 1)
return;
// If we're being asked to predict then we create an assistant message with framing strings
Message asstMsg;
asstMsg.role = "assistant";
MessageFrame asstMsgFrame = framingCallback(asstMsg);
// Tokenize and decode the prefixed assistant framing string if any
if (!asstMsgFrame.before.empty()) {
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
auto tokens = tokenize(promptCtx, asstMsgFrame.before, true);
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
return; // error
}
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, allowContextShift, promptCtx);
@ -137,18 +111,13 @@ void LLModel::prompt(const std::string &prompt,
return; // error
}
// decode the rest of the prompt template
// template: end of assistant prompt
std::string asstSuffix;
if (placeholders.size() >= 2) {
size_t start = placeholders[1].position() + placeholders[1].length();
asstSuffix = promptTemplate.substr(start);
} else {
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
// Tokenize and decode the suffixed assistant framing string if any
if (!asstMsgFrame.after.empty()) {
const auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
auto tokens = tokenize(promptCtx, asstMsgFrame.after, true);
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, tokens))
return; // error
}
}

View File

@ -86,8 +86,8 @@ size_t ChatAPI::restoreState(const uint8_t *src)
return 0;
}
void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
void ChatAPI::prompt(const std::vector<Message> &messages,
std::function<MessageFrame(const Message &)> framingCallback,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,

View File

@ -65,8 +65,8 @@ public:
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void prompt(const std::string &prompt,
const std::string &promptTemplate,
void prompt(const std::vector<Message> &,
std::function<MessageFrame(const Message &)> framingCallback,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,