diff --git a/include/turbopilot/model.hpp b/include/turbopilot/model.hpp index 3f372c6..2849d08 100644 --- a/include/turbopilot/model.hpp +++ b/include/turbopilot/model.hpp @@ -8,6 +8,9 @@ #include #include +typedef void (*offload_func_t)(struct ggml_tensor * tensor); +void ggml_nop(struct ggml_tensor * tensor); + struct gpt_vocab { using id = int32_t; diff --git a/include/turbopilot/starcoder.hpp b/include/turbopilot/starcoder.hpp new file mode 100644 index 0000000..f5b7344 --- /dev/null +++ b/include/turbopilot/starcoder.hpp @@ -0,0 +1,79 @@ +#ifndef __TURBOPILOT_STARCODER_H +#define __TURBOPILOT_STARCODER_H + +#include + +// default hparams (GPT-2 117M) +// https://huggingface.co/bigcode/gpt_bigcode-santacoder/blob/main/config.json +struct starcoder_hparams { + int32_t n_vocab = 49280; + int32_t n_ctx = 2048; + int32_t n_embd = 2048; + int32_t n_head = 16; + int32_t n_layer = 24; + int32_t ftype = 1; +}; + +struct starcoder_layer { + // normalization + struct ggml_tensor * ln_1_g; + struct ggml_tensor * ln_1_b; + + struct ggml_tensor * ln_2_g; + struct ggml_tensor * ln_2_b; + + // attention + struct ggml_tensor * c_attn_attn_w; + struct ggml_tensor * c_attn_attn_b; + + struct ggml_tensor * c_attn_proj_w; + struct ggml_tensor * c_attn_proj_b; + + // mlp + struct ggml_tensor * c_mlp_fc_w; + struct ggml_tensor * c_mlp_fc_b; + + struct ggml_tensor * c_mlp_proj_w; + struct ggml_tensor * c_mlp_proj_b; +}; + +struct starcoder_model { + starcoder_hparams hparams; + + // normalization + struct ggml_tensor * ln_f_g; + struct ggml_tensor * ln_f_b; + + struct ggml_tensor * wte; // position embedding + struct ggml_tensor * wpe; // token embedding + struct ggml_tensor * lm_head; // language model head + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + // + struct ggml_context * ctx; + std::map tensors; +}; + + +class StarcoderModel : public TurbopilotModel { +public: + StarcoderModel(ModelConfig config, std::mt19937 &rng) : TurbopilotModel(config, rng){ + this->model = new starcoder_model{}; + this->vocab = new gpt_vocab{}; + } + virtual ~StarcoderModel(); + bool load_model(std::string path); + virtual std::stringstream predict(std::string prompt, int max_length, bool include_prompt); + +private: + starcoder_model *model = NULL; + gpt_vocab *vocab = NULL; +}; + + +#endif //__TURBOPILOT_STARCODER_H \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ee093fd..02e485a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,8 +9,10 @@ add_executable(${TURBOPILOT_TARGET} gptj.cpp common.cpp server.cpp + starcoder.cpp ../include/turbopilot/model.hpp ../include/turbopilot/gptj.hpp + ../include/turbopilot/starcoder.hpp ) diff --git a/src/common.cpp b/src/common.cpp index 03faaaf..2439c56 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -4,11 +4,15 @@ #include #include +void llama_nop(struct ggml_tensor * tensor) { // don't offload by default + (void) tensor; +} void gpt_vocab::add_special_token(const std::string & token) { special_tokens.push_back(token); } + void gpt_split_words(std::string str, std::vector& words) { const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; const std::regex re(pattern); diff --git a/src/gptj.cpp b/src/gptj.cpp index 4d3a4c1..64ebd69 100644 --- a/src/gptj.cpp +++ b/src/gptj.cpp @@ -249,6 +249,7 @@ bool gptj_eval( GPTJModel::~GPTJModel(){ ggml_free(model->ctx); free(model); + free(vocab); } bool GPTJModel::load_model(std::string fname) { diff --git a/src/main.cpp b/src/main.cpp index 94d79a3..d951c63 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -9,6 +9,7 @@ #include #include "turbopilot/model.hpp" +#include "turbopilot/starcoder.hpp" #include "turbopilot/gptj.hpp" #include "turbopilot/server.hpp" @@ -64,6 +65,9 @@ int main(int argc, char **argv) if(model_type.compare("codegen") == 0) { spdlog::info("Initializing GPT-J type model for '{}' model", model_type); model = new GPTJModel(config, rng); + }else if(model_type.compare("starcoder") == 0 || model_type.compare("wizardcoder") == 0){ + spdlog::info("Initializing Starcoder/Wizardcoder type model for '{}' model type", model_type); + model = new StarcoderModel(config, rng); }else{ spdlog::error("Invalid model type: {}", model_type); } diff --git a/src/starcoder.cpp b/src/starcoder.cpp new file mode 100644 index 0000000..74aaea8 --- /dev/null +++ b/src/starcoder.cpp @@ -0,0 +1,771 @@ +#include +#include + +#include +#include +#include + + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool starcoder_eval( + const starcoder_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + + static size_t buf_size = 256u*1024*1024; + static void * buf = malloc(buf_size); + + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + + // wte + wpe + struct ggml_tensor * inpL = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.wte, embd), + ggml_get_rows(ctx0, model.wpe, position)); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur; + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // norm + { + // [ 768, N] + cur = ggml_norm(ctx0, inpL); + + // cur = ln_1_g*cur + ln_1_b + // [ 768, N] + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_1_g, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); + } + + // attn + // [2304, 768] - model.layers[il].c_attn_attn_w + // [2304, 1] - model.layers[il].c_attn_attn_b + // [ 768, N] - cur (in) + // [2304, N] - cur (out) + // + // cur = attn_w*cur + attn_b + // [2304, N] + { + cur = ggml_mul_mat(ctx0, + model.layers[il].c_attn_attn_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur), + cur); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd); + struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd); + + // store key and value to memory + if (N >= 1) { + struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + // [64, N, 12] + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), + 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + // [64, n_past + N, 12] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); //TODO: need to be tiled + + // GG: flash attention + //struct ggml_tensor * V = + // ggml_cpy(ctx0, + // ggml_permute(ctx0, + // ggml_reshape_3d(ctx0, + // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + // n_embd/n_head, n_head, n_past + N), + // 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); + + //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); + + // K * Q + // [n_past + N, N, 12] + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + // KQ_masked = mask_past(KQ_scaled) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // [n_past + N, N, 12] + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + // [n_past + N, 64, 12] + struct ggml_tensor * V_trans = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + n_embd/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); + + // KQV = transpose(V) * KQ_soft_max + // [64, N, 12] + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // [64, 12, N] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + // [768, N] + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + } + + // projection + // [ 768, 768] - model.layers[il].c_attn_proj_w + // [ 768, 1] - model.layers[il].c_attn_proj_b + // [ 768, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + { + cur = ggml_mul_mat(ctx0, + model.layers[il].c_attn_proj_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur), + cur); + } + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF); + + // cur = ln_2_g*cur + ln_2_b + // [ 768, N] + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].ln_2_g, cur), + cur), + ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); + } + + // fully connected + // [3072, 768] - model.layers[il].c_mlp_fc_w + // [3072, 1] - model.layers[il].c_mlp_fc_b + // [ 768, N] - cur (in) + // [3072, N] - cur (out) + // + // cur = fc_w*cur + fc_b + // [3072, N] + cur = ggml_mul_mat(ctx0, + model.layers[il].c_mlp_fc_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), + cur); + + // GELU activation + // [3072, N] + cur = ggml_gelu(ctx0, cur); + + // projection + // [ 768, 3072] - model.layers[il].c_mlp_proj_w + // [ 768, 1] - model.layers[il].c_mlp_proj_b + // [3072, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + cur = ggml_mul_mat(ctx0, + model.layers[il].c_mlp_proj_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), + cur); + } + + // input for next layer + inpL = ggml_add(ctx0, cur, inpFF); + } + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // norm + { + // [ 768, N] + inpL = ggml_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + // [ 768, N] + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.ln_f_g, inpL), + inpL), + ggml_repeat(ctx0, model.ln_f_b, inpL)); + } + + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + + // inpL = WTE * inpL + // [ 768, 50257] - model.lm_head + // [ 768, N] - inpL + inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); + + // logits -> probs + //inpL = ggml_soft_max_inplace(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result just for the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu MB\n", ggml_used_mem(ctx0)/(1024*1024)); + + ggml_free(ctx0); + + return true; +} + + +StarcoderModel::~StarcoderModel(){ + ggml_free(model->ctx); + free(model); + free(vocab); +} + +bool StarcoderModel::load_model(std::string fname) { + printf("%s: loading model from '%s'\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + // load hparams + { + auto & hparams = model->hparams; + + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + } + + // load vocab + { + int32_t n_vocab = 0; + fin.read((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != model->hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", + __func__, fname.c_str(), n_vocab, model->hparams.n_vocab); + return false; + } + + std::string word; + std::vector buf(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *) &len, sizeof(len)); + + buf.resize(len); + fin.read((char *) buf.data(), len); + word.assign(buf.data(), len); + + vocab->token_to_id[word] = i; + vocab->id_to_token[i] = word; + + // if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + // Add StarChat special tokens. + for (const std::string & token : { + "<|system|>", + "<|user|>", + "<|assistant|>", + "<|end|>", + "", + "", + "", + "", + "<|end_of_turn|>" + }) { + if (vocab->token_to_id.find(token) != vocab->token_to_id.end()) { + vocab->add_special_token(token); + } + } + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model->hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, fname.c_str(), model->hparams.ftype); + return false; + } + + auto & ctx = model->ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model->hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + const int head_dim = n_embd / hparams.n_head; + const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head + const int kv_dim = kv_heads * head_dim; + + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b + + ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte + ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe + ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b + + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g + ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b + + ctx_size += n_layer*((n_embd + 2*kv_dim)*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w // TODO: + ctx_size += n_layer*( (n_embd + 2*kv_dim)*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b + + ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b + + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b + + ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b + + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v + + ctx_size += (6 + 12*n_layer)*512; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + model->ctx = ggml_init(params); + if (!model->ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model->hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + const int head_dim = n_embd / hparams.n_head; + const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head + const int kv_dim = kv_heads * head_dim; + + model->layers.resize(n_layer); + + model->ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model->ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + model->wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model->wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx); + model->lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + // map by name + model->tensors["model/ln_f/g"] = model->ln_f_g; + model->tensors["model/ln_f/b"] = model->ln_f_b; + + model->tensors["model/wte"] = model->wte; + model->tensors["model/wpe"] = model->wpe; + model->tensors["model/lm_head"] = model->lm_head; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd + 2*kv_dim); + layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd + 2*kv_dim); + + layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); //TODO: 4*n_embd = config.n_inner + layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); + + layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + // map by name + model->tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g; + model->tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b; + + model->tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g; + model->tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b; + + model->tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w; + model->tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b; + + model->tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w; + model->tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b; + + model->tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w; + model->tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b; + + model->tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w; + model->tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b; + } + } + + // key + value memory + { + const auto & hparams = model->hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + model->memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + model->memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + + const size_t memory_size = ggml_nbytes(model->memory_k) + ggml_nbytes(model->memory_v); + + printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); + } + + // load weights + { + size_t total_size = 0; + + bool has_lm_head = false; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model->tensors.find(name.data()) == model->tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model->tensors[name.data()]; + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); + return false; + } + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file. got %d, expected %d\n", + __func__, name.data(), (int) ggml_nelements(tensor), nelements); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + // GPT-2 models share the WTE tensor as the LM head + if (name == "model/wte" && has_lm_head == false) { + memcpy(model->lm_head->data, tensor->data, ggml_nbytes(tensor)); + } + + if (name == "model/lm_head") { + has_lm_head = true; + } + + total_size += ggml_nbytes(tensor); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); + } + + fin.close(); + + return true; +} + + +std::stringstream StarcoderModel::predict(std::string prompt, int max_length, bool include_prompt) { + + std::stringstream result; + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize((*vocab), prompt); + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + int n_predict = std::min(max_length, model->hparams.n_ctx - (int) embd_inp.size()); + + spdlog::debug("{}: number of tokens in prompt = {}", __func__, embd_inp.size()); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + + std::vector logits; + + starcoder_eval((*model), config.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!starcoder_eval((*model), config.n_threads, n_past, embd, logits, mem_per_token)) { + throw std::runtime_error("Failed to predict"); + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = config.top_k; + const float top_p = config.top_p; + const float temp = config.temp; + + const int n_vocab = model->hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p((*vocab), logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + + if(id != 50256){ + result << vocab->id_to_token[id].c_str(); + } + + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + + if(include_prompt){ + result << vocab->id_to_token[embd_inp[k]].c_str(); + } + + if (embd.size() > config.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + return result; +}