diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 966367b3..f8e41faf 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -531,10 +531,10 @@ size_t LLamaModel::restoreState(const uint8_t *src) std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const { - const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty(); - const bool useBOS = wantBOS && shouldAddBOS(); + const bool atStart = ctx.n_past == 0 && ctx.tokens.empty(); std::vector fres(str.length() + 4); - auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, special); + int32_t fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), + /*add_special*/ atStart, /*parse_special*/ special); fres.resize(fres_len); return fres; } @@ -601,10 +601,7 @@ const std::vector &LLamaModel::endTokens() const bool LLamaModel::shouldAddBOS() const { - int add_bos = llama_add_bos_token(d_ptr->model); - if (add_bos != -1) { return add_bos; } - auto vocab_type = llama_vocab_type(d_ptr->model); - return vocab_type == LLAMA_VOCAB_TYPE_SPM || vocab_type == LLAMA_VOCAB_TYPE_WPM; + return llama_add_bos_token(d_ptr->model); } int32_t LLamaModel::maxContextLength(std::string const &modelPath) const @@ -946,7 +943,7 @@ void LLamaModel::embedInternal( const llama_token bos_token = llama_token_bos(d_ptr->model); const llama_token eos_token = llama_token_eos(d_ptr->model); - bool useBOS = shouldAddBOS(); + bool useBOS = llama_add_bos_token(d_ptr->model); bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM; // no EOS, optional BOS @@ -954,13 +951,13 @@ void LLamaModel::embedInternal( if (!text.empty() && text[0] != ' ') { text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix } - wantBOS &= useBOS; tokens.resize(text.length()+4); int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false); if (n_tokens) { (void)eos_token; - assert((useEOS && wantBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); + (void)useBOS; + assert((useEOS && wantBOS && useBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); if (useEOS && wantBOS) n_tokens--; // erase EOS/SEP }