mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
backend: fix extra spaces in tokenization and a CUDA crash (#2778)
Also potentially improves accuracy of BOS insertion, token cache, and logit indexing. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
da59c9f5ea
commit
51bd01ae05
@ -1 +1 @@
|
|||||||
Subproject commit c6546b0544ad2c01e8a1630b101e92336a68b036
|
Subproject commit add387854ea73d83770a62282089dea666fa266f
|
@ -145,9 +145,8 @@ static int llama_sample_top_p_top_k(
|
|||||||
float top_p,
|
float top_p,
|
||||||
float min_p,
|
float min_p,
|
||||||
float temp,
|
float temp,
|
||||||
float repeat_penalty,
|
float repeat_penalty) {
|
||||||
int32_t pos) {
|
auto logits = llama_get_logits_ith(ctx, -1);
|
||||||
auto logits = llama_get_logits_ith(ctx, pos);
|
|
||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
// Populate initial list of all candidates
|
// Populate initial list of all candidates
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -529,13 +528,21 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
|||||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
|
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
|
||||||
{
|
{
|
||||||
const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty();
|
bool atStart = m_tokenize_last_token == -1;
|
||||||
const bool useBOS = wantBOS && shouldAddBOS();
|
bool insertSpace = atStart || (
|
||||||
|
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
|
||||||
|
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
|
||||||
|
);
|
||||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||||
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, special);
|
int32_t fres_len = llama_tokenize_gpt4all(
|
||||||
|
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||||
|
/*parse_special*/ special, /*insert_space*/ insertSpace
|
||||||
|
);
|
||||||
fres.resize(fres_len);
|
fres.resize(fres_len);
|
||||||
|
if (fres_len)
|
||||||
|
m_tokenize_last_token = fres.back();
|
||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -561,7 +568,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
|||||||
return llama_sample_top_p_top_k(d_ptr->ctx,
|
return llama_sample_top_p_top_k(d_ptr->ctx,
|
||||||
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||||
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
|
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
|
||||||
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
|
promptCtx.repeat_penalty);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
||||||
@ -571,7 +578,6 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
|
|||||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||||
|
|
||||||
batch.n_tokens = tokens.size();
|
batch.n_tokens = tokens.size();
|
||||||
ctx.n_last_batch_tokens = tokens.size();
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
batch.token [i] = tokens[i];
|
batch.token [i] = tokens[i];
|
||||||
@ -601,10 +607,7 @@ const std::vector<LLModel::Token> &LLamaModel::endTokens() const
|
|||||||
|
|
||||||
bool LLamaModel::shouldAddBOS() const
|
bool LLamaModel::shouldAddBOS() const
|
||||||
{
|
{
|
||||||
int add_bos = llama_add_bos_token(d_ptr->model);
|
return llama_add_bos_token(d_ptr->model);
|
||||||
if (add_bos != -1) { return add_bos; }
|
|
||||||
auto vocab_type = llama_vocab_type(d_ptr->model);
|
|
||||||
return vocab_type == LLAMA_VOCAB_TYPE_SPM || vocab_type == LLAMA_VOCAB_TYPE_WPM;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
|
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
|
||||||
@ -946,7 +949,7 @@ void LLamaModel::embedInternal(
|
|||||||
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
||||||
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
||||||
|
|
||||||
bool useBOS = shouldAddBOS();
|
bool useBOS = llama_add_bos_token(d_ptr->model);
|
||||||
bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
|
bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
|
||||||
|
|
||||||
// no EOS, optional BOS
|
// no EOS, optional BOS
|
||||||
@ -954,13 +957,16 @@ void LLamaModel::embedInternal(
|
|||||||
if (!text.empty() && text[0] != ' ') {
|
if (!text.empty() && text[0] != ' ') {
|
||||||
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
|
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
|
||||||
}
|
}
|
||||||
wantBOS &= useBOS;
|
|
||||||
|
|
||||||
tokens.resize(text.length()+4);
|
tokens.resize(text.length()+4);
|
||||||
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
|
int32_t n_tokens = llama_tokenize_gpt4all(
|
||||||
|
d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), /*add_special*/ wantBOS,
|
||||||
|
/*parse_special*/ false, /*insert_space*/ false
|
||||||
|
);
|
||||||
if (n_tokens) {
|
if (n_tokens) {
|
||||||
(void)eos_token;
|
(void)eos_token;
|
||||||
assert((useEOS && wantBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
(void)useBOS;
|
||||||
|
assert((useEOS && wantBOS && useBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
||||||
if (useEOS && wantBOS)
|
if (useEOS && wantBOS)
|
||||||
n_tokens--; // erase EOS/SEP
|
n_tokens--; // erase EOS/SEP
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ private:
|
|||||||
bool m_supportsCompletion = false;
|
bool m_supportsCompletion = false;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override;
|
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
|
||||||
std::string tokenToString(Token id) const override;
|
std::string tokenToString(Token id) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
|
@ -14,11 +14,12 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
class Dlhandle;
|
||||||
|
|
||||||
using namespace std::string_literals;
|
using namespace std::string_literals;
|
||||||
|
|
||||||
#define LLMODEL_MAX_PROMPT_BATCH 128
|
#define LLMODEL_MAX_PROMPT_BATCH 128
|
||||||
|
|
||||||
class Dlhandle;
|
|
||||||
class LLModel {
|
class LLModel {
|
||||||
public:
|
public:
|
||||||
using Token = int32_t;
|
using Token = int32_t;
|
||||||
@ -134,7 +135,6 @@ public:
|
|||||||
float repeat_penalty = 1.10f;
|
float repeat_penalty = 1.10f;
|
||||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||||
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
|
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
|
||||||
int32_t n_last_batch_tokens = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
using ProgressCallback = std::function<bool(float progress)>;
|
using ProgressCallback = std::function<bool(float progress)>;
|
||||||
@ -212,7 +212,7 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||||
// 'prompt' above calls these functions
|
// 'prompt' above calls these functions
|
||||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) const = 0;
|
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
|
||||||
virtual std::string tokenToString(Token id) const = 0;
|
virtual std::string tokenToString(Token id) const = 0;
|
||||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||||
@ -256,7 +256,8 @@ protected:
|
|||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &promptCtx);
|
PromptContext &promptCtx);
|
||||||
|
|
||||||
private:
|
Token m_tokenize_last_token = -1; // not serialized
|
||||||
|
|
||||||
friend class LLMImplementation;
|
friend class LLMImplementation;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -117,9 +117,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
return response_callback(token_id, response.c_str());
|
return response_callback(token_id, response.c_str());
|
||||||
};
|
};
|
||||||
|
|
||||||
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
|
|
||||||
wrapper->promptContext.tokens.resize(ctx->n_past);
|
|
||||||
|
|
||||||
// Copy the C prompt context
|
// Copy the C prompt context
|
||||||
wrapper->promptContext.n_past = ctx->n_past;
|
wrapper->promptContext.n_past = ctx->n_past;
|
||||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||||
|
@ -30,8 +30,6 @@ typedef void *llmodel_model;
|
|||||||
* behavior.
|
* behavior.
|
||||||
*/
|
*/
|
||||||
struct llmodel_prompt_context {
|
struct llmodel_prompt_context {
|
||||||
float *logits; // logits of current context
|
|
||||||
size_t logits_size; // the size of the raw logits vector
|
|
||||||
int32_t *tokens; // current tokens in the context window
|
int32_t *tokens; // current tokens in the context window
|
||||||
size_t tokens_size; // the size of the raw tokens vector
|
size_t tokens_size; // the size of the raw tokens vector
|
||||||
int32_t n_past; // number of tokens in past conversation
|
int32_t n_past; // number of tokens in past conversation
|
||||||
|
@ -8,12 +8,16 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
|
||||||
|
// FIXME(jared): if recalculate returns false, we leave n_past<tokens.size() and do not tell the caller to stop
|
||||||
|
// FIXME(jared): if we get here during chat name or follow-up generation, bad things will happen when we try to restore
|
||||||
|
// the old prompt context afterwards
|
||||||
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||||
{
|
{
|
||||||
int n_keep = shouldAddBOS();
|
int n_keep = shouldAddBOS();
|
||||||
@ -88,6 +92,16 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make sure token cache matches decode offset
|
||||||
|
if (promptCtx.tokens.size() < promptCtx.n_past) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past;
|
||||||
|
throw std::out_of_range(ss.str());
|
||||||
|
}
|
||||||
|
if (promptCtx.n_past < promptCtx.tokens.size())
|
||||||
|
promptCtx.tokens.resize(promptCtx.n_past);
|
||||||
|
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
||||||
|
|
||||||
// parse the prompt template
|
// parse the prompt template
|
||||||
std::vector<std::smatch> placeholders;
|
std::vector<std::smatch> placeholders;
|
||||||
{
|
{
|
||||||
@ -201,8 +215,6 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
|||||||
|
|
||||||
size_t tokens = batch_end - i;
|
size_t tokens = batch_end - i;
|
||||||
for (size_t t = 0; t < tokens; ++t) {
|
for (size_t t = 0; t < tokens; ++t) {
|
||||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
||||||
promptCtx.tokens.push_back(batch.at(t));
|
promptCtx.tokens.push_back(batch.at(t));
|
||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
if (!promptCallback(batch.at(t)))
|
if (!promptCallback(batch.at(t)))
|
||||||
@ -270,8 +282,6 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
|||||||
|
|
||||||
// Empty the cache
|
// Empty the cache
|
||||||
for (auto t : cachedTokens) {
|
for (auto t : cachedTokens) {
|
||||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
|
||||||
promptCtx.tokens.push_back(t);
|
promptCtx.tokens.push_back(t);
|
||||||
promptCtx.n_past += 1;
|
promptCtx.n_past += 1;
|
||||||
//TODO: Conversion to std::string can be avoided here...
|
//TODO: Conversion to std::string can be avoided here...
|
||||||
|
@ -73,8 +73,6 @@ llmodel = load_llmodel_library()
|
|||||||
|
|
||||||
class LLModelPromptContext(ctypes.Structure):
|
class LLModelPromptContext(ctypes.Structure):
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("logits", ctypes.POINTER(ctypes.c_float)),
|
|
||||||
("logits_size", ctypes.c_size_t),
|
|
||||||
("tokens", ctypes.POINTER(ctypes.c_int32)),
|
("tokens", ctypes.POINTER(ctypes.c_int32)),
|
||||||
("tokens_size", ctypes.c_size_t),
|
("tokens_size", ctypes.c_size_t),
|
||||||
("n_past", ctypes.c_int32),
|
("n_past", ctypes.c_int32),
|
||||||
@ -351,7 +349,6 @@ class LLModel:
|
|||||||
):
|
):
|
||||||
if self.context is None:
|
if self.context is None:
|
||||||
context = LLModelPromptContext(
|
context = LLModelPromptContext(
|
||||||
logits_size=0,
|
|
||||||
tokens_size=0,
|
tokens_size=0,
|
||||||
n_past=0,
|
n_past=0,
|
||||||
n_ctx=0,
|
n_ctx=0,
|
||||||
|
@ -97,7 +97,7 @@ protected:
|
|||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||||
// completely replace
|
// completely replace
|
||||||
|
|
||||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) const override {
|
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
|
||||||
(void)ctx;
|
(void)ctx;
|
||||||
(void)str;
|
(void)str;
|
||||||
(void)special;
|
(void)special;
|
||||||
|
@ -611,6 +611,7 @@ std::string trim_whitespace(const std::string& input)
|
|||||||
return std::string(first_non_whitespace, last_non_whitespace);
|
return std::string(first_non_whitespace, last_non_whitespace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response
|
||||||
void ChatLLM::regenerateResponse()
|
void ChatLLM::regenerateResponse()
|
||||||
{
|
{
|
||||||
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
|
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
|
||||||
|
Loading…
Reference in New Issue
Block a user