Add reverse prompt support for gptj too.

This commit is contained in:
Adam Treat 2023-05-05 11:16:24 -04:00
parent 06bb6960d4
commit d0d5d84e06

View File

@ -14,6 +14,7 @@
#include <iostream>
#include <unistd.h>
#include <sstream>
#include <unordered_set>
// default hparams (GPT-J 6B)
static const size_t MB = 1024*1024;
@ -968,6 +969,11 @@ void GPTJ::prompt(const std::string &prompt,
int p_instructFound = 0;
int r_instructFound = 0;
std::string cachedResponse;
std::vector<gpt_vocab::id> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) {
@ -1014,11 +1020,40 @@ void GPTJ::prompt(const std::string &prompt,
if (id == 50256 /*end of text*/)
goto stop_generating;
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(id);
if (!responseCallback(id, d_ptr->vocab.id_to_token[id]))
const std::string str = d_ptr->vocab.id_to_token[id];
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + str;
if (reversePrompts.find(completed) != reversePrompts.end()) {
goto stop_generating;
}
// Check if it partially matches our reverse prompts and if so, cache
for (auto s : reversePrompts) {
if (s.compare(0, completed.size(), completed) == 0) {
foundPartialReversePrompt = true;
cachedResponse = completed;
break;
}
}
// Regardless the token gets added to our cache
cachedTokens.push_back(id);
// Continue if we have found a partial match
if (foundPartialReversePrompt)
continue;
// Empty the cache
for (auto t : cachedTokens) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
goto stop_generating;
}
cachedTokens.clear();
}
stop_generating: