backend: rebase llama.cpp on upstream as of Sep 26th (#2998)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-09-27 12:05:59 -04:00 committed by GitHub
parent 8bd937eb68
commit f9d6be8afb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 165 additions and 600 deletions

@ -1 +1 @@
Subproject commit ced74fbad4b258507f3ec06e77eec9445583511a
Subproject commit b3b5c0571eda3065035a7f25f7b84640b159d821

View File

@ -7,6 +7,7 @@
#include <cstdint>
#include <functional>
#include <optional>
#include <span>
#include <stdexcept>
#include <string>
#include <string_view>
@ -149,9 +150,9 @@ public:
virtual bool isEmbeddingModel(const std::string &modelPath) const { (void)modelPath; return false; }
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; }
virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; }
virtual size_t stateSize() const = 0;
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
@ -215,7 +216,8 @@ protected:
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;

View File

@ -148,18 +148,20 @@ uint64_t llmodel_get_state_size(llmodel_model model);
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @return the number of bytes copied
* @param size The size of the destination buffer.
* @return the number of bytes copied, or zero on error.
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest);
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the src.
* @return the number of bytes read
* @param src A pointer to the state data.
* @param size The size of the source data.
* @return The number of bytes read, or zero on error.
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
/**
* Generate a response using the model.

View File

@ -978,10 +978,13 @@ function(include_ggml SUFFIX)
add_library(llama${SUFFIX} STATIC
${DIRECTORY}/include/llama.h
${DIRECTORY}/src/llama-grammar.cpp
${DIRECTORY}/src/llama-sampling.cpp
${DIRECTORY}/src/llama-vocab.cpp
${DIRECTORY}/src/llama.cpp
${DIRECTORY}/src/unicode.h
${DIRECTORY}/src/unicode.cpp
${DIRECTORY}/src/unicode-data.cpp
${DIRECTORY}/src/unicode.cpp
${DIRECTORY}/src/unicode.h
)
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY}/include ${DIRECTORY}/ggml/include)

View File

@ -2,6 +2,7 @@
#include "llamamodel_impl.h"
#include "llmodel.h"
#include "utils.h"
#include <ggml.h>
#include <llama.h>
@ -103,26 +104,34 @@ static bool llama_verbose()
return var && *var;
}
static void llama_log_callback(enum ggml_log_level level, const char *text, void *userdata)
static void llama_log_callback(ggml_log_level level, const char *text, void *userdata, bool warn)
{
(void)userdata;
if (llama_verbose() || level <= GGML_LOG_LEVEL_ERROR) {
fputs(text, stderr);
static ggml_log_level lastlevel = GGML_LOG_LEVEL_NONE;
if (!llama_verbose()) {
auto efflevel = level == GGML_LOG_LEVEL_CONT ? lastlevel : level;
lastlevel = efflevel;
switch (efflevel) {
case GGML_LOG_LEVEL_CONT:
UNREACHABLE();
break;
case GGML_LOG_LEVEL_WARN:
if (warn) break;
[[fallthrough]];
case GGML_LOG_LEVEL_NONE: // not used?
case GGML_LOG_LEVEL_INFO:
case GGML_LOG_LEVEL_DEBUG:
return; // suppress
case GGML_LOG_LEVEL_ERROR:
;
}
}
#ifdef GGML_USE_CUDA
static void cuda_log_callback(enum ggml_log_level level, const char *text, void *userdata)
{
(void)userdata;
if (llama_verbose() || level <= GGML_LOG_LEVEL_WARN) {
fputs(text, stderr);
}
}
#endif
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
// sampling parameters
@ -137,44 +146,6 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
};
static llama_token llama_sample_top_p_top_k(
llama_context *ctx,
const llama_token *last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float min_p,
float temp,
float repeat_penalty) {
auto logits = llama_get_logits_ith(ctx, -1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (int token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
llama_sample_repetition_penalties(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty, 0.0f, 0.0f);
llama_token id;
if (temp == 0.0) {
// greedy sampling, no probs
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
// temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
return id;
}
const char *get_arch_name(gguf_context *ctx_gguf)
{
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
@ -241,21 +212,26 @@ cleanup:
}
struct LLamaPrivate {
const std::string modelPath;
bool modelLoaded = false;
int device = -1;
std::string deviceName;
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
llama_model_params model_params;
llama_context_params ctx_params;
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
llama_sampler *sampler_chain;
};
LLamaModel::LLamaModel()
: d_ptr(new LLamaPrivate) {}
: d_ptr(std::make_unique<LLamaPrivate>())
{
auto sparams = llama_sampler_chain_default_params();
d_ptr->sampler_chain = llama_sampler_chain_init(sparams);
}
// default hparams (LLaMA 7B)
struct llama_file_hparams {
@ -445,7 +421,6 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
}
d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.type_k = params.kv_type;
d_ptr->ctx_params.type_v = params.kv_type;
@ -513,6 +488,7 @@ LLamaModel::~LLamaModel()
llama_free(d_ptr->ctx);
}
llama_free_model(d_ptr->model);
llama_sampler_free(d_ptr->sampler_chain);
}
bool LLamaModel::isModelLoaded() const
@ -522,18 +498,17 @@ bool LLamaModel::isModelLoaded() const
size_t LLamaModel::stateSize() const
{
return llama_get_state_size(d_ptr->ctx);
return llama_state_get_size(d_ptr->ctx);
}
size_t LLamaModel::saveState(uint8_t *dest) const
size_t LLamaModel::saveState(std::span<uint8_t> dest) const
{
return llama_copy_state_data(d_ptr->ctx, dest);
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
}
size_t LLamaModel::restoreState(const uint8_t *src)
size_t LLamaModel::restoreState(std::span<const uint8_t> src)
{
// const_cast is required, see: https://github.com/ggerganov/llama.cpp/pull/1540
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
@ -573,13 +548,50 @@ std::string LLamaModel::tokenToString(Token id) const
return std::string(result.data(), result.size());
}
LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
void LLamaModel::initSampler(PromptContext &promptCtx)
{
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
return llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
promptCtx.repeat_penalty);
auto *model = d_ptr->model;
auto *chain = d_ptr->sampler_chain;
// clear sampler chain
for (int i = llama_sampler_chain_n(chain) - 1; i >= 0; i--) {
auto *smpl = llama_sampler_chain_remove(chain, i);
llama_sampler_free(smpl);
}
// build new chain
llama_sampler_chain_add(chain,
llama_sampler_init_penalties(
llama_n_vocab(model),
llama_token_eos(model),
llama_token_nl(model),
promptCtx.repeat_last_n,
promptCtx.repeat_penalty,
// TODO(jared): consider making the below configurable
/*penalty_freq*/ 0.0f,
/*penalty_present*/ 0.0f,
/*penalize_nl*/ true,
/*ignore_eos*/ false
)
);
if (promptCtx.temp == 0.0f) {
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
} else {
struct llama_sampler *samplers[] = {
llama_sampler_init_top_k(promptCtx.top_k),
llama_sampler_init_top_p(promptCtx.top_p, 1),
llama_sampler_init_min_p(promptCtx.min_p, 1),
llama_sampler_init_temp(promptCtx.temp),
llama_sampler_init_dist(LLAMA_DEFAULT_SEED)
};
for (auto *smpl : samplers)
llama_sampler_chain_add(chain, smpl);
}
}
LLModel::Token LLamaModel::sampleToken() const
{
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
@ -1227,9 +1239,9 @@ DLL_EXPORT bool is_arch_supported(const char *arch)
DLL_EXPORT LLModel *construct()
{
llama_log_set(llama_log_callback, nullptr);
llama_log_set([](auto l, auto t, auto u) { llama_log_callback(l, t, u, false); }, nullptr);
#ifdef GGML_USE_CUDA
ggml_backend_cuda_log_set_callback(cuda_log_callback, nullptr);
ggml_backend_cuda_log_set_callback([](auto l, auto t, auto u) { llama_log_callback(l, t, u, true); }, nullptr);
#endif
return new LLamaModel;
}

View File

@ -7,6 +7,7 @@
#include "llmodel.h"
#include <memory>
#include <span>
#include <string>
#include <string_view>
#include <vector>
@ -27,8 +28,8 @@ public:
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
@ -56,7 +57,8 @@ protected:
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
void initSampler(PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;

View File

@ -91,16 +91,16 @@ uint64_t llmodel_get_state_size(llmodel_model model)
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->saveState(dest);
return wrapper->llModel->saveState({dest, size_t(size)});
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->restoreState(src);
return wrapper->llModel->restoreState({src, size_t(size)});
}
void llmodel_prompt(llmodel_model model, const char *prompt,

View File

@ -244,6 +244,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
return;
}
initSampler(promptCtx);
std::string cachedResponse;
std::vector<Token> cachedTokens;
int n_predicted = 0;
@ -251,7 +253,7 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
// Predict next tokens
for (bool stop = false; !stop;) {
// Sample next token
std::optional<Token> new_tok = sampleToken(promptCtx);
std::optional<Token> new_tok = sampleToken();
std::string new_piece = tokenToString(new_tok.value());
cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece;

View File

@ -1,49 +0,0 @@
#pragma once
#include <ggml.h>
#include <cstddef>
#include <cstdint>
#include <vector>
struct llm_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
}
~llm_buffer() {
delete[] addr;
}
};
struct llm_kv_cache {
struct ggml_tensor * k;
struct ggml_tensor * v;
struct ggml_context * ctx = NULL;
llm_buffer buf;
int n; // number of tokens currently in the cache
~llm_kv_cache() {
if (ctx) {
ggml_free(ctx);
}
}
};
inline void ggml_graph_compute_g4a(llm_buffer& buf, ggml_cgraph * graph, int n_threads)
{
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.addr;
}
ggml_graph_compute(graph, &plan);
}

View File

@ -1,339 +0,0 @@
#include "utils.h"
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iterator>
#include <regex>
#include <utility>
void replace(std::string & str, const std::string & needle, const std::string & replacement)
{
size_t pos = 0;
while ((pos = str.find(needle, pos)) != std::string::npos) {
str.replace(pos, needle.length(), replacement);
pos += replacement.length();
}
}
std::map<std::string, int32_t> json_parse(const std::string & fname)
{
std::map<std::string, int32_t> result;
// read file into string
std::string json;
{
std::ifstream ifs(fname);
if (!ifs) {
fprintf(stderr, "Failed to open %s\n", fname.c_str());
exit(1);
}
json = std::string((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
}
if (json[0] != '{') {
return result;
}
// parse json
{
bool has_key = false;
bool in_token = false;
std::string str_key = "";
std::string str_val = "";
int n = json.size();
for (int i = 1; i < n; ++i) {
if (!in_token) {
if (json[i] == ' ') continue;
if (json[i] == '"') {
in_token = true;
continue;
}
} else {
if (json[i] == '\\' && i+1 < n) {
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
++i;
} else if (json[i] == '"') {
if (has_key == false) {
has_key = true;
++i;
while (json[i] == ' ') ++i;
++i; // :
while (json[i] == ' ') ++i;
if (json[i] != '\"') {
while (json[i] != ',' && json[i] != '}') {
str_val += json[i++];
}
has_key = false;
} else {
in_token = true;
continue;
}
} else {
has_key = false;
}
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
::replace(str_key, "\\\"", "\""); // \\\" -> "
try {
result[str_key] = std::stoi(str_val);
} catch (...) {
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
}
str_key = "";
str_val = "";
in_token = false;
continue;
}
if (has_key == false) {
str_key += json[i];
} else {
str_val += json[i];
}
}
}
}
return result;
}
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text)
{
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
std::string regex_escape(const std::string &s)
{
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text)
{
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::vector<gpt_vocab::id> out;
std::vector<std::string> chunks;
std::string str = text;
std::string special_tokens_subpattern;
for (const auto &token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += regex_escape(token);
}
std::regex re(special_tokens_subpattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
auto tok = vocab.token_to_id.find(m.str());
if (tok != vocab.token_to_id.end()) {
auto tokid = tok->second;
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
out.push_back(tokid);
str = m.suffix();
}
}
if (!str.empty()) {
auto tokrest = gpt_tokenize_inner(vocab, str);
out.insert(out.end(), tokrest.begin(), tokrest.end());
}
return out;
} else {
return gpt_tokenize_inner(vocab, text);
}
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab)
{
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
vocab.token_to_id = ::json_parse(fname);
for (const auto & kv : vocab.token_to_id) {
vocab.id_to_token[kv.second] = kv.first;
}
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
// print the vocabulary
//for (auto kv : vocab.token_to_id) {
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
//}
return true;
}
gpt_vocab::id gpt_sample_top_k_top_p(
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data();
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}

View File

@ -1,101 +1,17 @@
// Various helper functions and utilities
#pragma once
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <map>
#include <random>
#include <string>
#include <thread>
#include <vector>
#include <cassert>
//
// General purpose inline functions
//
constexpr inline unsigned long long operator ""_MiB(unsigned long long bytes)
{
return bytes*1024*1024;
}
//
// CLI argument parsing
//
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t n_batch = 8; // batch size for prompt processing
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
//
// Vocab utils
//
struct gpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
void replace(std::string & str, const std::string & needle, const std::string & replacement);
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
//
// Regex (Python):
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
//
// Regex (C++):
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
// load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// sample next token given probabilities for each embedding
//
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
//
gpt_vocab::id gpt_sample_top_k_top_p(
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng);
#ifdef NDEBUG
# ifdef __has_builtin
# if __has_builtin(__builtin_unreachable)
# define UNREACHABLE() __builtin_unreachable()
# else
# define UNREACHABLE() do {} while (0)
# endif
# else
# define UNREACHABLE() do {} while (0)
# endif
#else
# define UNREACHABLE() assert(!"Unreachable statement was reached")
#endif

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Warn on Windows if the Microsoft Visual C++ runtime libraries are not found ([#2920](https://github.com/nomic-ai/gpt4all/pull/2920))
### Changed
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
## [2.8.2] - 2024-08-14
### Fixed

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Add bm25 hybrid search to localdocs ([#2969](https://github.com/nomic-ai/gpt4all/pull/2969))
### Changed
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
### Fixed
- Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995))
- Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996))

View File

@ -71,19 +71,19 @@ bool ChatAPI::isModelLoaded() const
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatAPI::stateSize() const
{
return 0;
throw std::logic_error("not implemented");
}
size_t ChatAPI::saveState(uint8_t *dest) const
size_t ChatAPI::saveState(std::span<uint8_t> dest) const
{
Q_UNUSED(dest);
return 0;
throw std::logic_error("not implemented");
}
size_t ChatAPI::restoreState(const uint8_t *src)
size_t ChatAPI::restoreState(std::span<const uint8_t> src)
{
Q_UNUSED(src);
return 0;
throw std::logic_error("not implemented");
}
void ChatAPI::prompt(const std::string &prompt,

View File

@ -64,8 +64,8 @@ public:
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
@ -118,12 +118,14 @@ protected:
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override
void initSampler(PromptContext &ctx) override
{
(void)ctx;
throw std::logic_error("not implemented");
}
Token sampleToken() const override { throw std::logic_error("not implemented"); }
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;

View File

@ -1174,7 +1174,13 @@ void ChatLLM::saveState()
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_llModelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
if (!ok) {
// FIXME(jared): how badly does this situation break GPT4All?
qWarning() << "ChatLLM failed to save LLModel state";
m_state.clear();
m_state.squeeze();
}
}
void ChatLLM::restoreState()
@ -1183,7 +1189,7 @@ void ChatLLM::restoreState()
return;
if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
QDataStream stream(m_state);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
QList<QString> context;
@ -1201,12 +1207,12 @@ void ChatLLM::restoreState()
if (m_state.isEmpty())
return;
if (m_llModelInfo.model->stateSize() == m_state.size()) {
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
if (bytesRead) {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
m_restoreStateFromText = true;
}