Check for ###Prompt: or ###Response and stop generating and modify the default template a little bit.

This commit is contained in:
Adam Treat 2023-04-16 11:12:22 -04:00
parent d4767478fc
commit 185dc2460e
2 changed files with 54 additions and 14 deletions

View File

@ -691,9 +691,13 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
// determine the required inference memory per token:
static bool initialized = false;
static std::vector<gpt_vocab::id> p_instruct;
static std::vector<gpt_vocab::id> r_instruct;
size_t mem_per_token = 0;
if (!initialized) {
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, ctx.logits, mem_per_token);
p_instruct = ::gpt_tokenize(d_ptr->vocab, "### Prompt:");
r_instruct = ::gpt_tokenize(d_ptr->vocab, "### Response:");
initialized = true;
}
@ -717,6 +721,11 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
}
t_prompt_us += ggml_time_us() - t_start_prompt_us;
int p_instructFound = 0;
int r_instructFound = 0;
std::vector<gpt_vocab::id> cachedTokens;
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < n_predict; i++) {
@ -736,15 +745,46 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
return;
}
t_predict_us += ggml_time_us() - t_start_predict_us;
ctx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id]))
break;
cachedTokens.emplace_back(id);
// Check if this token is next token for p_instruct or r_instruct
if (p_instruct.at(p_instructFound) == id) {
++p_instructFound;
if (p_instructFound == p_instruct.size()) {
fprintf(stderr, "Warning: Tried to generate \"### Prompt:\" stopping.\n");
fflush(stderr);
goto stop_generating;
}
continue;
} else
p_instructFound = 0;
if (r_instruct.at(r_instructFound) == id) {
++r_instructFound;
if (r_instructFound == r_instruct.size()) {
fprintf(stderr, "Warning: Tried to generate \"### Response:\" stopping.\n");
fflush(stderr);
goto stop_generating;
}
continue;
} else
r_instructFound = 0;
t_predict_us += ggml_time_us() - t_start_predict_us;
for (int j = 0; j < cachedTokens.size(); ++j) {
gpt_vocab::id cachedToken = cachedTokens.at(j);
ctx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[cachedToken]))
goto stop_generating;
}
cachedTokens.clear();
}
stop_generating:
#if 0
// report timing
{

View File

@ -59,12 +59,13 @@ Window {
property int topK: 40
property int maxLength: 4096
property int promptBatchSize: 9
property string promptTemplate: "Below is a prompt for either a task to complete or a piece of conversation. Decide which and write an appropriate response to the prompt.
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
%1
### Response:
"
### Response:\n"
property string promptTemplate: ""
function restoreDefaults() {
temperature = 0.9;
@ -72,12 +73,11 @@ Window {
topK = 40;
maxLength = 4096;
promptBatchSize = 9;
promptTemplate = "Below is a prompt for either a task to complete or a piece of conversation. Decide which and write an appropriate response to the prompt.
promptTemplate = defaultPromptTemplate;
}
### Prompt:
%1
### Response:
";
Component.onCompleted: {
promptTemplate = defaultPromptTemplate;
}
GridLayout {