diff --git a/gptj.cpp b/gptj.cpp index 07b5b3cc..e4ceacfe 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -683,8 +683,9 @@ bool GPTJ::isModelLoaded() const return d_ptr->modelLoaded; } -void GPTJ::prompt(const std::string &prompt, std::function response, - PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { +void GPTJ::prompt(const std::string &prompt, + std::function response, + PromptContext &promptCtx) { if (!isModelLoaded()) { std::cerr << "GPT-J ERROR: prompt won't work with an unloaded model!\n"; @@ -700,10 +701,11 @@ void GPTJ::prompt(const std::string &prompt, std::function embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt); - const int n_ctx = d_ptr->model.hparams.n_ctx; + // save the context size + promptCtx.n_ctx = d_ptr->model.hparams.n_ctx; - n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, n_ctx); + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); // determine the required inference memory per token: static bool initialized = false; @@ -719,13 +721,13 @@ void GPTJ::prompt(const std::string &prompt, std::function batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + batch.size() > n_ctx) { + if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size())); + promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; } @@ -736,7 +738,7 @@ void GPTJ::prompt(const std::string &prompt, std::functionmodel.hparams.n_vocab; gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(d_ptr->vocab, promptCtx.logits.data() + (promptCtx.logits.size() - n_vocab), - top_k, top_p, temp, d_ptr->rng); + id = gpt_sample_top_k_top_p(d_ptr->vocab, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, + promptCtx.n_ctx, + promptCtx.logits, + promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty, + d_ptr->rng); + t_sample_us += ggml_time_us() - t_start_sample_us; } // Check if the context has run out... - if (promptCtx.n_past + 1 > n_ctx) { + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); std::cerr << "GPT-J WARNING: reached the end of the context window!\n"; } @@ -777,7 +785,7 @@ void GPTJ::prompt(const std::string &prompt, std::functionvocab.id_to_token[id])) + if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id])) goto stop_generating; } diff --git a/gptj.h b/gptj.h index 72fc4109..a6a0b8dc 100644 --- a/gptj.h +++ b/gptj.h @@ -15,9 +15,9 @@ public: bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; - void prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, - float temp = 0.0f, int32_t n_batch = 9) override; + void prompt(const std::string &prompt, + std::function response, + PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; diff --git a/llamamodel.cpp b/llamamodel.cpp index 561ec5c8..693c05ea 100644 --- a/llamamodel.cpp +++ b/llamamodel.cpp @@ -78,8 +78,9 @@ bool LLamaModel::isModelLoaded() const return d_ptr->modelLoaded; } -void LLamaModel::prompt(const std::string &prompt, std::function response, - PromptContext &promptCtx, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch) { +void LLamaModel::prompt(const std::string &prompt, + std::function response, + PromptContext &promptCtx) { if (!isModelLoaded()) { std::cerr << "LLAMA ERROR: prompt won't work with an unloaded model!\n"; @@ -94,15 +95,17 @@ void LLamaModel::prompt(const std::string &prompt, std::functionctx, params.prompt, false); - const int n_ctx = llama_n_ctx(d_ptr->ctx); - if ((int) embd_inp.size() > n_ctx - 4) { + // save the context size + promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); + + if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { std::cerr << "LLAMA ERROR: prompt is too long\n"; return; } - n_predict = std::min(n_predict, n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, n_ctx); + promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); // number of tokens to keep when resetting context params.n_keep = (int)embd_inp.size(); @@ -111,13 +114,13 @@ void LLamaModel::prompt(const std::string &prompt, std::function batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + batch.size() > n_ctx) { + if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) { // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, int(n_ctx - batch.size())); + promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size())); std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; } @@ -129,7 +132,7 @@ void LLamaModel::prompt(const std::string &prompt, std::functionctx, {}, 0, top_k, top_p, temp, 1.0f); + llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, + promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + promptCtx.repeat_penalty); // Check if the context has run out... - if (promptCtx.n_past + 1 > n_ctx) { + if (promptCtx.n_past + 1 > promptCtx.n_ctx) { // FIXME: will produce gibberish after this - promptCtx.n_past = std::min(promptCtx.n_past, n_ctx - 1); + promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1); std::cerr << "LLAMA WARNING: reached the end of the context window!\n"; } @@ -156,7 +162,7 @@ void LLamaModel::prompt(const std::string &prompt, std::functionctx, id))) + if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id))) return; } } diff --git a/llamamodel.h b/llamamodel.h index 9ed73d6d..57eb4194 100644 --- a/llamamodel.h +++ b/llamamodel.h @@ -15,9 +15,9 @@ public: bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; - void prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, - float temp = 0.0f, int32_t n_batch = 9) override; + void prompt(const std::string &prompt, + std::function response, + PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; diff --git a/llm.cpp b/llm.cpp index 59fdeab4..332b1e85 100644 --- a/llm.cpp +++ b/llm.cpp @@ -124,6 +124,7 @@ void LLMObject::regenerateResponse() s_ctx.n_past = std::max(0, s_ctx.n_past); // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end()); + s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end()); m_responseTokens = 0; m_responseLogits = 0; m_response = std::string(); @@ -243,12 +244,20 @@ QList LLMObject::modelList() const return list; } -bool LLMObject::handleResponse(const std::string &response) +bool LLMObject::handleResponse(int32_t token, const std::string &response) { #if 0 printf("%s", response.c_str()); fflush(stdout); #endif + + // Save the token to our prompt ctxt + if (s_ctx.tokens.size() == s_ctx.n_ctx) + s_ctx.tokens.erase(s_ctx.tokens.begin()); + s_ctx.tokens.push_back(token); + + // m_responseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt ++m_responseTokens; if (!response.empty()) { m_response.append(response); @@ -271,10 +280,15 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in QString instructPrompt = prompt_template.arg(prompt); m_stopGenerating = false; - auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); + auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2); emit responseStarted(); qint32 logitsBefore = s_ctx.logits.size(); - m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx, n_predict, top_k, top_p, temp, n_batch); + s_ctx.n_predict = n_predict; + s_ctx.top_k = top_k; + s_ctx.top_p = top_p; + s_ctx.temp = temp; + s_ctx.n_batch = n_batch; + m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx); m_responseLogits += s_ctx.logits.size() - logitsBefore; std::string trimmed = trim_whitespace(m_response); if (trimmed != m_response) { diff --git a/llm.h b/llm.h index 93987e14..bf95a348 100644 --- a/llm.h +++ b/llm.h @@ -50,7 +50,7 @@ Q_SIGNALS: private: bool loadModelPrivate(const QString &modelName); - bool handleResponse(const std::string &response); + bool handleResponse(int32_t token, const std::string &response); private: LLModel *m_llmodel; diff --git a/llmodel.h b/llmodel.h index 829e4145..cacd23aa 100644 --- a/llmodel.h +++ b/llmodel.h @@ -14,12 +14,22 @@ public: virtual bool loadModel(const std::string &modelPath, std::istream &fin) = 0; virtual bool isModelLoaded() const = 0; struct PromptContext { - std::vector logits; - int32_t n_past = 0; // number of tokens in past conversation + std::vector logits; // logits of current context + std::vector tokens; // current tokens in the context window + int32_t n_past = 0; // number of tokens in past conversation + int32_t n_ctx = 0; // number of tokens possible in context window + int32_t n_predict = 200; + int32_t top_k = 40; + float top_p = 0.9f; + float temp = 0.9f; + int32_t n_batch = 9; + float repeat_penalty = 1.10f; + int32_t repeat_last_n = 64; // last n tokens to penalize + }; - virtual void prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, - float temp = 0.9f, int32_t n_batch = 9) = 0; + virtual void prompt(const std::string &prompt, + std::function response, + PromptContext &ctx) = 0; virtual void setThreadCount(int32_t n_threads) {} virtual int32_t threadCount() { return 1; } }; diff --git a/utils.cpp b/utils.cpp index a77fb7a3..b9b653f5 100644 --- a/utils.cpp +++ b/utils.cpp @@ -178,20 +178,37 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, - const float * logits, + const int32_t * last_n_tokens_data, + int last_n_tokens_size, + const std::vector logits, int top_k, double top_p, double temp, + float repeat_penalty, std::mt19937 & rng) { int n_logits = vocab.id_to_token.size(); + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); + const auto * plogits = logits.data() + logits.size() - n_logits; + std::vector> logits_id; logits_id.reserve(n_logits); { - const double scale = 1.0/temp; + const float scale = 1.0f/temp; for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } } } diff --git a/utils.h b/utils.h index b61173ff..90cfdd97 100644 --- a/utils.h +++ b/utils.h @@ -72,12 +72,14 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // - from them, consider only the top tokens with cumulative probability > P // // TODO: not sure if this implementation is correct -// TODO: temperature is not implemented // gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, - const float * logits, + const int32_t * last_n_tokens_data, + int last_n_tokens_size, + const std::vector logits, int top_k, double top_p, double temp, + float repeat_penalty, std::mt19937 & rng);