llamamodel: restore leading space removal logic

When llama.cpp was updated, I removed the space removal logic, but it
turns out it's still actually needed. This is now a proper parameter, as
we specifically only want to disable the *leading* space when we are
tokenizing input that comes after a normal token.

This fixes a regression in commit 290c6294 ("backend: rebase llama.cpp
submodule on latest upstream (#2694)").

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-07-30 15:38:09 -04:00
parent 44c8467e6d
commit b1ebe63820
6 changed files with 19 additions and 10 deletions

@ -1 +1 @@
Subproject commit c6546b0544ad2c01e8a1630b101e92336a68b036
Subproject commit 527a0f503ff2a98bd164bc3a06b1b71fdd11bab9

View File

@ -529,13 +529,19 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
const bool atStart = ctx.n_past == 0;
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(),
/*add_special*/ atStart, /*parse_special*/ special);
/*add_special*/ atStart, /*parse_special*/ special, /*insert_space*/ insertSpace);
fres.resize(fres_len);
if (fres_len)
m_tokenize_last_token = fres.back();
return fres;
}
@ -953,7 +959,7 @@ void LLamaModel::embedInternal(
}
tokens.resize(text.length()+4);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false, false);
if (n_tokens) {
(void)eos_token;
(void)useBOS;

View File

@ -53,7 +53,7 @@ private:
bool m_supportsCompletion = false;
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;

View File

@ -14,11 +14,12 @@
#include <utility>
#include <vector>
class Dlhandle;
using namespace std::string_literals;
#define LLMODEL_MAX_PROMPT_BATCH 128
class Dlhandle;
class LLModel {
public:
using Token = int32_t;
@ -212,7 +213,7 @@ public:
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0;
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
@ -256,7 +257,8 @@ protected:
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx);
private:
Token m_tokenize_last_token = -1; // not serialized
friend class LLMImplementation;
};

View File

@ -97,6 +97,7 @@ void LLModel::prompt(const std::string &prompt,
}
if (promptCtx.n_past < promptCtx.tokens.size())
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
// parse the prompt template
std::vector<std::smatch> placeholders;

View File

@ -97,7 +97,7 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override {
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
(void)ctx;
(void)str;
(void)special;