add requiredMem method to llmodel impls

most of these can just shortcut out of the model loading logic llama is a bit worse to deal with because we submodule it so I have to at least parse the hparams, and then I just use the size on disk as an estimate for the mem size (which seems reasonable since we mmap() the llama files anyway)
This commit is contained in:
Aaron Miller 2023-06-26 12:17:34 -07:00 committed by AT
parent dead954134
commit b19a3e5b2c
14 changed files with 154 additions and 8 deletions

View File

@ -158,8 +158,11 @@ static bool kv_cache_init(
}
// load the model's weights from a stream
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) {
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
if(mem_req != nullptr) {
*mem_req = 0;
}
// verify magic
{
@ -276,6 +279,19 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
if (mem_req != nullptr) {
*mem_req += ctx_size;
const int n_embd = model.hparams.n_embd;
const int n_layer = model.hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
const int64_t n_elements = n_embd*n_mem;
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
return false;
}
// create the ggml context
{
struct ggml_init_params params = {
@ -837,6 +853,15 @@ GPTJ::GPTJ()
d_ptr->modelLoaded = false;
}
size_t GPTJ::requiredMem(const std::string &modelPath) {
gptj_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
auto fin = std::ifstream(modelPath, std::ios::binary);
gptj_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
return mem_req;
}
bool GPTJ::loadModel(const std::string &modelPath) {
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;

View File

@ -17,6 +17,7 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -97,6 +97,40 @@ LLamaModel::LLamaModel()
d_ptr->modelLoaded = false;
}
// default hparams (LLaMA 7B)
struct llama_file_hparams {
uint32_t n_vocab = 32000;
uint32_t n_embd = 4096;
uint32_t n_mult = 256;
uint32_t n_head = 32;
uint32_t n_layer = 32;
uint32_t n_rot = 64;
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
};
size_t LLamaModel::requiredMem(const std::string &modelPath) {
auto fin = std::ifstream(modelPath, std::ios::binary);
fin.seekg(0, std::ios_base::end);
size_t filesize = fin.tellg();
fin.seekg(0, std::ios_base::beg);
uint32_t magic = 0;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != 0x67676a74) return 0;
uint32_t version = 0;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
llama_file_hparams hparams;
fin.read(reinterpret_cast<char*>(&hparams.n_vocab), sizeof(hparams.n_vocab));
fin.read(reinterpret_cast<char*>(&hparams.n_embd), sizeof(hparams.n_embd));
fin.read(reinterpret_cast<char*>(&hparams.n_head), sizeof(hparams.n_head));
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
const size_t n_ctx = 2048;
const size_t kvcache_element_size = 2; // fp16
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
return filesize + est_kvcache_size;
}
bool LLamaModel::loadModel(const std::string &modelPath)
{
// load the model

View File

@ -17,6 +17,7 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -59,6 +59,7 @@ public:
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath) = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }

View File

@ -60,6 +60,12 @@ void llmodel_model_destroy(llmodel_model model) {
delete reinterpret_cast<LLModelWrapper*>(model);
}
size_t llmodel_required_mem(llmodel_model model, const char *model_path)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->requiredMem(model_path);
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);

View File

@ -107,6 +107,14 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *build_va
*/
void llmodel_model_destroy(llmodel_model model);
/**
* Estimate RAM requirement for a model file
* @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file.
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
*/
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
/**
* Load a model from a file.
* @param model A pointer to the llmodel_model instance.

View File

@ -152,9 +152,13 @@ static bool kv_cache_init(
return true;
}
// load the model's weights from a stream
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
// load the model's weights from a stream. if mem_req ptr is passed the model is
// only partially parsed to estimate required memory
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab, size_t * mem_req) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
if (mem_req != nullptr) {
*mem_req = 0;
}
// verify magic
{
@ -276,6 +280,18 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
if (mem_req != nullptr) {
*mem_req += ctx_size;
const int n_embd = model.hparams.n_embd;
const int n_layer = model.hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
const int64_t n_elements = n_embd*n_mem;
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
return false;
}
// create the ggml context
{
struct ggml_init_params params = {
@ -431,7 +447,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
return false;
}
bool loaded = mpt_model_load(fname, fin, model, vocab);
bool loaded = mpt_model_load(fname, fin, model, vocab, nullptr);
fin.close();
return loaded;
}
@ -761,6 +777,15 @@ MPT::MPT()
d_ptr->modelLoaded = false;
}
size_t MPT::requiredMem(const std::string &modelPath) {
mpt_model dummy_model;
gpt_vocab dummy_vocab;
size_t mem_req;
auto fin = std::ifstream(modelPath, std::ios::binary);
mpt_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
return mem_req;
}
bool MPT::loadModel(const std::string &modelPath) {
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;
@ -768,7 +793,7 @@ bool MPT::loadModel(const std::string &modelPath) {
auto fin = std::ifstream(modelPath, std::ios::binary);
// load the model
if (!mpt_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) {
if (!mpt_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab, nullptr)) {
std::cerr << "MPT ERROR: failed to load model from " << modelPath;
return false;
}

View File

@ -17,6 +17,7 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -267,8 +267,11 @@ static bool kv_cache_init(
}
// load the model's weights from a stream
bool replit_model_load(const std::string & fname, std::istream &fin, replit_model & model, replit_tokenizer & vocab) {
bool replit_model_load(const std::string & fname, std::istream &fin, replit_model & model, replit_tokenizer & vocab, size_t *mem_req) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
if (mem_req != nullptr) {
*mem_req = 0;
}
// verify magic
{
@ -352,6 +355,18 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
}
if (mem_req != nullptr) {
*mem_req += ctx_size;
const int n_embd = model.hparams.n_embd;
const int n_layer = model.hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
const int64_t n_elements = n_embd*n_mem;
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
return false;
}
// create the ggml context
{
struct ggml_init_params params = {
@ -544,7 +559,7 @@ bool replit_model_load(const std::string & fname, replit_model & model, replit_t
return false;
}
bool loaded = replit_model_load(fname, fin, model, vocab);
bool loaded = replit_model_load(fname, fin, model, vocab, nullptr);
fin.close();
return loaded;
}
@ -888,6 +903,15 @@ Replit::Replit()
d_ptr->modelLoaded = false;
}
size_t Replit::requiredMem(const std::string &modelPath) {
replit_model dummy_model;
replit_tokenizer dummy_vocab;
size_t mem_req;
auto fin = std::ifstream(modelPath, std::ios::binary);
replit_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
return mem_req;
}
bool Replit::loadModel(const std::string &modelPath) {
std::mt19937 rng(time(NULL));
d_ptr->rng = rng;
@ -895,7 +919,7 @@ bool Replit::loadModel(const std::string &modelPath) {
auto fin = std::ifstream(modelPath, std::ios::binary);
// load the model
if (!replit_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) {
if (!replit_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab, nullptr)) {
std::cerr << "Replit ERROR: failed to load model from " << modelPath;
return false;
}

View File

@ -19,6 +19,7 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string & modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;

View File

@ -81,6 +81,8 @@ llmodel.llmodel_model_destroy.restype = None
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
@ -131,6 +133,16 @@ class LLModel:
if self.model is not None:
llmodel.llmodel_model_destroy(self.model)
def memory_needed(self, model_path: str) -> int:
model_path_enc = model_path.encode("utf-8")
self.model = llmodel.llmodel_model_create(model_path_enc)
if self.model is not None:
return llmodel.llmodel_required_mem(self.model, model_path_enc)
else:
raise ValueError("Unable to instantiate model")
def load_model(self, model_path: str) -> bool:
"""
Load model from a file.

View File

@ -20,6 +20,12 @@ ChatGPT::ChatGPT()
{
}
size_t ChatGPT::requiredMem(const std::string &modelPath)
{
Q_UNUSED(modelPath);
return 0;
}
bool ChatGPT::loadModel(const std::string &modelPath)
{
Q_UNUSED(modelPath);

View File

@ -16,6 +16,7 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;