From 0efdbfcffee88b7b1e17b9d902eead654c38585a Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sun, 9 Jul 2023 11:32:51 -0400 Subject: [PATCH] Bert --- gpt4all-backend/bert.cpp | 732 ++++++++++-------- gpt4all-backend/bert.h | 71 -- gpt4all-backend/bert_impl.h | 44 ++ gpt4all-backend/falcon_impl.h | 2 + gpt4all-backend/gptj_impl.h | 2 + gpt4all-backend/llamamodel_impl.h | 2 + gpt4all-backend/llmodel.h | 7 + gpt4all-backend/llmodel_c.cpp | 19 + gpt4all-backend/llmodel_c.h | 17 + gpt4all-backend/llmodel_shared.cpp | 16 + gpt4all-backend/mpt_impl.h | 2 + gpt4all-backend/replit_impl.h | 2 + .../scripts/convert_bert_hf_to_ggml.py | 102 +++ gpt4all-bindings/python/gpt4all/__init__.py | 2 +- gpt4all-bindings/python/gpt4all/gpt4all.py | 14 + gpt4all-bindings/python/gpt4all/pyllmodel.py | 24 + .../python/gpt4all/tests/test_gpt4all.py | 10 +- gpt4all-chat/chatgpt.h | 2 + gpt4all-chat/chatllm.cpp | 9 +- gpt4all-chat/chatllm.h | 1 + 20 files changed, 682 insertions(+), 398 deletions(-) delete mode 100644 gpt4all-backend/bert.h create mode 100644 gpt4all-backend/bert_impl.h create mode 100644 gpt4all-backend/scripts/convert_bert_hf_to_ggml.py diff --git a/gpt4all-backend/bert.cpp b/gpt4all-backend/bert.cpp index 318efdc7..66ee2515 100644 --- a/gpt4all-backend/bert.cpp +++ b/gpt4all-backend/bert.cpp @@ -1,4 +1,5 @@ -#include "bert.h" +#define BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#include "bert_impl.h" #include "ggml.h" #include @@ -16,6 +17,12 @@ //#define DEBUG_BERT +namespace { +const char *modelType_ = "Bert"; +} + +typedef int32_t bert_vocab_id; + // default hparams (all-MiniLM-L6-v2) struct bert_hparams { @@ -192,15 +199,11 @@ std::string bert_normalize_prompt(const std::string &text) } return text2; } -void bert_tokenize( + +std::vector bert_tokenize( struct bert_ctx * ctx, - const char * text, - bert_vocab_id * tokens, - int32_t * n_tokens, - int32_t n_max_tokens) + const char * text) { - int cls_tok_id = 101; - int sep_tok_id = 102; const bert_vocab &vocab = ctx->vocab; std::string str = text; @@ -225,10 +228,10 @@ void bert_tokenize( } } - int32_t t = 0; - tokens[t++] = cls_tok_id; - // find the longest tokens that form the words: + std::vector tokens; + int cls_tok_id = 101; + tokens.push_back(cls_tok_id); for (const auto &word : words) { if (word.size() == 0) @@ -237,21 +240,17 @@ void bert_tokenize( int i = 0; int n = word.size(); auto *token_map = &vocab.token_to_id; - loop: while (i < n) { - if (t >= n_max_tokens - 1) - break; int j = n; while (j > i) { auto it = token_map->find(word.substr(i, j - i)); if (it != token_map->end()) { - tokens[t++] = it->second; + tokens.push_back(it->second); i = j; token_map = &vocab.subword_token_to_id; - goto loop; } --j; } @@ -263,14 +262,247 @@ void bert_tokenize( } } } - tokens[t++] = sep_tok_id; - *n_tokens = t; + + return tokens; +} + +void bert_resize_ctx(bert_ctx * ctx, int32_t new_size) { + int64_t buf_size_new = ctx->mem_per_input * new_size; + + // TODO: Max memory should be a param? Now just 1 GB + int64_t GB = 1 << 30; +#if defined(DEBUG_BERT) + printf("%s: requested_buf_size %lldMB\n", __func__, buf_size_new / (1 << 20)); +#endif + if (buf_size_new > GB) { + int32_t adjusted_new_size = GB / ctx->mem_per_input; + if (adjusted_new_size < 1) adjusted_new_size = 1; +#if defined(DEBUG_BERT) + printf("%s: requested batch size %d, actual new batch size %d\n", __func__, new_size, adjusted_new_size); +#endif + new_size = adjusted_new_size; + buf_size_new = ctx->mem_per_input * new_size; + } + if (new_size > ctx->max_batch_n) { + ctx->buf_compute.resize(buf_size_new); + ctx->max_batch_n = new_size; + } +} + +void bert_eval( + struct bert_ctx *ctx, + int32_t n_threads, + const bert_vocab_id *raw_tokens, + int32_t n_tokens, + float *embeddings) +{ + const bert_model& model = ctx->model; + bool mem_req_mode = !embeddings; + + // batch_embeddings is nullptr for the initial memory requirements run + if (!mem_req_mode && 1 > ctx->max_batch_n) + bert_resize_ctx(ctx, 1); + + const int N = n_tokens; + const auto &tokens = raw_tokens; + + const auto &hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_max_tokens = hparams.n_max_tokens; + const int n_head = hparams.n_head; + + const int d_head = n_embd / n_head; + + std::vector result; + if (N > n_max_tokens) + { + fprintf(stderr, "Too many tokens, maximum is %d\n", n_max_tokens); + return; + } + + auto & mem_per_token = ctx->mem_per_token; + auto & buf_compute = ctx->buf_compute; + + struct ggml_init_params params = { + .mem_size = buf_compute.size, + .mem_buffer = buf_compute.data, + .no_alloc = false, + }; + + struct ggml_context *ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; + + // Embeddings. word_embeddings + token_type_embeddings + position_embeddings + struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(token_layer->data, tokens, N * ggml_element_size(token_layer)); + + struct ggml_tensor *token_types = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_set_zero(token_types); + + struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; i++) + { + ggml_set_i32_1d(positions, i, i); + } + + struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.word_embeddings, token_layer); + + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.token_type_embeddings, token_types), + inpL); + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.position_embeddings, positions), + inpL); + + // embd norm + { + inpL = ggml_norm(ctx0, inpL); + + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.ln_e_w, inpL), + inpL), + ggml_repeat(ctx0, model.ln_e_b, inpL)); + } + // layers + for (int il = 0; il < n_layer; il++) + { + struct ggml_tensor *cur = inpL; + + // self-attention + { + struct ggml_tensor *Qcur = cur; + Qcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, Qcur), + ggml_mul_mat(ctx0, model.layers[il].q_w, Qcur)), + d_head, n_head, N); + struct ggml_tensor *Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor *Kcur = cur; + Kcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, Kcur), + ggml_mul_mat(ctx0, model.layers[il].k_w, Kcur)), + d_head, n_head, N); + struct ggml_tensor *K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + + struct ggml_tensor *Vcur = cur; + Vcur = ggml_reshape_3d(ctx0, + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, Vcur), + ggml_mul_mat(ctx0, model.layers[il].v_w, Vcur)), + d_head, n_head, N); + struct ggml_tensor *V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + // KQ = soft_max(KQ / sqrt(head width)) + KQ = ggml_soft_max(ctx0, + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head)))); + + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + } + // attention output + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].o_b, cur), + ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention norm + { + cur = ggml_norm(ctx0, cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_att_w, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_att_b, cur)); + } + struct ggml_tensor *att_output = cur; + // intermediate_output = self.intermediate(attention_output) + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), + cur); + cur = ggml_gelu(ctx0, cur); + + // layer_output = self.output(intermediate_output, attention_output) + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), + cur); + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, att_output, cur); + + // output norm + { + cur = ggml_norm(ctx0, cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_out_w, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_out_b, cur)); + } + inpL = cur; + } + inpL = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); + // pooler + struct ggml_tensor *sum = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, 1); + ggml_set_f32(sum, 1.0f / N); + inpL = ggml_mul_mat(ctx0, inpL, sum); + + // normalizer + ggml_tensor *length = ggml_sqrt(ctx0, + ggml_sum(ctx0, ggml_sqr(ctx0, inpL))); + inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + + ggml_tensor *output = inpL; + // run the computation + ggml_build_forward_expand(&gf, output); + ggml_graph_compute(ctx0, &gf); + + + // float *dat = ggml_get_data_f32(output); + // pretty_print_tensor(dat, output->ne, output->nb, output->n_dims - 1, ""); + + #ifdef GGML_PERF + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(&gf); + #endif + + if (!mem_req_mode) { + memcpy(embeddings, (float *)ggml_get_data(output), sizeof(float) * n_embd); + } else { + mem_per_token = ggml_used_mem(ctx0) / N; + } + + // printf("used_mem = %zu KB \n", ggml_used_mem(ctx0) / 1024); + // printf("mem_per_token = %zu KB \n", mem_per_token / 1024); + + ggml_free(ctx0); } // // Loading and setup // +void bert_free(bert_ctx * ctx) { + ggml_free(ctx->model.ctx); + delete ctx; +} + struct bert_ctx * bert_load_from_file(const char *fname) { #if defined(DEBUG_BERT) @@ -288,7 +520,7 @@ struct bert_ctx * bert_load_from_file(const char *fname) { uint32_t magic; fin.read((char *)&magic, sizeof(magic)); - if (magic != 0x67676d6c) + if (magic != 0x62657274) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname); return nullptr; @@ -506,7 +738,9 @@ struct bert_ctx * bert_load_from_file(const char *fname) // load weights { int n_tensors = 0; +#if defined(DEBUG_BERT) size_t total_size = 0; +#endif #if defined(DEBUG_BERT) printf("%s: ", __func__); @@ -609,8 +843,10 @@ struct bert_ctx * bert_load_from_file(const char *fname) fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); +#if defined(DEBUG_BERT) // printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); +#endif if (++n_tensors % 8 == 0) { @@ -639,7 +875,7 @@ struct bert_ctx * bert_load_from_file(const char *fname) // TODO: Max tokens should be a param? int32_t N = new_bert->model.hparams.n_max_tokens; - new_bert->mem_per_input = 1.1 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead + new_bert->mem_per_input = 1.9 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead } #if defined(DEBUG_BERT) @@ -649,331 +885,183 @@ struct bert_ctx * bert_load_from_file(const char *fname) return new_bert; } -void bert_resize_ctx(bert_ctx * ctx, int32_t new_size) { - int64_t buf_size_new = ctx->mem_per_input * new_size; +struct BertPrivate { + const std::string modelPath; + bool modelLoaded; + bert_ctx *ctx = nullptr; + int64_t n_threads = 0; +}; - // TODO: Max memory should be a param? Now just 1 GB - int64_t GB = 1 << 30; - //printf("%s: requested_buf_size %ldMB\n", __func__, buf_size_new / (1 << 20)); - if (buf_size_new > GB) { - int32_t adjusted_new_size = GB / ctx->mem_per_input; - if (adjusted_new_size < 1) adjusted_new_size = 1; - //printf("%s: requested batch size %d, actual new batch size %d\n", __func__, new_size, adjusted_new_size); - new_size = adjusted_new_size; - buf_size_new = ctx->mem_per_input * new_size; - } - if (new_size > ctx->max_batch_n) { - ctx->buf_compute.resize(buf_size_new); - ctx->max_batch_n = new_size; - } +Bert::Bert() : d_ptr(new BertPrivate) { + d_ptr->modelLoaded = false; } -void bert_free(bert_ctx * ctx) { - ggml_free(ctx->model.ctx); - delete ctx; +Bert::~Bert() { + bert_free(d_ptr->ctx); } -void bert_eval( - struct bert_ctx *ctx, - int32_t n_threads, - bert_vocab_id *tokens, - int32_t n_tokens, - float *embeddings) +bool Bert::loadModel(const std::string &modelPath) { - bert_eval_batch(ctx, n_threads, 1, &tokens, &n_tokens, embeddings ? &embeddings : nullptr); + d_ptr->ctx = bert_load_from_file(modelPath.c_str()); + d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + d_ptr->modelLoaded = d_ptr->ctx != nullptr; + fflush(stdout); + return true; } -void bert_eval_batch( - bert_ctx * ctx, - int32_t n_threads, - int32_t n_batch_size, - bert_vocab_id ** batch_tokens, - int32_t * n_tokens, - float ** batch_embeddings) +bool Bert::isModelLoaded() const { - const bert_model& model = ctx->model; - bool mem_req_mode = !batch_embeddings; - // batch_embeddings is nullptr for the initial memory requirements run - if (!mem_req_mode && n_batch_size > ctx->max_batch_n) { - bert_resize_ctx(ctx, n_batch_size); - if (n_batch_size > ctx->max_batch_n) { - fprintf(stderr, "%s: tried to increase buffers to batch size %d but failed\n", __func__, n_batch_size); - return; - } - } + return d_ptr->modelLoaded; +} - // TODO: implement real batching - for (int ba = 0; ba < n_batch_size; ba++) - { - const int N = n_tokens[ba]; - const auto &tokens = batch_tokens[ba]; +size_t Bert::requiredMem(const std::string &/*modelPath*/) +{ + return 0; +} - const auto &hparams = model.hparams; +size_t Bert::stateSize() const +{ + return 0; +} - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_max_tokens = hparams.n_max_tokens; - const int n_head = hparams.n_head; +size_t Bert::saveState(uint8_t */*dest*/) const +{ + return 0; +} - const int d_head = n_embd / n_head; +size_t Bert::restoreState(const uint8_t */*src*/) +{ + return 0; +} - std::vector result; - if (N > n_max_tokens) - { - fprintf(stderr, "Too many tokens, maximum is %d\n", n_max_tokens); - return; - } +void Bert::setThreadCount(int32_t n_threads) +{ + d_ptr->n_threads = n_threads; +} - auto & mem_per_token = ctx->mem_per_token; - auto & buf_compute = ctx->buf_compute; +int32_t Bert::threadCount() const +{ + return d_ptr->n_threads; +} - struct ggml_init_params params = { - .mem_size = buf_compute.size, - .mem_buffer = buf_compute.data, - .no_alloc = false, - }; - - struct ggml_context *ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - gf.n_threads = n_threads; - - // Embeddings. word_embeddings + token_type_embeddings + position_embeddings - struct ggml_tensor *token_layer = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(token_layer->data, tokens, N * ggml_element_size(token_layer)); - - struct ggml_tensor *token_types = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_zero(token_types); - - struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; i++) - { - ggml_set_i32_1d(positions, i, i); - } - - struct ggml_tensor *inpL = ggml_get_rows(ctx0, model.word_embeddings, token_layer); - - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.token_type_embeddings, token_types), - inpL); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.position_embeddings, positions), - inpL); - - // embd norm - { - inpL = ggml_norm(ctx0, inpL); - - inpL = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.ln_e_w, inpL), - inpL), - ggml_repeat(ctx0, model.ln_e_b, inpL)); - } - // layers - for (int il = 0; il < n_layer; il++) - { - struct ggml_tensor *cur = inpL; - - // self-attention - { - struct ggml_tensor *Qcur = cur; - Qcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, Qcur), - ggml_mul_mat(ctx0, model.layers[il].q_w, Qcur)), - d_head, n_head, N); - struct ggml_tensor *Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - - struct ggml_tensor *Kcur = cur; - Kcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, Kcur), - ggml_mul_mat(ctx0, model.layers[il].k_w, Kcur)), - d_head, n_head, N); - struct ggml_tensor *K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); - - struct ggml_tensor *Vcur = cur; - Vcur = ggml_reshape_3d(ctx0, - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, Vcur), - ggml_mul_mat(ctx0, model.layers[il].v_w, Vcur)), - d_head, n_head, N); - struct ggml_tensor *V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); - - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - // KQ = soft_max(KQ / sqrt(head width)) - KQ = ggml_soft_max(ctx0, - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head)))); - - V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - } - // attention output - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].o_b, cur), - ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); - - // re-add the layer input - cur = ggml_add(ctx0, cur, inpL); - - // attention norm - { - cur = ggml_norm(ctx0, cur); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_att_w, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_att_b, cur)); - } - struct ggml_tensor *att_output = cur; - // intermediate_output = self.intermediate(attention_output) - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), - cur); - cur = ggml_gelu(ctx0, cur); - - // layer_output = self.output(intermediate_output, attention_output) - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), - cur); - // attentions bypass the intermediate layer - cur = ggml_add(ctx0, att_output, cur); - - // output norm - { - cur = ggml_norm(ctx0, cur); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_out_w, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_out_b, cur)); - } - inpL = cur; - } - inpL = ggml_cont(ctx0, ggml_transpose(ctx0, inpL)); - // pooler - struct ggml_tensor *sum = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, 1); - ggml_set_f32(sum, 1.0f / N); - inpL = ggml_mul_mat(ctx0, inpL, sum); - - // normalizer - ggml_tensor *length = ggml_sqrt(ctx0, - ggml_sum(ctx0, ggml_sqr(ctx0, inpL))); - inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); - - ggml_tensor *output = inpL; - // run the computation - ggml_build_forward_expand(&gf, output); - ggml_graph_compute(ctx0, &gf); - - - // float *dat = ggml_get_data_f32(output); - // pretty_print_tensor(dat, output->ne, output->nb, output->n_dims - 1, ""); - - #ifdef GGML_PERF - // print timing information per ggml operation (for debugging purposes) - // requires GGML_PERF to be defined - ggml_graph_print(&gf); - #endif - - if (!mem_req_mode) { - memcpy(batch_embeddings[ba], (float *)ggml_get_data(output), sizeof(float) * n_embd); +std::vector Bert::embedding(const std::string &text) +{ + const int overlap = 32; + const LLModel::Token clsToken = 101; + const size_t contextLength = bert_n_max_tokens(d_ptr->ctx); + typedef std::vector TokenString; + TokenString tokens = ::bert_tokenize(d_ptr->ctx, text.c_str()); +#if defined(DEBUG_BERT) + std::cerr << "embedding: " << tokens.size() + << " contextLength " << contextLength + << "\n"; +#endif + std::vector embeddingsSum(bert_n_embd(d_ptr->ctx), 0); + int embeddingsSumTotal = 0; + size_t start_pos = 0; + bool isFirstChunk = true; + while (start_pos < tokens.size()) { + TokenString chunk; + if (!isFirstChunk) + chunk.push_back(clsToken); + const size_t l = isFirstChunk ? contextLength : contextLength - 1; + if (tokens.size() - start_pos > l) { + chunk.insert(chunk.end(), tokens.begin() + start_pos, tokens.begin() + start_pos + l); + start_pos = start_pos + contextLength - overlap; } else { - mem_per_token = ggml_used_mem(ctx0) / N; - - // printf("used_mem = %zu KB \n", ggml_used_mem(ctx0) / 1024); - // printf("mem_per_token = %zu KB \n", mem_per_token / 1024); + chunk.insert(chunk.end(), tokens.begin() + start_pos, tokens.end()); + start_pos = tokens.size(); } - - ggml_free(ctx0); +#if defined(DEBUG_BERT) + std::cerr << "chunk length: " << chunk.size() + << " embeddingsSumTotal " << embeddingsSumTotal + << " contextLength " << contextLength + << " start_pos " << start_pos + << "\n"; +#endif + embeddingsSumTotal++; + std::vector embeddings(bert_n_embd(d_ptr->ctx)); + bert_eval(d_ptr->ctx, d_ptr->n_threads, chunk.data(), chunk.size(), embeddings.data()); + std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddings.begin(), embeddingsSum.begin(), std::plus()); + isFirstChunk = false; } + + std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), [embeddingsSumTotal](float num){ return num / embeddingsSumTotal; }); + std::vector finalEmbeddings(embeddingsSum.begin(), embeddingsSum.end()); + return finalEmbeddings; } -void bert_encode( - struct bert_ctx *ctx, - int32_t n_threads, - const char *texts, - float *embeddings) +std::vector Bert::tokenize(PromptContext &, const std::string &str) const { - bert_encode_batch(ctx, n_threads, 1, 1, &texts, &embeddings); + return ::bert_tokenize(d_ptr->ctx, str.c_str()); } -void bert_encode_batch( - struct bert_ctx *ctx, - int32_t n_threads, - int32_t n_batch_size, - int32_t n_inputs, - const char ** texts, - float **embeddings) +LLModel::Token Bert::sampleToken(PromptContext &/*promptCtx*/) const { - // TODO: Disable batching for now - n_batch_size = 1; - /* - if (n_batch_size > n_inputs) { - n_batch_size = n_inputs; - } - if (n_batch_size > ctx->max_batch_n) { - bert_resize_ctx(ctx, n_batch_size); - n_batch_size = ctx->max_batch_n; - } - */ - - int32_t N = bert_n_max_tokens(ctx); - - std::vector buf_tokens; - // Most of this buffer will be unused in typical case where inputs are not that long. - buf_tokens.resize(N * n_inputs); - std::vector n_tokens = std::vector(n_inputs); - std::vector unsorted_tokens(n_inputs); - bert_vocab_id* it_tokens = buf_tokens.data(); - for (int i = 0; i < n_inputs; i++) { - unsorted_tokens[i] = it_tokens; - bert_tokenize(ctx, texts[i], it_tokens, &n_tokens[i], N); - it_tokens += n_tokens[i]; - } - - if (n_batch_size == n_inputs) { - bert_eval_batch(ctx, n_threads, n_batch_size, unsorted_tokens.data(), n_tokens.data(), embeddings); - } else { - // sort the inputs by tokenized length, batch and eval - - std::vector indices; - indices.reserve(n_inputs); - for (int i = 0; i < n_inputs; i++) - { - indices.push_back(i); - } - - std::vector sorted_n_tokens = std::vector(n_inputs); - - std::vector sorted_tokens(n_inputs); - - std::sort(indices.begin(), indices.end(), [&](int a, int b) - { return n_tokens[a] < n_tokens[b]; }); - - std::vector sorted_embeddings(n_inputs); - memcpy(sorted_embeddings.data(), embeddings, n_inputs * sizeof(float *)); - - for (int i = 0; i < n_inputs; i++) { - sorted_embeddings[i] = embeddings[indices[i]]; - sorted_tokens[i] = unsorted_tokens[indices[i]]; - sorted_n_tokens[i] = n_tokens[indices[i]]; - } - - for (int i = 0; i < n_inputs; i += n_batch_size) - { - if (i + n_batch_size > n_inputs) { - n_batch_size = n_inputs - i; - } - bert_eval_batch(ctx, n_threads, n_batch_size, &sorted_tokens[i], &sorted_n_tokens[i], &sorted_embeddings[i]); - } - } + return 999 /*!*/; } + +std::string Bert::tokenToString(Token id) const +{ + return bert_vocab_id_to_token(d_ptr->ctx, id); +} + +bool Bert::evalTokens(PromptContext &ctx, const std::vector &tokens) const +{ + std::vector embeddings(bert_n_embd(d_ptr->ctx)); + int32_t cls = 101; + const bool useCLS = tokens.front() != cls; + if (useCLS) { + std::vector myTokens; + myTokens.push_back(cls); + myTokens.insert(myTokens.end(), tokens.begin(), tokens.end()); + bert_eval(d_ptr->ctx, d_ptr->n_threads, myTokens.data(), myTokens.size(), embeddings.data()); + } else + bert_eval(d_ptr->ctx, d_ptr->n_threads, tokens.data(), tokens.size(), embeddings.data()); + ctx.n_past = 0; // bert does not store any context + return true; +} + +int32_t Bert::contextLength() const +{ + return bert_n_max_tokens(d_ptr->ctx); +} + +const std::vector &Bert::endTokens() const +{ + static const std::vector out = { 102 /*sep*/}; + return out; +} + +#if defined(_WIN32) +#define DLL_EXPORT __declspec(dllexport) +#else +#define DLL_EXPORT __attribute__ ((visibility ("default"))) +#endif + +extern "C" { +DLL_EXPORT bool is_g4a_backend_model_implementation() { + return true; +} + +DLL_EXPORT const char *get_model_type() { + return modelType_; +} + +DLL_EXPORT const char *get_build_variant() { + return GGML_BUILD_VARIANT; +} + +DLL_EXPORT bool magic_match(std::istream& f) { + uint32_t magic = 0; + f.read(reinterpret_cast(&magic), sizeof(magic)); + if (magic != 0x62657274) { + return false; + } + return true; +} + +DLL_EXPORT LLModel *construct() { + return new Bert; +} +} \ No newline at end of file diff --git a/gpt4all-backend/bert.h b/gpt4all-backend/bert.h deleted file mode 100644 index 28435ede..00000000 --- a/gpt4all-backend/bert.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef BERT_H -#define BERT_H - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -struct bert_ctx; - -typedef int32_t bert_vocab_id; - -struct bert_ctx * bert_load_from_file(const char * fname); -void bert_free(bert_ctx * ctx); - -// Main api, does both tokenizing and evaluation - -void bert_encode( - struct bert_ctx * ctx, - int32_t n_threads, - const char * texts, - float * embeddings); - -// n_batch_size - how many to process at a time -// n_inputs - total size of texts and embeddings arrays -void bert_encode_batch( - struct bert_ctx * ctx, - int32_t n_threads, - int32_t n_batch_size, - int32_t n_inputs, - const char ** texts, - float ** embeddings); - -// Api for separate tokenization & eval - -void bert_tokenize( - struct bert_ctx * ctx, - const char * text, - bert_vocab_id * tokens, - int32_t * n_tokens, - int32_t n_max_tokens); - -void bert_eval( - struct bert_ctx * ctx, - int32_t n_threads, - bert_vocab_id * tokens, - int32_t n_tokens, - float * embeddings); - -// NOTE: for batch processing the longest input must be first -void bert_eval_batch( - struct bert_ctx * ctx, - int32_t n_threads, - int32_t n_batch_size, - bert_vocab_id ** batch_tokens, - int32_t * n_tokens, - float ** batch_embeddings); - -int32_t bert_n_embd(bert_ctx * ctx); -int32_t bert_n_max_tokens(bert_ctx * ctx); - -const char* bert_vocab_id_to_token(bert_ctx * ctx, bert_vocab_id id); - -#ifdef __cplusplus -} -#endif - -#endif // BERT_H diff --git a/gpt4all-backend/bert_impl.h b/gpt4all-backend/bert_impl.h new file mode 100644 index 00000000..d1cc99f4 --- /dev/null +++ b/gpt4all-backend/bert_impl.h @@ -0,0 +1,44 @@ +#ifndef BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#error This file is NOT meant to be included outside of bert.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define BERT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#endif +#ifndef BERT_H +#define BERT_H + +#include +#include +#include +#include +#include "llmodel.h" + +struct BertPrivate; +class Bert : public LLModel { +public: + Bert(); + ~Bert(); + + bool supportsEmbedding() const override { return true; } + bool supportsCompletion() const override { return true; } + bool loadModel(const std::string &modelPath) override; + bool isModelLoaded() const override; + size_t requiredMem(const std::string &modelPath) override; + size_t stateSize() const override; + size_t saveState(uint8_t *dest) const override; + size_t restoreState(const uint8_t *src) override; + void setThreadCount(int32_t n_threads) override; + int32_t threadCount() const override; + + std::vector embedding(const std::string &text) override; + +private: + std::unique_ptr d_ptr; + +protected: + std::vector tokenize(PromptContext &, const std::string&) const override; + Token sampleToken(PromptContext &ctx) const override; + std::string tokenToString(Token) const override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; + int32_t contextLength() const override; + const std::vector& endTokens() const override; +}; + +#endif // BERT_H diff --git a/gpt4all-backend/falcon_impl.h b/gpt4all-backend/falcon_impl.h index 017252ea..2362af9f 100644 --- a/gpt4all-backend/falcon_impl.h +++ b/gpt4all-backend/falcon_impl.h @@ -16,6 +16,8 @@ public: Falcon(); ~Falcon(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath) override; diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index 93e27319..e2b1826e 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -15,6 +15,8 @@ public: GPTJ(); ~GPTJ(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath) override; diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index 7623f157..e564c44a 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -15,6 +15,8 @@ public: LLamaModel(); ~LLamaModel(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath) override; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 06f9d618..29706697 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -61,18 +61,25 @@ public: explicit LLModel() {} virtual ~LLModel() {} + virtual bool supportsEmbedding() const = 0; + virtual bool supportsCompletion() const = 0; virtual bool loadModel(const std::string &modelPath) = 0; virtual bool isModelLoaded() const = 0; virtual size_t requiredMem(const std::string &modelPath) = 0; virtual size_t stateSize() const { return 0; } virtual size_t saveState(uint8_t */*dest*/) const { return 0; } virtual size_t restoreState(const uint8_t */*src*/) { return 0; } + + // This method requires the model to return true from supportsCompletion otherwise it will throw + // an error virtual void prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, std::function recalculateCallback, PromptContext &ctx); + virtual std::vector embedding(const std::string &text); + virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index c7e13f79..fb916d95 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -166,6 +166,25 @@ void llmodel_prompt(llmodel_model model, const char *prompt, ctx->context_erase = wrapper->promptContext.contextErase; } +float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size) +{ + LLModelWrapper *wrapper = reinterpret_cast(model); + std::vector embeddingVector = wrapper->llModel->embedding(text); + float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float)); + if(embedding == nullptr) { + *embedding_size = 0; + return nullptr; + } + std::copy(embeddingVector.begin(), embeddingVector.end(), embedding); + *embedding_size = embeddingVector.size(); + return embedding; +} + +void llmodel_free_embedding(float *ptr) +{ + free(ptr); +} + void llmodel_setThreadCount(llmodel_model model, int32_t n_threads) { LLModelWrapper *wrapper = reinterpret_cast(model); diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 0d221c7e..8d582d08 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -171,6 +171,23 @@ void llmodel_prompt(llmodel_model model, const char *prompt, llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx); +/** + * Generate an embedding using the model. + * @param model A pointer to the llmodel_model instance. + * @param text A string representing the text to generate an embedding for. + * @param embedding_size A pointer to a size_t type that will be set by the call indicating the length + * of the returned floating point array. + * @return A pointer to an array of floating point values passed to the calling method which then will + * be responsible for lifetime of this memory. + */ +float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size); + +/** + * Frees the memory allocated by the llmodel_embedding function. + * @param ptr A pointer to the embedding as returned from llmodel_embedding. + */ +void llmodel_free_embedding(float *ptr); + /** * Set the number of threads to be used by the model. * @param model A pointer to the llmodel_model instance. diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index fe1db763..89ba32b5 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -37,6 +37,13 @@ void LLModel::prompt(const std::string &prompt, return; } + if (!supportsCompletion()) { + std::string errorMessage = "ERROR: this model does not support text completion or chat!\n"; + responseCallback(-1, errorMessage); + std::cerr << implementation().modelType() << errorMessage; + return; + } + // tokenize the prompt std::vector embd_inp = tokenize(promptCtx, prompt); @@ -158,3 +165,12 @@ void LLModel::prompt(const std::string &prompt, cachedTokens.clear(); } } + +std::vector LLModel::embedding(const std::string &/*text*/) +{ + if (!supportsCompletion()) { + std::string errorMessage = "ERROR: this model does not support generating embeddings!\n"; + std::cerr << implementation().modelType() << errorMessage; + } + return std::vector(); +} diff --git a/gpt4all-backend/mpt_impl.h b/gpt4all-backend/mpt_impl.h index f5156836..df7b7718 100644 --- a/gpt4all-backend/mpt_impl.h +++ b/gpt4all-backend/mpt_impl.h @@ -15,6 +15,8 @@ public: MPT(); ~MPT(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath) override; diff --git a/gpt4all-backend/replit_impl.h b/gpt4all-backend/replit_impl.h index 73a8ea80..f635f30d 100644 --- a/gpt4all-backend/replit_impl.h +++ b/gpt4all-backend/replit_impl.h @@ -17,6 +17,8 @@ public: Replit(); ~Replit(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string & modelPath) override; diff --git a/gpt4all-backend/scripts/convert_bert_hf_to_ggml.py b/gpt4all-backend/scripts/convert_bert_hf_to_ggml.py new file mode 100644 index 00000000..ba7045ca --- /dev/null +++ b/gpt4all-backend/scripts/convert_bert_hf_to_ggml.py @@ -0,0 +1,102 @@ +import sys +import struct +import json +import torch +import numpy as np + +from transformers import AutoModel, AutoTokenizer + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + sys.exit(1) + +# output in the same directory as the model +dir_model = sys.argv[1] +fname_out = sys.argv[1] + "/ggml-model.bin" + +with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + encoder = json.load(f) + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +with open(dir_model + "/vocab.txt", "r", encoding="utf-8") as f: + vocab = f.readlines() +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + sys.exit(1) + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + + +tokenizer = AutoTokenizer.from_pretrained(dir_model) +model = AutoModel.from_pretrained(dir_model, low_cpu_mem_usage=True) +print (model) + +print(tokenizer.encode('I believe the meaning of life is')) + +list_vars = model.state_dict() +for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + +fout = open(fname_out, "wb") + +print(hparams) + +fout.write(struct.pack("i", 0x62657274)) # magic: ggml in hex +fout.write(struct.pack("i", hparams["vocab_size"])) +fout.write(struct.pack("i", hparams["max_position_embeddings"])) +fout.write(struct.pack("i", hparams["hidden_size"])) +fout.write(struct.pack("i", hparams["intermediate_size"])) +fout.write(struct.pack("i", hparams["num_attention_heads"])) +fout.write(struct.pack("i", hparams["num_hidden_layers"])) +fout.write(struct.pack("i", ftype)) + +for i in range(hparams["vocab_size"]): + text = vocab[i][:-1] # strips newline at the end + #print(f"{i}:{text}") + data = bytes(text, 'utf-8') + fout.write(struct.pack("i", len(data))) + fout.write(data) + +for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + if name in ['embeddings.position_ids', 'pooler.dense.weight', 'pooler.dense.bias']: + continue + print("Processing variable: " + name + " with shape: ", data.shape) + + n_dims = len(data.shape); + + # ftype == 0 -> float32, ftype == 1 -> float16 + if ftype == 1 and name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + l_type = 1 + else: + l_type = 0 + + # header + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), l_type)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str); + + # data + data.tofile(fout) + +fout.close() + +print("Done. Output file: " + fname_out) +print("") diff --git a/gpt4all-bindings/python/gpt4all/__init__.py b/gpt4all-bindings/python/gpt4all/__init__.py index 4c0cc9e6..54491d79 100644 --- a/gpt4all-bindings/python/gpt4all/__init__.py +++ b/gpt4all-bindings/python/gpt4all/__init__.py @@ -1,2 +1,2 @@ -from .gpt4all import GPT4All # noqa +from .gpt4all import GPT4All, embed # noqa from .pyllmodel import LLModel # noqa diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 7c126b76..1eddf2e6 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -15,6 +15,20 @@ from . import pyllmodel # TODO: move to config DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") +def embed( + text: str +) -> list[float]: + """ + Generate an embedding for all GPT4All. + + Args: + text: The text document to generate an embedding for. + + Returns: + An embedding of your document of text. + """ + model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin') + return model.model.generate_embedding(text) class GPT4All: """ diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 7e091207..8aa33227 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -112,6 +112,19 @@ llmodel.llmodel_prompt.argtypes = [ llmodel.llmodel_prompt.restype = None +llmodel.llmodel_embedding.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.POINTER(ctypes.c_size_t), +] + +llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float) + +llmodel.llmodel_free_embedding.argtypes = [ + ctypes.POINTER(ctypes.c_float) +] +llmodel.llmodel_free_embedding.restype = None + llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32] llmodel.llmodel_setThreadCount.restype = None @@ -233,6 +246,17 @@ class LLModel: self.context.repeat_last_n = repeat_last_n self.context.context_erase = context_erase + def generate_embedding( + self, + text: str + ) -> list[float]: + embedding_size = ctypes.c_size_t() + c_text = ctypes.c_char_p(text.encode('utf-8')) + embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) + embedding_array = ctypes.cast(embedding_ptr, ctypes.POINTER(ctypes.c_float * embedding_size.value)).contents + llmodel.llmodel_free_embedding(embedding_ptr) + return list(embedding_array) + def prompt_model( self, prompt: str, diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index df382ed5..dd9aa417 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -1,7 +1,7 @@ import sys from io import StringIO -from gpt4all import GPT4All +from gpt4all import GPT4All, embed def test_inference(): @@ -99,3 +99,11 @@ def test_inference_mpt(): output = model.generate(prompt) assert isinstance(output, str) assert len(output) > 0 + +def test_embedding(): + text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id estLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est' + output = embed(text) + #for i, value in enumerate(output): + #print(f'Value at index {i}: {value}') + assert len(output) == 384 + diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index b1f32298..0f835bee 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -46,6 +46,8 @@ public: ChatGPT(); virtual ~ChatGPT(); + bool supportsEmbedding() const override { return false; } + bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath) override; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 37c92d53..80910274 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -14,6 +14,7 @@ #define REPLIT_INTERNAL_STATE_VERSION 0 #define LLAMA_INTERNAL_STATE_VERSION 0 #define FALCON_INTERNAL_STATE_VERSION 0 +#define BERT_INTERNAL_STATE_VERSION 0 class LLModelStore { public: @@ -264,6 +265,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) case 'M': m_llModelType = LLModelType::MPT_; break; case 'R': m_llModelType = LLModelType::REPLIT_; break; case 'F': m_llModelType = LLModelType::FALCON_; break; + case 'B': m_llModelType = LLModelType::BERT_; break; default: { delete std::exchange(m_llModelInfo.model, nullptr); @@ -628,8 +630,8 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc) qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc; #endif Q_UNUSED(isRecalc); - Q_UNREACHABLE(); - return false; + qt_noop(); + return true; } bool ChatLLM::handleSystemPrompt(int32_t token) @@ -669,7 +671,8 @@ bool ChatLLM::serialize(QDataStream &stream, int version) case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break; case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break; case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break; - case FALCON_: stream << LLAMA_INTERNAL_STATE_VERSION; break; + case FALCON_: stream << FALCON_INTERNAL_STATE_VERSION; break; + case BERT_: stream << BERT_INTERNAL_STATE_VERSION; break; default: Q_UNREACHABLE(); } } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index aad8a0fb..f75d24e2 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -16,6 +16,7 @@ enum LLModelType { CHATGPT_, REPLIT_, FALCON_, + BERT_ }; struct LLModelInfo {