mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add reverse prompts for llama models.
This commit is contained in:
parent
01accf9e33
commit
82c1d08b33
@ -16,6 +16,7 @@
|
|||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
struct LLamaPrivate {
|
struct LLamaPrivate {
|
||||||
const std::string modelPath;
|
const std::string modelPath;
|
||||||
@ -144,6 +145,11 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
i = batch_end;
|
i = batch_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string cachedResponse;
|
||||||
|
std::vector<llama_token> cachedTokens;
|
||||||
|
std::unordered_set<std::string> reversePrompts
|
||||||
|
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
|
||||||
|
|
||||||
// predict next tokens
|
// predict next tokens
|
||||||
int32_t totalPredictions = 0;
|
int32_t totalPredictions = 0;
|
||||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||||
@ -175,11 +181,40 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
if (id == llama_token_eos())
|
if (id == llama_token_eos())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
const std::string str = llama_token_to_str(d_ptr->ctx, id);
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
||||||
promptCtx.tokens.push_back(id);
|
// Check if the provided str is part of our reverse prompts
|
||||||
if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
|
bool foundPartialReversePrompt = false;
|
||||||
|
const std::string completed = cachedResponse + str;
|
||||||
|
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it partially matches our reverse prompts and if so, cache
|
||||||
|
for (auto s : reversePrompts) {
|
||||||
|
if (s.compare(0, completed.size(), completed) == 0) {
|
||||||
|
foundPartialReversePrompt = true;
|
||||||
|
cachedResponse = completed;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regardless the token gets added to our cache
|
||||||
|
cachedTokens.push_back(id);
|
||||||
|
|
||||||
|
// Continue if we have found a partial match
|
||||||
|
if (foundPartialReversePrompt)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Empty the cache
|
||||||
|
for (auto t : cachedTokens) {
|
||||||
|
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||||
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||||
|
promptCtx.tokens.push_back(t);
|
||||||
|
if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t)))
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cachedTokens.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user