diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp-mainline index c6546b05..add38785 160000 --- a/gpt4all-backend/llama.cpp-mainline +++ b/gpt4all-backend/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit c6546b0544ad2c01e8a1630b101e92336a68b036 +Subproject commit add387854ea73d83770a62282089dea666fa266f diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 966367b3..cab0e75a 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -145,9 +145,8 @@ static int llama_sample_top_p_top_k( float top_p, float min_p, float temp, - float repeat_penalty, - int32_t pos) { - auto logits = llama_get_logits_ith(ctx, pos); + float repeat_penalty) { + auto logits = llama_get_logits_ith(ctx, -1); auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Populate initial list of all candidates std::vector candidates; @@ -529,13 +528,21 @@ size_t LLamaModel::restoreState(const uint8_t *src) return llama_set_state_data(d_ptr->ctx, const_cast(src)); } -std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const +std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) { - const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty(); - const bool useBOS = wantBOS && shouldAddBOS(); + bool atStart = m_tokenize_last_token == -1; + bool insertSpace = atStart || ( + llama_token_get_attr(d_ptr->model, m_tokenize_last_token) + & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN) + ); std::vector fres(str.length() + 4); - auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, special); + int32_t fres_len = llama_tokenize_gpt4all( + d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, + /*parse_special*/ special, /*insert_space*/ insertSpace + ); fres.resize(fres_len); + if (fres_len) + m_tokenize_last_token = fres.back(); return fres; } @@ -561,7 +568,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const return llama_sample_top_p_top_k(d_ptr->ctx, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp, - promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1); + promptCtx.repeat_penalty); } bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const @@ -571,7 +578,6 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &toke llama_batch batch = llama_batch_init(tokens.size(), 0, 1); batch.n_tokens = tokens.size(); - ctx.n_last_batch_tokens = tokens.size(); for (int32_t i = 0; i < batch.n_tokens; i++) { batch.token [i] = tokens[i]; @@ -601,10 +607,7 @@ const std::vector &LLamaModel::endTokens() const bool LLamaModel::shouldAddBOS() const { - int add_bos = llama_add_bos_token(d_ptr->model); - if (add_bos != -1) { return add_bos; } - auto vocab_type = llama_vocab_type(d_ptr->model); - return vocab_type == LLAMA_VOCAB_TYPE_SPM || vocab_type == LLAMA_VOCAB_TYPE_WPM; + return llama_add_bos_token(d_ptr->model); } int32_t LLamaModel::maxContextLength(std::string const &modelPath) const @@ -946,7 +949,7 @@ void LLamaModel::embedInternal( const llama_token bos_token = llama_token_bos(d_ptr->model); const llama_token eos_token = llama_token_eos(d_ptr->model); - bool useBOS = shouldAddBOS(); + bool useBOS = llama_add_bos_token(d_ptr->model); bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM; // no EOS, optional BOS @@ -954,13 +957,16 @@ void LLamaModel::embedInternal( if (!text.empty() && text[0] != ' ') { text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix } - wantBOS &= useBOS; tokens.resize(text.length()+4); - int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false); + int32_t n_tokens = llama_tokenize_gpt4all( + d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), /*add_special*/ wantBOS, + /*parse_special*/ false, /*insert_space*/ false + ); if (n_tokens) { (void)eos_token; - assert((useEOS && wantBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); + (void)useBOS; + assert((useEOS && wantBOS && useBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); if (useEOS && wantBOS) n_tokens--; // erase EOS/SEP } diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index f162a94d..019e5532 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -53,7 +53,7 @@ private: bool m_supportsCompletion = false; protected: - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) const override; + std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override; std::string tokenToString(Token id) const override; Token sampleToken(PromptContext &ctx) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 112421a9..f95dc3a8 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -14,11 +14,12 @@ #include #include +class Dlhandle; + using namespace std::string_literals; #define LLMODEL_MAX_PROMPT_BATCH 128 -class Dlhandle; class LLModel { public: using Token = int32_t; @@ -134,7 +135,6 @@ public: float repeat_penalty = 1.10f; int32_t repeat_last_n = 64; // last n tokens to penalize float contextErase = 0.75f; // percent of context to erase if we exceed the context window - int32_t n_last_batch_tokens = 0; }; using ProgressCallback = std::function; @@ -212,7 +212,7 @@ public: protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions - virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0; + virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0; virtual std::string tokenToString(Token id) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0; virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) const = 0; @@ -256,7 +256,8 @@ protected: std::function recalculateCallback, PromptContext &promptCtx); -private: + Token m_tokenize_last_token = -1; // not serialized + friend class LLMImplementation; }; diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 2f83376a..d6ba95e8 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -117,9 +117,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt, return response_callback(token_id, response.c_str()); }; - if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size()) - wrapper->promptContext.tokens.resize(ctx->n_past); - // Copy the C prompt context wrapper->promptContext.n_past = ctx->n_past; wrapper->promptContext.n_ctx = ctx->n_ctx; diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index ced23a96..75cfb862 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -30,8 +30,6 @@ typedef void *llmodel_model; * behavior. */ struct llmodel_prompt_context { - float *logits; // logits of current context - size_t logits_size; // the size of the raw logits vector int32_t *tokens; // current tokens in the context window size_t tokens_size; // the size of the raw tokens vector int32_t n_past; // number of tokens in past conversation diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index be02b65b..68ea42e4 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -8,12 +8,16 @@ #include #include #include +#include #include #include #include #include // TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is) +// FIXME(jared): if recalculate returns false, we leave n_past recalculate) { int n_keep = shouldAddBOS(); @@ -88,6 +92,16 @@ void LLModel::prompt(const std::string &prompt, return; } + // make sure token cache matches decode offset + if (promptCtx.tokens.size() < promptCtx.n_past) { + std::ostringstream ss; + ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past; + throw std::out_of_range(ss.str()); + } + if (promptCtx.n_past < promptCtx.tokens.size()) + promptCtx.tokens.resize(promptCtx.n_past); + m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized + // parse the prompt template std::vector placeholders; { @@ -201,8 +215,6 @@ bool LLModel::decodePrompt(std::function promptCallback, size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(batch.at(t)); promptCtx.n_past += 1; if (!promptCallback(batch.at(t))) @@ -270,8 +282,6 @@ void LLModel::generateResponse(std::function // Empty the cache for (auto t : cachedTokens) { - if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) - promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(t); promptCtx.n_past += 1; //TODO: Conversion to std::string can be avoided here... diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 892d72e7..284d8ab3 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -73,8 +73,6 @@ llmodel = load_llmodel_library() class LLModelPromptContext(ctypes.Structure): _fields_ = [ - ("logits", ctypes.POINTER(ctypes.c_float)), - ("logits_size", ctypes.c_size_t), ("tokens", ctypes.POINTER(ctypes.c_int32)), ("tokens_size", ctypes.c_size_t), ("n_past", ctypes.c_int32), @@ -351,7 +349,6 @@ class LLModel: ): if self.context is None: context = LLModelPromptContext( - logits_size=0, tokens_size=0, n_past=0, n_ctx=0, diff --git a/gpt4all-chat/chatapi.h b/gpt4all-chat/chatapi.h index 90aaf5b7..51ba8067 100644 --- a/gpt4all-chat/chatapi.h +++ b/gpt4all-chat/chatapi.h @@ -97,7 +97,7 @@ protected: // them as they are only called from the default implementation of 'prompt' which we override and // completely replace - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) const override { + std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override { (void)ctx; (void)str; (void)special; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index a0165026..18319cee 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -611,6 +611,7 @@ std::string trim_whitespace(const std::string& input) return std::string(first_non_whitespace, last_non_whitespace); } +// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response void ChatLLM::regenerateResponse() { // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning