From d14936bfd6023409af22c3136e8c6bea8e54f2f6 Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Mon, 15 May 2023 17:42:20 -0700 Subject: [PATCH] backend: dedupe tokenizing code in mpt/gptj --- gpt4all-backend/gptj.cpp | 2 +- gpt4all-backend/mpt.cpp | 100 +------------------------------------- gpt4all-backend/utils.cpp | 5 +- gpt4all-backend/utils.h | 1 + 4 files changed, 6 insertions(+), 102 deletions(-) diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 79129747..f0eb0eaa 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -983,7 +983,7 @@ void GPTJ::prompt(const std::string &prompt, gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(d_ptr->vocab, + id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.n_ctx, promptCtx.logits, diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index d316c03b..e57cc3c6 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -691,104 +691,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint return written; } -gpt_vocab::id mpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const size_t actualVocabSize, - const int32_t * last_n_tokens_data, - int last_n_tokens_size, - const std::vector logits, - int top_k, - double top_p, - double temp, - float repeat_penalty, - std::mt19937 & rng) { - int n_logits = actualVocabSize; - - const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); - const auto * plogits = logits.data() + logits.size() - n_logits; - - std::vector> logits_id; - logits_id.reserve(n_logits); - - { - const float scale = 1.0f/temp; - for (int i = 0; i < n_logits; ++i) { - // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (plogits[i] < 0.0f) { - logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); - } - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale, i)); - } - } - } - - // find the top K tokens - std::partial_sort( - logits_id.begin(), - logits_id.begin() + top_k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - logits_id.resize(top_k); - - double maxl = -INFINITY; - for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); - } - - // compute probs for the top K tokens - std::vector probs; - probs.reserve(logits_id.size()); - - double sum = 0.0; - for (const auto & kv : logits_id) { - double p = exp(kv.first - maxl); - probs.push_back(p); - sum += p; - } - - // normalize the probs - for (auto & p : probs) { - p /= sum; - } - - if (top_p < 1.0f) { - double cumsum = 0.0f; - for (int i = 0; i < top_k; i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - top_k = i + 1; - probs.resize(top_k); - logits_id.resize(top_k); - break; - } - } - - cumsum = 1.0/cumsum; - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; - } - } - - //printf("\n"); - //for (int i = 0; i < (int) probs.size(); i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); - //} - //exit(0); - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); - - return logits_id[idx].second; -} - size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src) { const uint8_t * in = src; @@ -1006,7 +908,7 @@ void MPT::prompt(const std::string &prompt, int id = 0; { const int64_t t_start_sample_us = ggml_time_us(); - id = mpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, + id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab, promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx, promptCtx.n_ctx, promptCtx.logits, diff --git a/gpt4all-backend/utils.cpp b/gpt4all-backend/utils.cpp index b1cb113b..783054f5 100644 --- a/gpt4all-backend/utils.cpp +++ b/gpt4all-backend/utils.cpp @@ -219,6 +219,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, + const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, const std::vector logits, @@ -227,7 +228,7 @@ gpt_vocab::id gpt_sample_top_k_top_p( double temp, float repeat_penalty, std::mt19937 & rng) { - int n_logits = vocab.id_to_token.size(); + int n_logits = actualVocabSize; const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); const auto * plogits = logits.data() + logits.size() - n_logits; @@ -312,4 +313,4 @@ gpt_vocab::id gpt_sample_top_k_top_p( int idx = dist(rng); return logits_id[idx].second; -} +} \ No newline at end of file diff --git a/gpt4all-backend/utils.h b/gpt4all-backend/utils.h index e51b66b0..9c9f5c60 100644 --- a/gpt4all-backend/utils.h +++ b/gpt4all-backend/utils.h @@ -80,6 +80,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, + const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, const std::vector logits,