2023-06-02 10:47:12 -04:00
|
|
|
#include "llmodel.h"
|
|
|
|
|
|
|
|
#include <cassert>
|
|
|
|
#include <iostream>
|
2024-02-21 15:45:32 -05:00
|
|
|
#include <regex>
|
2024-03-13 18:09:24 -04:00
|
|
|
#include <string>
|
2023-06-04 08:59:24 -04:00
|
|
|
#include <unordered_set>
|
2023-06-02 10:47:12 -04:00
|
|
|
|
2024-02-21 15:45:32 -05:00
|
|
|
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
2023-06-02 10:47:12 -04:00
|
|
|
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
|
2024-02-21 15:45:32 -05:00
|
|
|
int n_keep = shouldAddBOS();
|
|
|
|
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
|
|
|
|
|
|
|
|
// Erase the first percentage of context from the tokens
|
|
|
|
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
|
|
|
|
|
|
|
size_t i = n_keep;
|
|
|
|
promptCtx.n_past = n_keep;
|
2023-06-02 10:47:12 -04:00
|
|
|
while (i < promptCtx.tokens.size()) {
|
|
|
|
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
|
|
|
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
|
|
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
|
|
|
if (!evalTokens(promptCtx, batch)) {
|
|
|
|
std::cerr << "LLModel ERROR: Failed to process prompt\n";
|
|
|
|
goto stop_generating;
|
|
|
|
}
|
|
|
|
promptCtx.n_past += batch.size();
|
|
|
|
if (!recalculate(true))
|
|
|
|
goto stop_generating;
|
|
|
|
i = batch_end;
|
|
|
|
}
|
|
|
|
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
|
|
|
|
|
|
|
stop_generating:
|
|
|
|
recalculate(false);
|
|
|
|
}
|
2023-06-04 08:59:24 -04:00
|
|
|
|
2024-02-21 15:45:32 -05:00
|
|
|
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err) {
|
|
|
|
static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))");
|
|
|
|
|
|
|
|
auto it = std::sregex_iterator(tmpl.begin(), tmpl.end(), placeholderRegex);
|
|
|
|
placeholders.clear();
|
|
|
|
placeholders.insert(placeholders.end(), it, std::sregex_iterator());
|
|
|
|
|
|
|
|
if (placeholders.size() > 2) {
|
|
|
|
err = "ERROR: expected at most two placeholders, got " + std::to_string(placeholders.size());
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (placeholders.size() >= 1 && placeholders[0].str() != "%1") {
|
|
|
|
err = "ERROR: first placeholder must be %1, got " + placeholders[0].str();
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (placeholders.size() >= 2 && placeholders[1].str() != "%2") {
|
|
|
|
err = "ERROR: second placeholder must be %2, got " + placeholders[1].str();
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2023-06-04 08:59:24 -04:00
|
|
|
void LLModel::prompt(const std::string &prompt,
|
2024-02-21 15:45:32 -05:00
|
|
|
const std::string &promptTemplate,
|
2023-06-04 08:59:24 -04:00
|
|
|
std::function<bool(int32_t)> promptCallback,
|
|
|
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
|
|
|
std::function<bool(bool)> recalculateCallback,
|
2024-02-21 15:45:32 -05:00
|
|
|
PromptContext &promptCtx,
|
|
|
|
bool special,
|
|
|
|
std::string *fakeReply)
|
2023-06-04 08:59:24 -04:00
|
|
|
{
|
|
|
|
if (!isModelLoaded()) {
|
2023-07-08 10:04:38 -04:00
|
|
|
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
2023-06-04 08:59:24 -04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2023-07-09 11:32:51 -04:00
|
|
|
if (!supportsCompletion()) {
|
2024-02-21 15:45:32 -05:00
|
|
|
std::string errorMessage = "ERROR: this model does not support text completion or chat!";
|
2023-07-09 11:32:51 -04:00
|
|
|
responseCallback(-1, errorMessage);
|
2024-02-21 15:45:32 -05:00
|
|
|
std::cerr << implementation().modelType() << " " << errorMessage << "\n";
|
2023-07-09 11:32:51 -04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2024-02-21 15:45:32 -05:00
|
|
|
// parse the prompt template
|
|
|
|
std::vector<std::smatch> placeholders;
|
|
|
|
{
|
|
|
|
std::string err;
|
|
|
|
if (!parsePromptTemplate(promptTemplate, placeholders, err)) {
|
|
|
|
responseCallback(-1, err);
|
|
|
|
std::cerr << err << "\n";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
|
2023-06-04 08:59:24 -04:00
|
|
|
|
2024-02-21 15:45:32 -05:00
|
|
|
// tokenize the user prompt
|
|
|
|
std::vector<Token> embd_inp;
|
|
|
|
if (placeholders.empty()) {
|
|
|
|
// this is unusual, but well-defined
|
|
|
|
std::cerr << __func__ << ": prompt template has no placeholder\n";
|
|
|
|
embd_inp = tokenize(promptCtx, promptTemplate, true);
|
|
|
|
} else {
|
|
|
|
// template: beginning of user prompt
|
|
|
|
const auto &phUser = placeholders[0];
|
|
|
|
std::string userPrefix(phUser.prefix());
|
|
|
|
if (!userPrefix.empty()) {
|
|
|
|
embd_inp = tokenize(promptCtx, userPrefix, true);
|
|
|
|
promptCtx.n_past += embd_inp.size();
|
|
|
|
}
|
|
|
|
|
|
|
|
// user input (shouldn't have special token processing)
|
|
|
|
auto tokens = tokenize(promptCtx, prompt, special);
|
|
|
|
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
|
|
|
promptCtx.n_past += tokens.size();
|
|
|
|
|
|
|
|
// template: end of user prompt + start of assistant prompt
|
|
|
|
size_t start = phUser.position() + phUser.length();
|
|
|
|
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
|
|
|
|
auto userToAsst = promptTemplate.substr(start, end - start);
|
|
|
|
if (!userToAsst.empty()) {
|
|
|
|
tokens = tokenize(promptCtx, userToAsst, true);
|
|
|
|
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
|
|
|
|
promptCtx.n_past += tokens.size();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
|
|
|
|
|
|
|
|
// decode the user prompt
|
|
|
|
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
|
|
|
|
|
|
|
// decode the assistant's reply, either generated or spoofed
|
|
|
|
if (fakeReply == nullptr) {
|
|
|
|
generateResponse(responseCallback, recalculateCallback, promptCtx);
|
|
|
|
} else {
|
|
|
|
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
|
|
|
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
|
|
|
}
|
|
|
|
|
|
|
|
// decode the rest of the prompt template
|
2024-02-26 13:11:15 -05:00
|
|
|
// template: end of assistant prompt
|
|
|
|
std::string asstSuffix;
|
2024-02-21 15:45:32 -05:00
|
|
|
if (placeholders.size() >= 2) {
|
|
|
|
size_t start = placeholders[1].position() + placeholders[1].length();
|
2024-02-26 13:11:15 -05:00
|
|
|
asstSuffix = promptTemplate.substr(start);
|
|
|
|
} else {
|
|
|
|
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
|
|
|
|
}
|
|
|
|
if (!asstSuffix.empty()) {
|
|
|
|
embd_inp = tokenize(promptCtx, asstSuffix, true);
|
|
|
|
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
|
2024-02-21 15:45:32 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
|
|
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
|
|
|
std::function<bool(bool)> recalculateCallback,
|
|
|
|
PromptContext &promptCtx,
|
|
|
|
std::vector<Token> embd_inp) {
|
2023-06-04 08:59:24 -04:00
|
|
|
// save the context size
|
|
|
|
promptCtx.n_ctx = contextLength();
|
|
|
|
|
|
|
|
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
|
|
|
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
2023-07-09 11:00:20 -04:00
|
|
|
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
|
|
|
" tokens and the context window is " << promptCtx.n_ctx << "!\n";
|
2023-06-04 08:59:24 -04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
|
|
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
2023-06-30 19:13:25 -04:00
|
|
|
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
2023-06-04 08:59:24 -04:00
|
|
|
|
|
|
|
// process the prompt in batches
|
|
|
|
size_t i = 0;
|
|
|
|
while (i < embd_inp.size()) {
|
|
|
|
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
|
|
|
std::vector<Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
|
|
|
|
|
|
|
// Check if the context has run out...
|
|
|
|
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
|
|
|
recalculateContext(promptCtx, recalculateCallback);
|
|
|
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!evalTokens(promptCtx, batch)) {
|
2023-07-08 10:04:38 -04:00
|
|
|
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
|
2023-06-04 08:59:24 -04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t tokens = batch_end - i;
|
|
|
|
for (size_t t = 0; t < tokens; ++t) {
|
|
|
|
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
|
|
promptCtx.tokens.push_back(batch.at(t));
|
2023-10-03 12:42:31 -04:00
|
|
|
promptCtx.n_past += 1;
|
2023-06-04 08:59:24 -04:00
|
|
|
if (!promptCallback(batch.at(t)))
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
i = batch_end;
|
|
|
|
}
|
2024-02-21 15:45:32 -05:00
|
|
|
}
|
2023-06-04 08:59:24 -04:00
|
|
|
|
2024-02-21 15:45:32 -05:00
|
|
|
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
|
|
|
std::function<bool(bool)> recalculateCallback,
|
|
|
|
PromptContext &promptCtx) {
|
2023-06-04 08:59:24 -04:00
|
|
|
std::string cachedResponse;
|
|
|
|
std::vector<Token> cachedTokens;
|
|
|
|
std::unordered_set<std::string> reversePrompts
|
|
|
|
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
|
|
|
|
|
|
|
// predict next tokens
|
|
|
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
|
|
|
|
|
|
|
// sample next token
|
|
|
|
auto id = sampleToken(promptCtx);
|
|
|
|
|
|
|
|
// Check if the context has run out...
|
|
|
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
|
|
|
recalculateContext(promptCtx, recalculateCallback);
|
|
|
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!evalTokens(promptCtx, { id })) {
|
2023-07-08 10:04:38 -04:00
|
|
|
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
2023-06-04 08:59:24 -04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// display text
|
|
|
|
for (const auto token : endTokens()) {
|
|
|
|
if (id == token) return;
|
|
|
|
}
|
|
|
|
|
2023-06-13 07:14:02 -04:00
|
|
|
const std::string str = tokenToString(id);
|
2023-06-04 08:59:24 -04:00
|
|
|
|
|
|
|
// Check if the provided str is part of our reverse prompts
|
|
|
|
bool foundPartialReversePrompt = false;
|
|
|
|
const std::string completed = cachedResponse + std::string(str);
|
|
|
|
if (reversePrompts.find(completed) != reversePrompts.end())
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Check if it partially matches our reverse prompts and if so, cache
|
|
|
|
for (const auto& s : reversePrompts) {
|
|
|
|
if (s.compare(0, completed.size(), completed) == 0) {
|
|
|
|
foundPartialReversePrompt = true;
|
|
|
|
cachedResponse = completed;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Regardless the token gets added to our cache
|
|
|
|
cachedTokens.push_back(id);
|
|
|
|
|
|
|
|
// Continue if we have found a partial match
|
|
|
|
if (foundPartialReversePrompt)
|
|
|
|
continue;
|
|
|
|
|
|
|
|
// Empty the cache
|
|
|
|
for (auto t : cachedTokens) {
|
|
|
|
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
|
|
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
|
|
promptCtx.tokens.push_back(t);
|
2023-10-03 12:42:31 -04:00
|
|
|
promptCtx.n_past += 1;
|
2023-06-04 08:59:24 -04:00
|
|
|
//TODO: Conversion to std::string can be avoided here...
|
|
|
|
if (!responseCallback(t, std::string(tokenToString(t))))
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
cachedTokens.clear();
|
|
|
|
}
|
|
|
|
}
|
2023-07-09 11:32:51 -04:00
|
|
|
|
2024-03-13 18:09:24 -04:00
|
|
|
void LLModel::embed(
|
|
|
|
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
2024-04-12 16:00:39 -04:00
|
|
|
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb
|
2024-03-13 18:09:24 -04:00
|
|
|
) {
|
|
|
|
(void)texts;
|
|
|
|
(void)embeddings;
|
|
|
|
(void)prefix;
|
|
|
|
(void)dimensionality;
|
2024-03-20 11:24:02 -04:00
|
|
|
(void)tokenCount;
|
2024-03-13 18:09:24 -04:00
|
|
|
(void)doMean;
|
|
|
|
(void)atlas;
|
2024-04-12 16:00:39 -04:00
|
|
|
(void)cancelCb;
|
2024-03-13 18:09:24 -04:00
|
|
|
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
|
|
|
}
|
|
|
|
|
|
|
|
void LLModel::embed(
|
2024-03-20 11:24:02 -04:00
|
|
|
const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality, size_t *tokenCount,
|
|
|
|
bool doMean, bool atlas
|
2024-03-13 18:09:24 -04:00
|
|
|
) {
|
|
|
|
(void)texts;
|
|
|
|
(void)embeddings;
|
|
|
|
(void)isRetrieval;
|
|
|
|
(void)dimensionality;
|
2024-03-20 11:24:02 -04:00
|
|
|
(void)tokenCount;
|
2024-03-13 18:09:24 -04:00
|
|
|
(void)doMean;
|
|
|
|
(void)atlas;
|
|
|
|
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
2023-07-09 11:32:51 -04:00
|
|
|
}
|