feat: load model

This commit is contained in:
Zach Nussbaum 2023-05-07 06:03:04 -04:00 committed by Adam Treat
parent 58069dc8b9
commit 6a56bcaf06
3 changed files with 32 additions and 0 deletions

View File

@ -83,6 +83,7 @@ bool ChatLLM::loadModel(const QString &modelName)
} }
bool isGPTJ = false; bool isGPTJ = false;
bool isMPT = false;
QString filePath = modelFilePath(modelName); QString filePath = modelFilePath(modelName);
QFileInfo info(filePath); QFileInfo info(filePath);
if (info.exists()) { if (info.exists()) {
@ -93,9 +94,13 @@ bool ChatLLM::loadModel(const QString &modelName)
fin.seekg(0); fin.seekg(0);
fin.close(); fin.close();
isGPTJ = magic == 0x67676d6c; isGPTJ = magic == 0x67676d6c;
isMPT = magic == 0x67676d6d;
if (isGPTJ) { if (isGPTJ) {
m_llmodel = new GPTJ; m_llmodel = new GPTJ;
m_llmodel->loadModel(filePath.toStdString()); m_llmodel->loadModel(filePath.toStdString());
} else if (isMPT) {
m_llmodel = new MPT;
m_llmodel->loadModel(filePath.toStdString());
} else { } else {
m_llmodel = new LLamaModel; m_llmodel = new LLamaModel;
m_llmodel->loadModel(filePath.toStdString()); m_llmodel->loadModel(filePath.toStdString());

View File

@ -2,6 +2,7 @@
#include "gptj.h" #include "gptj.h"
#include "llamamodel.h" #include "llamamodel.h"
#include "mpt.h"
struct LLModelWrapper { struct LLModelWrapper {
LLModel *llModel = nullptr; LLModel *llModel = nullptr;
@ -22,6 +23,20 @@ void llmodel_gptj_destroy(llmodel_model gptj)
delete wrapper; delete wrapper;
} }
llmodel_model llmodel_mpt_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new MPT;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_mpt_destroy(llmodel_model mpt)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
delete wrapper->llModel;
delete wrapper;
}
llmodel_model llmodel_llama_create() llmodel_model llmodel_llama_create()
{ {
LLModelWrapper *wrapper = new LLModelWrapper; LLModelWrapper *wrapper = new LLModelWrapper;

View File

@ -71,6 +71,18 @@ llmodel_model llmodel_gptj_create();
*/ */
void llmodel_gptj_destroy(llmodel_model gptj); void llmodel_gptj_destroy(llmodel_model gptj);
/**
* Create a MPT instance.
* @return A pointer to the MPT instance.
*/
llmodel_model llmodel_mpt_create();
/**
* Destroy a MPT instance.
* @param gptj A pointer to the MPT instance.
*/
void llmodel_mpt_destroy(llmodel_model mpt);
/** /**
* Create a LLAMA instance. * Create a LLAMA instance.
* @return A pointer to the LLAMA instance. * @return A pointer to the LLAMA instance.