mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
add requiredMem method to llmodel impls
most of these can just shortcut out of the model loading logic llama is a bit worse to deal with because we submodule it so I have to at least parse the hparams, and then I just use the size on disk as an estimate for the mem size (which seems reasonable since we mmap() the llama files anyway)
This commit is contained in:
parent
dead954134
commit
b19a3e5b2c
@ -158,8 +158,11 @@ static bool kv_cache_init(
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) {
|
||||
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if(mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
@ -276,6 +279,19 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req += ctx_size;
|
||||
const int n_embd = model.hparams.n_embd;
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
@ -837,6 +853,15 @@ GPTJ::GPTJ()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
size_t GPTJ::requiredMem(const std::string &modelPath) {
|
||||
gptj_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
gptj_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
bool GPTJ::loadModel(const std::string &modelPath) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
|
@ -17,6 +17,7 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -97,6 +97,40 @@ LLamaModel::LLamaModel()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
// default hparams (LLaMA 7B)
|
||||
struct llama_file_hparams {
|
||||
uint32_t n_vocab = 32000;
|
||||
uint32_t n_embd = 4096;
|
||||
uint32_t n_mult = 256;
|
||||
uint32_t n_head = 32;
|
||||
uint32_t n_layer = 32;
|
||||
uint32_t n_rot = 64;
|
||||
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
|
||||
};
|
||||
|
||||
size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
fin.seekg(0, std::ios_base::end);
|
||||
size_t filesize = fin.tellg();
|
||||
fin.seekg(0, std::ios_base::beg);
|
||||
uint32_t magic = 0;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != 0x67676a74) return 0;
|
||||
uint32_t version = 0;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
llama_file_hparams hparams;
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_vocab), sizeof(hparams.n_vocab));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_embd), sizeof(hparams.n_embd));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_head), sizeof(hparams.n_head));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
|
||||
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
|
||||
const size_t n_ctx = 2048;
|
||||
const size_t kvcache_element_size = 2; // fp16
|
||||
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
|
||||
return filesize + est_kvcache_size;
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
{
|
||||
// load the model
|
||||
|
@ -17,6 +17,7 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -59,6 +59,7 @@ public:
|
||||
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath) = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||
|
@ -60,6 +60,12 @@ void llmodel_model_destroy(llmodel_model model) {
|
||||
delete reinterpret_cast<LLModelWrapper*>(model);
|
||||
}
|
||||
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->requiredMem(model_path);
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
@ -107,6 +107,14 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *build_va
|
||||
*/
|
||||
void llmodel_model_destroy(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Estimate RAM requirement for a model file
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
|
||||
*/
|
||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
|
||||
|
||||
/**
|
||||
* Load a model from a file.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
|
@ -152,9 +152,13 @@ static bool kv_cache_init(
|
||||
return true;
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
|
||||
// load the model's weights from a stream. if mem_req ptr is passed the model is
|
||||
// only partially parsed to estimate required memory
|
||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab, size_t * mem_req) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
@ -276,6 +280,18 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req += ctx_size;
|
||||
const int n_embd = model.hparams.n_embd;
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
||||
return false;
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
@ -431,7 +447,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
|
||||
return false;
|
||||
}
|
||||
|
||||
bool loaded = mpt_model_load(fname, fin, model, vocab);
|
||||
bool loaded = mpt_model_load(fname, fin, model, vocab, nullptr);
|
||||
fin.close();
|
||||
return loaded;
|
||||
}
|
||||
@ -761,6 +777,15 @@ MPT::MPT()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
size_t MPT::requiredMem(const std::string &modelPath) {
|
||||
mpt_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
mpt_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
bool MPT::loadModel(const std::string &modelPath) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
@ -768,7 +793,7 @@ bool MPT::loadModel(const std::string &modelPath) {
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
|
||||
// load the model
|
||||
if (!mpt_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) {
|
||||
if (!mpt_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab, nullptr)) {
|
||||
std::cerr << "MPT ERROR: failed to load model from " << modelPath;
|
||||
return false;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -267,8 +267,11 @@ static bool kv_cache_init(
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool replit_model_load(const std::string & fname, std::istream &fin, replit_model & model, replit_tokenizer & vocab) {
|
||||
bool replit_model_load(const std::string & fname, std::istream &fin, replit_model & model, replit_tokenizer & vocab, size_t *mem_req) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
@ -352,6 +355,18 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
|
||||
}
|
||||
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req += ctx_size;
|
||||
const int n_embd = model.hparams.n_embd;
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
||||
return false;
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
@ -544,7 +559,7 @@ bool replit_model_load(const std::string & fname, replit_model & model, replit_t
|
||||
return false;
|
||||
}
|
||||
|
||||
bool loaded = replit_model_load(fname, fin, model, vocab);
|
||||
bool loaded = replit_model_load(fname, fin, model, vocab, nullptr);
|
||||
fin.close();
|
||||
return loaded;
|
||||
}
|
||||
@ -888,6 +903,15 @@ Replit::Replit()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
size_t Replit::requiredMem(const std::string &modelPath) {
|
||||
replit_model dummy_model;
|
||||
replit_tokenizer dummy_vocab;
|
||||
size_t mem_req;
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
replit_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
bool Replit::loadModel(const std::string &modelPath) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
@ -895,7 +919,7 @@ bool Replit::loadModel(const std::string &modelPath) {
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
|
||||
// load the model
|
||||
if (!replit_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) {
|
||||
if (!replit_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab, nullptr)) {
|
||||
std::cerr << "Replit ERROR: failed to load model from " << modelPath;
|
||||
return false;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string & modelPath) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
@ -81,6 +81,8 @@ llmodel.llmodel_model_destroy.restype = None
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
|
||||
@ -131,6 +133,16 @@ class LLModel:
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_model_destroy(self.model)
|
||||
|
||||
def memory_needed(self, model_path: str) -> int:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
self.model = llmodel.llmodel_model_create(model_path_enc)
|
||||
|
||||
if self.model is not None:
|
||||
return llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
else:
|
||||
raise ValueError("Unable to instantiate model")
|
||||
|
||||
|
||||
def load_model(self, model_path: str) -> bool:
|
||||
"""
|
||||
Load model from a file.
|
||||
|
@ -20,6 +20,12 @@ ChatGPT::ChatGPT()
|
||||
{
|
||||
}
|
||||
|
||||
size_t ChatGPT::requiredMem(const std::string &modelPath)
|
||||
{
|
||||
Q_UNUSED(modelPath);
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool ChatGPT::loadModel(const std::string &modelPath)
|
||||
{
|
||||
Q_UNUSED(modelPath);
|
||||
|
@ -16,6 +16,7 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
|
Loading…
Reference in New Issue
Block a user