mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Implement repeat penalty for both llama and gptj in gui.
This commit is contained in:
parent
a02b0c14ca
commit
a79bc4233c
38
gptj.cpp
38
gptj.cpp
@ -683,8 +683,9 @@ bool GPTJ::isModelLoaded() const
|
|||||||
return d_ptr->modelLoaded;
|
return d_ptr->modelLoaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void GPTJ::prompt(const std::string &prompt,
|
||||||
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n";
|
std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n";
|
||||||
@ -700,10 +701,11 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
||||||
|
|
||||||
const int n_ctx = d_ptr->model.hparams.n_ctx;
|
// save the context size
|
||||||
|
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
|
||||||
|
|
||||||
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||||
|
|
||||||
// determine the required inference memory per token:
|
// determine the required inference memory per token:
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
@ -719,13 +721,13 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
const int64_t t_start_prompt_us = ggml_time_us();
|
const int64_t t_start_prompt_us = ggml_time_us();
|
||||||
while (i < embd_inp.size()) {
|
while (i < embd_inp.size()) {
|
||||||
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||||
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + batch.size() > n_ctx) {
|
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
// FIXME: will produce gibberish after this
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -736,7 +738,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|||||||
// We pass a null string for each token to see if the user has asked us to stop...
|
// We pass a null string for each token to see if the user has asked us to stop...
|
||||||
size_t tokens = batch_end - i;
|
size_t tokens = batch_end - i;
|
||||||
for (size_t t = 0; t < tokens; ++t)
|
for (size_t t = 0; t < tokens; ++t)
|
||||||
if (!response(""))
|
if (!response(batch.at(t), ""))
|
||||||
return;
|
return;
|
||||||
promptCtx.n_past += batch.size();
|
promptCtx.n_past += batch.size();
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
@ -748,22 +750,28 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|||||||
|
|
||||||
// predict next tokens
|
// predict next tokens
|
||||||
int32_t totalPredictions = 0;
|
int32_t totalPredictions = 0;
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||||
|
|
||||||
// sample next token
|
// sample next token
|
||||||
const int n_vocab = d_ptr->model.hparams.n_vocab;
|
const int n_vocab = d_ptr->model.hparams.n_vocab;
|
||||||
gpt_vocab::id id = 0;
|
gpt_vocab::id id = 0;
|
||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, promptCtx.logits.data() + (promptCtx.logits.size() - n_vocab),
|
id = gpt_sample_top_k_top_p(d_ptr->vocab,
|
||||||
top_k, top_p, temp, d_ptr->rng);
|
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
|
||||||
|
promptCtx.n_ctx,
|
||||||
|
promptCtx.logits,
|
||||||
|
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||||
|
promptCtx.repeat_penalty,
|
||||||
|
d_ptr->rng);
|
||||||
|
|
||||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + 1 > n_ctx) {
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
// FIXME: will produce gibberish after this
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||||
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -777,7 +785,7 @@ void GPTJ::prompt(const std::string &prompt, std::function<bool(const std::strin
|
|||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
// display text
|
// display text
|
||||||
++totalPredictions;
|
++totalPredictions;
|
||||||
if (id == 50256 /*end of text*/ || !response(d_ptr->vocab.id_to_token[id]))
|
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
6
gptj.h
6
gptj.h
@ -15,9 +15,9 @@ public:
|
|||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath) override;
|
||||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void prompt(const std::string &prompt,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
|
||||||
|
@ -78,8 +78,9 @@ bool LLamaModel::isModelLoaded() const
|
|||||||
return d_ptr->modelLoaded;
|
return d_ptr->modelLoaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void LLamaModel::prompt(const std::string &prompt,
|
||||||
PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) {
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
|
PromptContext &promptCtx) {
|
||||||
|
|
||||||
if (!isModelLoaded()) {
|
if (!isModelLoaded()) {
|
||||||
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n";
|
||||||
@ -94,15 +95,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
|||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||||
const int n_ctx = llama_n_ctx(d_ptr->ctx);
|
|
||||||
|
|
||||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
// save the context size
|
||||||
|
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||||
|
|
||||||
|
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||||
std::cerr << "LLAMA ERROR: prompt is too long\n";
|
std::cerr << "LLAMA ERROR: prompt is too long\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size());
|
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx);
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
|
||||||
|
|
||||||
// number of tokens to keep when resetting context
|
// number of tokens to keep when resetting context
|
||||||
params.n_keep = (int)embd_inp.size();
|
params.n_keep = (int)embd_inp.size();
|
||||||
@ -111,13 +114,13 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
|||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
const int64_t t_start_prompt_us = ggml_time_us();
|
const int64_t t_start_prompt_us = ggml_time_us();
|
||||||
while (i < embd_inp.size()) {
|
while (i < embd_inp.size()) {
|
||||||
size_t batch_end = std::min(i + n_batch, embd_inp.size());
|
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||||
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + batch.size() > n_ctx) {
|
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
// FIXME: will produce gibberish after this
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size()));
|
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
|
||||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,7 +132,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
|||||||
// We pass a null string for each token to see if the user has asked us to stop...
|
// We pass a null string for each token to see if the user has asked us to stop...
|
||||||
size_t tokens = batch_end - i;
|
size_t tokens = batch_end - i;
|
||||||
for (size_t t = 0; t < tokens; ++t)
|
for (size_t t = 0; t < tokens; ++t)
|
||||||
if (!response(""))
|
if (!response(batch.at(t), ""))
|
||||||
return;
|
return;
|
||||||
promptCtx.n_past += batch.size();
|
promptCtx.n_past += batch.size();
|
||||||
i = batch_end;
|
i = batch_end;
|
||||||
@ -137,14 +140,17 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
|||||||
|
|
||||||
// predict next tokens
|
// predict next tokens
|
||||||
int32_t totalPredictions = 0;
|
int32_t totalPredictions = 0;
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||||
// sample next token
|
// sample next token
|
||||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, {}, 0, top_k, top_p, temp, 1.0f);
|
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
|
||||||
|
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n,
|
||||||
|
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||||
|
promptCtx.repeat_penalty);
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
if (promptCtx.n_past + 1 > n_ctx) {
|
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||||
// FIXME: will produce gibberish after this
|
// FIXME: will produce gibberish after this
|
||||||
promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1);
|
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
|
||||||
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,7 +162,7 @@ void LLamaModel::prompt(const std::string &prompt, std::function<bool(const std:
|
|||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
// display text
|
// display text
|
||||||
++totalPredictions;
|
++totalPredictions;
|
||||||
if (id == llama_token_eos() || !response(llama_token_to_str(d_ptr->ctx, id)))
|
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,9 +15,9 @@ public:
|
|||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath) override;
|
||||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
void prompt(const std::string &prompt,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
PromptContext &ctx) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() override;
|
int32_t threadCount() override;
|
||||||
|
|
||||||
|
20
llm.cpp
20
llm.cpp
@ -124,6 +124,7 @@ void LLMObject::regenerateResponse()
|
|||||||
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
||||||
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
||||||
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
||||||
|
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
|
||||||
m_responseTokens = 0;
|
m_responseTokens = 0;
|
||||||
m_responseLogits = 0;
|
m_responseLogits = 0;
|
||||||
m_response = std::string();
|
m_response = std::string();
|
||||||
@ -243,12 +244,20 @@ QList<QString> LLMObject::modelList() const
|
|||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLMObject::handleResponse(const std::string &response)
|
bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||||
{
|
{
|
||||||
#if 0
|
#if 0
|
||||||
printf("%s", response.c_str());
|
printf("%s", response.c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Save the token to our prompt ctxt
|
||||||
|
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||||
|
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||||
|
s_ctx.tokens.push_back(token);
|
||||||
|
|
||||||
|
// m_responseTokens and m_responseLogits are related to last prompt/response not
|
||||||
|
// the entire context window which we can reset on regenerate prompt
|
||||||
++m_responseTokens;
|
++m_responseTokens;
|
||||||
if (!response.empty()) {
|
if (!response.empty()) {
|
||||||
m_response.append(response);
|
m_response.append(response);
|
||||||
@ -271,10 +280,15 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
QString instructPrompt = prompt_template.arg(prompt);
|
QString instructPrompt = prompt_template.arg(prompt);
|
||||||
|
|
||||||
m_stopGenerating = false;
|
m_stopGenerating = false;
|
||||||
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
|
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||||
emit responseStarted();
|
emit responseStarted();
|
||||||
qint32 logitsBefore = s_ctx.logits.size();
|
qint32 logitsBefore = s_ctx.logits.size();
|
||||||
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch);
|
s_ctx.n_predict = n_predict;
|
||||||
|
s_ctx.top_k = top_k;
|
||||||
|
s_ctx.top_p = top_p;
|
||||||
|
s_ctx.temp = temp;
|
||||||
|
s_ctx.n_batch = n_batch;
|
||||||
|
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
|
||||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||||
std::string trimmed = trim_whitespace(m_response);
|
std::string trimmed = trim_whitespace(m_response);
|
||||||
if (trimmed != m_response) {
|
if (trimmed != m_response) {
|
||||||
|
2
llm.h
2
llm.h
@ -50,7 +50,7 @@ Q_SIGNALS:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool loadModelPrivate(const QString &modelName);
|
bool loadModelPrivate(const QString &modelName);
|
||||||
bool handleResponse(const std::string &response);
|
bool handleResponse(int32_t token, const std::string &response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLModel *m_llmodel;
|
LLModel *m_llmodel;
|
||||||
|
20
llmodel.h
20
llmodel.h
@ -14,12 +14,22 @@ public:
|
|||||||
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0;
|
||||||
virtual bool isModelLoaded() const = 0;
|
virtual bool isModelLoaded() const = 0;
|
||||||
struct PromptContext {
|
struct PromptContext {
|
||||||
std::vector<float> logits;
|
std::vector<float> logits; // logits of current context
|
||||||
int32_t n_past = 0; // number of tokens in past conversation
|
std::vector<int32_t> tokens; // current tokens in the context window
|
||||||
|
int32_t n_past = 0; // number of tokens in past conversation
|
||||||
|
int32_t n_ctx = 0; // number of tokens possible in context window
|
||||||
|
int32_t n_predict = 200;
|
||||||
|
int32_t top_k = 40;
|
||||||
|
float top_p = 0.9f;
|
||||||
|
float temp = 0.9f;
|
||||||
|
int32_t n_batch = 9;
|
||||||
|
float repeat_penalty = 1.10f;
|
||||||
|
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||||
|
|
||||||
};
|
};
|
||||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
virtual void prompt(const std::string &prompt,
|
||||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
std::function<bool(int32_t, const std::string&)> response,
|
||||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
PromptContext &ctx) = 0;
|
||||||
virtual void setThreadCount(int32_t n_threads) {}
|
virtual void setThreadCount(int32_t n_threads) {}
|
||||||
virtual int32_t threadCount() { return 1; }
|
virtual int32_t threadCount() { return 1; }
|
||||||
};
|
};
|
||||||
|
23
utils.cpp
23
utils.cpp
@ -178,20 +178,37 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
|||||||
|
|
||||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
const gpt_vocab & vocab,
|
const gpt_vocab & vocab,
|
||||||
const float * logits,
|
const int32_t * last_n_tokens_data,
|
||||||
|
int last_n_tokens_size,
|
||||||
|
const std::vector<float> logits,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
double top_p,
|
||||||
double temp,
|
double temp,
|
||||||
|
float repeat_penalty,
|
||||||
std::mt19937 & rng) {
|
std::mt19937 & rng) {
|
||||||
int n_logits = vocab.id_to_token.size();
|
int n_logits = vocab.id_to_token.size();
|
||||||
|
|
||||||
|
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||||
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||||
|
|
||||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||||
logits_id.reserve(n_logits);
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
{
|
{
|
||||||
const double scale = 1.0/temp;
|
const float scale = 1.0f/temp;
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||||
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||||
|
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||||
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
|
if (plogits[i] < 0.0f) {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||||
|
} else {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
6
utils.h
6
utils.h
@ -72,12 +72,14 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
|||||||
// - from them, consider only the top tokens with cumulative probability > P
|
// - from them, consider only the top tokens with cumulative probability > P
|
||||||
//
|
//
|
||||||
// TODO: not sure if this implementation is correct
|
// TODO: not sure if this implementation is correct
|
||||||
// TODO: temperature is not implemented
|
|
||||||
//
|
//
|
||||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
const gpt_vocab & vocab,
|
const gpt_vocab & vocab,
|
||||||
const float * logits,
|
const int32_t * last_n_tokens_data,
|
||||||
|
int last_n_tokens_size,
|
||||||
|
const std::vector<float> logits,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
double top_p,
|
||||||
double temp,
|
double temp,
|
||||||
|
float repeat_penalty,
|
||||||
std::mt19937 & rng);
|
std::mt19937 & rng);
|
||||||
|
Loading…
Reference in New Issue
Block a user