mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Fix up mpt.
This commit is contained in:
parent
61e2aabadb
commit
b6886c0e31
@ -40,7 +40,7 @@ struct mpt_layer {
|
||||
// attention
|
||||
struct ggml_tensor * attn_Wqkv_w;
|
||||
struct ggml_tensor * attn_out_proj_w;
|
||||
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_proj_w;
|
||||
struct ggml_tensor * ffn_down_proj_w;
|
||||
@ -87,7 +87,7 @@ struct mpt_model {
|
||||
struct ggml_tensor * norm_f_w;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
|
||||
|
||||
// mpt does weight tying
|
||||
|
||||
std::vector<mpt_layer> layers;
|
||||
@ -260,7 +260,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||
|
||||
ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_up_proj_w
|
||||
ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_down_proj_w
|
||||
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v
|
||||
|
||||
@ -427,7 +427,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||
|
||||
|
||||
// load the model's weights from a file path
|
||||
bool gptj_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
|
||||
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
@ -528,7 +528,7 @@ bool mpt_eval(
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
@ -641,7 +641,7 @@ std::vector<int> mpt_tokenize(const mpt_vocab & vocab, const std::string & text)
|
||||
// not sure if this entirely right?
|
||||
std::vector<std::string> words;
|
||||
|
||||
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
@ -771,6 +771,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
||||
|
||||
mpt_vocab::id mpt_sample_top_k_top_p(
|
||||
const mpt_vocab & vocab,
|
||||
const size_t actualVocabSize,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
@ -779,7 +780,7 @@ mpt_vocab::id mpt_sample_top_k_top_p(
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
int n_logits = actualVocabSize;
|
||||
|
||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||
@ -1038,7 +1039,7 @@ void MPT::prompt(const std::string &prompt,
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
std::cerr << "MPT: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
@ -1081,7 +1082,7 @@ void MPT::prompt(const std::string &prompt,
|
||||
int id = 0;
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
id = mpt_sample_top_k_top_p(d_ptr->vocab,
|
||||
id = mpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
|
||||
promptCtx.n_ctx,
|
||||
promptCtx.logits,
|
||||
@ -1096,7 +1097,7 @@ void MPT::prompt(const std::string &prompt,
|
||||
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
std::cerr << "MPT: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
@ -1185,7 +1186,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
|
||||
|
||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
||||
std::cerr << "MPT ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
|
Loading…
Reference in New Issue
Block a user