Fix for special tokens.

This commit is contained in:
Adam Treat 2023-05-08 12:06:32 -04:00
parent b6886c0e31
commit dc559c1575

View File

@ -145,9 +145,16 @@ struct mpt_vocab {
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token);
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
std::string regex_escape(const std::string &s) {
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
// load the model's weights from a stream
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
@ -215,6 +222,9 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
// tokenize special tokens
if(special) {
vocab.add_special_token(regex_escape(word));
}
}
}