Add reverse prompts for llama models.

This commit is contained in:
Adam Treat 2023-05-03 11:58:26 -04:00
parent 01accf9e33
commit 82c1d08b33

View File

@ -16,6 +16,7 @@
#include <unistd.h> #include <unistd.h>
#include <random> #include <random>
#include <thread> #include <thread>
#include <unordered_set>
struct LLamaPrivate { struct LLamaPrivate {
const std::string modelPath; const std::string modelPath;
@ -144,6 +145,11 @@ void LLamaModel::prompt(const std::string &prompt,
i = batch_end; i = batch_end;
} }
std::string cachedResponse;
std::vector<llama_token> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
// predict next tokens // predict next tokens
int32_t totalPredictions = 0; int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) { for (int i = 0; i < promptCtx.n_predict; i++) {
@ -175,11 +181,40 @@ void LLamaModel::prompt(const std::string &prompt,
if (id == llama_token_eos()) if (id == llama_token_eos())
return; return;
if (promptCtx.tokens.size() == promptCtx.n_ctx) const std::string str = llama_token_to_str(d_ptr->ctx, id);
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(id); // Check if the provided str is part of our reverse prompts
if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id))) bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + str;
if (reversePrompts.find(completed) != reversePrompts.end()) {
return; return;
}
// Check if it partially matches our reverse prompts and if so, cache
for (auto s : reversePrompts) {
if (s.compare(0, completed.size(), completed) == 0) {
foundPartialReversePrompt = true;
cachedResponse = completed;
break;
}
}
// Regardless the token gets added to our cache
cachedTokens.push_back(id);
// Continue if we have found a partial match
if (foundPartialReversePrompt)
continue;
// Empty the cache
for (auto t : cachedTokens) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t)))
return;
}
cachedTokens.clear();
} }
} }