mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Dlopen backend 5 (#779)
Major change to the backend that allows for pluggable versions of llama.cpp/ggml. This was squashed merged from dlopen_backend_5 where the history is preserved.
This commit is contained in:
parent
f4a1f7340c
commit
48275d0dcc
10
.gitmodules
vendored
10
.gitmodules
vendored
@ -1,3 +1,9 @@
|
||||
[submodule "llama.cpp"]
|
||||
path = gpt4all-backend/llama.cpp
|
||||
[submodule "llama.cpp-230519"]
|
||||
path = gpt4all-backend/llama.cpp-230519
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
[submodule "llama.cpp-230511"]
|
||||
path = gpt4all-backend/llama.cpp-230511
|
||||
url = https://github.com/manyoso/llama.cpp.git
|
||||
[submodule "llama.cpp-mainline"]
|
||||
path = gpt4all-backend/llama.cpp-mainline
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
|
@ -17,36 +17,97 @@ endif()
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
set(LLMODEL_VERSION_MAJOR 0)
|
||||
set(LLMODEL_VERSION_MINOR 1)
|
||||
set(LLMODEL_VERSION_PATCH 1)
|
||||
set(LLMODEL_VERSION_MINOR 2)
|
||||
set(LLMODEL_VERSION_PATCH 0)
|
||||
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON FORCE)
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
if (GPT4ALL_AVX_ONLY)
|
||||
set(LLAMA_AVX2 OFF CACHE BOOL "llama: enable AVX2" FORCE)
|
||||
set(LLAMA_F16C OFF CACHE BOOL "llama: enable F16C" FORCE)
|
||||
set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE)
|
||||
# Check for IPO support
|
||||
include(CheckIPOSupported)
|
||||
check_ipo_supported(RESULT IPO_SUPPORTED OUTPUT IPO_ERROR)
|
||||
if (NOT IPO_SUPPORTED)
|
||||
message(WARNING "Interprocedural optimization is not supported by your toolchain! This will lead to bigger file sizes and worse performance: ${IPO_ERROR}")
|
||||
else()
|
||||
message(STATUS "Interprocedural optimization support detected")
|
||||
endif()
|
||||
|
||||
add_subdirectory(llama.cpp)
|
||||
include(llama.cpp.cmake)
|
||||
|
||||
set(BUILD_VARIANTS default avxonly)
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
# Go through each build variant
|
||||
foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS)
|
||||
# Determine flags
|
||||
if (BUILD_VARIANT STREQUAL avxonly)
|
||||
set(GPT4ALL_ALLOW_NON_AVX NO)
|
||||
else()
|
||||
set(GPT4ALL_ALLOW_NON_AVX YES)
|
||||
endif()
|
||||
set(LLAMA_AVX2 ${GPT4ALL_ALLOW_NON_AVX})
|
||||
set(LLAMA_F16C ${GPT4ALL_ALLOW_NON_AVX})
|
||||
set(LLAMA_FMA ${GPT4ALL_ALLOW_NON_AVX})
|
||||
|
||||
# Include GGML
|
||||
include_ggml(llama.cpp-mainline -mainline-${BUILD_VARIANT} ON)
|
||||
include_ggml(llama.cpp-230511 -230511-${BUILD_VARIANT} ON)
|
||||
include_ggml(llama.cpp-230519 -230519-${BUILD_VARIANT} ON)
|
||||
|
||||
# Function for preparing individual implementations
|
||||
function(prepare_target TARGET_NAME BASE_LIB)
|
||||
set(TARGET_NAME ${TARGET_NAME}-${BUILD_VARIANT})
|
||||
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
|
||||
# Link to ggml/llama
|
||||
target_link_libraries(${TARGET_NAME}
|
||||
PUBLIC ${BASE_LIB}-${BUILD_VARIANT})
|
||||
# Let it know about its build variant
|
||||
target_compile_definitions(${TARGET_NAME}
|
||||
PRIVATE GGML_BUILD_VARIANT="${BUILD_VARIANT}")
|
||||
# Enable IPO if possible
|
||||
set_property(TARGET ${TARGET_NAME}
|
||||
PROPERTY INTERPROCEDURAL_OPTIMIZATION ${IPO_SUPPORTED})
|
||||
endfunction()
|
||||
|
||||
# Add each individual implementations
|
||||
add_library(llamamodel-mainline-${BUILD_VARIANT} SHARED
|
||||
llamamodel.cpp)
|
||||
target_compile_definitions(llamamodel-mainline-${BUILD_VARIANT} PRIVATE
|
||||
LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
|
||||
prepare_target(llamamodel-mainline llama-mainline)
|
||||
|
||||
add_library(llamamodel-230519-${BUILD_VARIANT} SHARED
|
||||
llamamodel.cpp)
|
||||
target_compile_definitions(llamamodel-230519-${BUILD_VARIANT} PRIVATE
|
||||
LLAMA_VERSIONS===2 LLAMA_DATE=230519)
|
||||
prepare_target(llamamodel-230519 llama-230519)
|
||||
|
||||
add_library(llamamodel-230511-${BUILD_VARIANT} SHARED
|
||||
llamamodel.cpp)
|
||||
target_compile_definitions(llamamodel-230511-${BUILD_VARIANT} PRIVATE
|
||||
LLAMA_VERSIONS=<=1 LLAMA_DATE=230511)
|
||||
prepare_target(llamamodel-230511 llama-230511)
|
||||
|
||||
add_library(gptj-${BUILD_VARIANT} SHARED
|
||||
gptj.cpp utils.h utils.cpp)
|
||||
prepare_target(gptj ggml-230511)
|
||||
|
||||
add_library(mpt-${BUILD_VARIANT} SHARED
|
||||
mpt.cpp utils.h utils.cpp)
|
||||
prepare_target(mpt ggml-230511)
|
||||
endforeach()
|
||||
|
||||
add_library(llmodel
|
||||
gptj.h gptj.cpp
|
||||
llamamodel.h llamamodel.cpp
|
||||
llama.cpp/examples/common.cpp
|
||||
llmodel.h llmodel_c.h llmodel_c.cpp
|
||||
mpt.h mpt.cpp
|
||||
utils.h utils.cpp
|
||||
llmodel.h llmodel.cpp
|
||||
llmodel_c.h llmodel_c.cpp
|
||||
dlhandle.h
|
||||
)
|
||||
|
||||
target_link_libraries(llmodel
|
||||
PRIVATE llama)
|
||||
target_compile_definitions(llmodel PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
|
||||
|
||||
set_target_properties(llmodel PROPERTIES
|
||||
VERSION ${PROJECT_VERSION}
|
||||
|
101
gpt4all-backend/dlhandle.h
Normal file
101
gpt4all-backend/dlhandle.h
Normal file
@ -0,0 +1,101 @@
|
||||
#ifndef DLHANDLE_H
|
||||
#define DLHANDLE_H
|
||||
#ifndef _WIN32
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <dlfcn.h>
|
||||
|
||||
|
||||
|
||||
class Dlhandle {
|
||||
void *chandle;
|
||||
|
||||
public:
|
||||
class Exception : public std::runtime_error {
|
||||
public:
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
Dlhandle() : chandle(nullptr) {}
|
||||
Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) {
|
||||
chandle = dlopen(fpath.c_str(), flags);
|
||||
if (!chandle) {
|
||||
throw Exception("dlopen(\""+fpath+"\"): "+dlerror());
|
||||
}
|
||||
}
|
||||
Dlhandle(const Dlhandle& o) = delete;
|
||||
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
||||
o.chandle = nullptr;
|
||||
}
|
||||
~Dlhandle() {
|
||||
if (chandle) dlclose(chandle);
|
||||
}
|
||||
|
||||
auto operator =(Dlhandle&& o) {
|
||||
chandle = std::exchange(o.chandle, nullptr);
|
||||
}
|
||||
|
||||
bool is_valid() const {
|
||||
return chandle != nullptr;
|
||||
}
|
||||
operator bool() const {
|
||||
return is_valid();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* get(const std::string& fname) {
|
||||
auto fres = reinterpret_cast<T*>(dlsym(chandle, fname.c_str()));
|
||||
return (dlerror()==NULL)?fres:nullptr;
|
||||
}
|
||||
auto get_fnc(const std::string& fname) {
|
||||
return get<void*(...)>(fname);
|
||||
}
|
||||
};
|
||||
#else
|
||||
#include <string>
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
#include <windows.h>
|
||||
#include <libloaderapi.h>
|
||||
|
||||
|
||||
|
||||
class Dlhandle {
|
||||
HMODULE chandle;
|
||||
|
||||
public:
|
||||
class Exception : public std::runtime_error {
|
||||
public:
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
Dlhandle() : chandle(nullptr) {}
|
||||
Dlhandle(const std::string& fpath) {
|
||||
chandle = LoadLibraryA(fpath.c_str());
|
||||
if (!chandle) {
|
||||
throw Exception("dlopen(\""+fpath+"\"): Error");
|
||||
}
|
||||
}
|
||||
Dlhandle(const Dlhandle& o) = delete;
|
||||
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
||||
o.chandle = nullptr;
|
||||
}
|
||||
~Dlhandle() {
|
||||
if (chandle) FreeLibrary(chandle);
|
||||
}
|
||||
|
||||
bool is_valid() const {
|
||||
return chandle != nullptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* get(const std::string& fname) {
|
||||
return reinterpret_cast<T*>(GetProcAddress(chandle, fname.c_str()));
|
||||
}
|
||||
auto get_fnc(const std::string& fname) {
|
||||
return get<void*(...)>(fname);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
#endif // DLHANDLE_H
|
@ -1,5 +1,5 @@
|
||||
#include "gptj.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#include "gptj_impl.h"
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
@ -25,10 +25,16 @@
|
||||
#endif
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
#include <ggml.h>
|
||||
|
||||
|
||||
namespace {
|
||||
const char *modelType_ = "MPT";
|
||||
|
||||
static const size_t MB = 1024*1024;
|
||||
}
|
||||
|
||||
// default hparams (GPT-J 6B)
|
||||
static const size_t MB = 1024*1024;
|
||||
|
||||
struct gptj_hparams {
|
||||
int32_t n_vocab = 50400;
|
||||
int32_t n_ctx = 2048;
|
||||
@ -229,8 +235,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
}
|
||||
}
|
||||
|
||||
const ggml_type wtype2 = GGML_TYPE_F32;
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
@ -279,6 +283,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
@ -294,7 +299,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
@ -355,14 +359,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
ggml_free(ctx);
|
||||
@ -505,8 +501,6 @@ bool gptj_eval(
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_rot = hparams.n_rot;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
const size_t init_buf_size = 1024u*MB;
|
||||
if (!model.buf.addr || model.buf.size < init_buf_size)
|
||||
model.buf.resize(init_buf_size);
|
||||
@ -526,10 +520,12 @@ bool gptj_eval(
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = model.buf.size,
|
||||
.mem_buffer = model.buf.addr,
|
||||
.no_alloc = false
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
@ -772,8 +768,7 @@ size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, ui
|
||||
}
|
||||
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = gptj_get_state_size(model);
|
||||
assert(written == expected);
|
||||
assert(written == gptj_get_state_size(model));
|
||||
fflush(stdout);
|
||||
return written;
|
||||
}
|
||||
@ -822,8 +817,7 @@ size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *
|
||||
}
|
||||
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = gptj_get_state_size(*model);
|
||||
assert(nread == expected);
|
||||
assert(nread == gptj_get_state_size(*model));
|
||||
fflush(stdout);
|
||||
return nread;
|
||||
}
|
||||
@ -840,6 +834,7 @@ struct GPTJPrivate {
|
||||
|
||||
GPTJ::GPTJ()
|
||||
: d_ptr(new GPTJPrivate) {
|
||||
modelType = modelType_;
|
||||
|
||||
d_ptr->model = new gptj_model;
|
||||
d_ptr->modelLoaded = false;
|
||||
@ -908,12 +903,6 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
int64_t t_prompt_us = 0;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
||||
|
||||
@ -942,20 +931,19 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
@ -966,7 +954,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
@ -975,10 +963,6 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
t_prompt_us += ggml_time_us() - t_start_prompt_us;
|
||||
|
||||
int p_instructFound = 0;
|
||||
int r_instructFound = 0;
|
||||
|
||||
std::string cachedResponse;
|
||||
std::vector<gpt_vocab::id> cachedTokens;
|
||||
@ -986,24 +970,20 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
|
||||
// sample next token
|
||||
const int n_vocab = d_ptr->model->hparams.n_vocab;
|
||||
gpt_vocab::id id = 0;
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
||||
id = gpt_sample_top_k_top_p(n_vocab,
|
||||
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||
n_prev_toks,
|
||||
promptCtx.logits,
|
||||
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty,
|
||||
d_ptr->rng);
|
||||
|
||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Check if the context has run out...
|
||||
@ -1017,29 +997,24 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_start_predict_us = ggml_time_us();
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
t_predict_us += ggml_time_us() - t_start_predict_us;
|
||||
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
|
||||
if (id == 50256 /*end of text*/)
|
||||
goto stop_generating;
|
||||
return;
|
||||
|
||||
const std::string str = d_ptr->vocab.id_to_token[id];
|
||||
|
||||
// Check if the provided str is part of our reverse prompts
|
||||
bool foundPartialReversePrompt = false;
|
||||
const std::string completed = cachedResponse + str;
|
||||
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
||||
goto stop_generating;
|
||||
}
|
||||
if (reversePrompts.find(completed) != reversePrompts.end())
|
||||
return;
|
||||
|
||||
// Check if it partially matches our reverse prompts and if so, cache
|
||||
for (auto s : reversePrompts) {
|
||||
@ -1059,32 +1034,14 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(t);
|
||||
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
||||
goto stop_generating;
|
||||
return;
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
|
||||
stop_generating:
|
||||
|
||||
#if 0
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
|
||||
std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n";
|
||||
std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n";
|
||||
std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n";
|
||||
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n";
|
||||
std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n";
|
||||
fflush(stdout);
|
||||
}
|
||||
#endif
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
@ -1095,7 +1052,7 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
@ -1107,8 +1064,38 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define DLL_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define DLL_EXPORT __attribute__ ((visibility ("default")))
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
DLL_EXPORT bool is_g4a_backend_model_implementation() {
|
||||
return true;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_model_type() {
|
||||
return modelType_;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_build_variant() {
|
||||
return GGML_BUILD_VARIANT;
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(std::istream& f) {
|
||||
uint32_t magic = 0;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
return magic == 0x67676d6c;
|
||||
}
|
||||
|
||||
DLL_EXPORT LLModel *construct() {
|
||||
return new GPTJ;
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
#ifndef GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#error This file is NOT meant to be included outside of gptj.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#endif
|
||||
#ifndef GPTJ_H
|
||||
#define GPTJ_H
|
||||
|
||||
@ -6,7 +9,7 @@
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class GPTJPrivate;
|
||||
struct GPTJPrivate;
|
||||
class GPTJ : public LLModel {
|
||||
public:
|
||||
GPTJ();
|
1
gpt4all-backend/llama.cpp-230519
Submodule
1
gpt4all-backend/llama.cpp-230519
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 5ea43392731040b454c293123839b90e159cbb99
|
1
gpt4all-backend/llama.cpp-mainline
Submodule
1
gpt4all-backend/llama.cpp-mainline
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit ea600071cb005267e9e8f2629c1e406dd5fde083
|
364
gpt4all-backend/llama.cpp.cmake
Normal file
364
gpt4all-backend/llama.cpp.cmake
Normal file
@ -0,0 +1,364 @@
|
||||
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||
endif()
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(LLAMA_STANDALONE ON)
|
||||
|
||||
# configure project version
|
||||
# TODO
|
||||
else()
|
||||
set(LLAMA_STANDALONE OFF)
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||
else()
|
||||
if (MINGW)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
else()
|
||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Option list
|
||||
#
|
||||
|
||||
# general
|
||||
option(LLAMA_STATIC "llama: static link libraries" OFF)
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
||||
option(LLAMA_LTO "llama: enable link time optimization" OFF)
|
||||
|
||||
# debug
|
||||
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
||||
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
||||
option(LLAMA_GPROF "llama: enable gprof" OFF)
|
||||
|
||||
# sanitizers
|
||||
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||
|
||||
# instruction set specific
|
||||
#option(LLAMA_AVX "llama: enable AVX" ON)
|
||||
#option(LLAMA_AVX2 "llama: enable AVX2" ON)
|
||||
#option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||
#option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||
#option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||
#option(LLAMA_FMA "llama: enable FMA" ON)
|
||||
# in MSVC F16C is implied with AVX2/AVX512
|
||||
#if (NOT MSVC)
|
||||
# option(LLAMA_F16C "llama: enable F16C" ON)
|
||||
#endif()
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||
|
||||
#
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
if (NOT MSVC)
|
||||
if (LLAMA_SANITIZE_THREAD)
|
||||
add_compile_options(-fsanitize=thread)
|
||||
link_libraries(-fsanitize=thread)
|
||||
endif()
|
||||
|
||||
if (LLAMA_SANITIZE_ADDRESS)
|
||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||
link_libraries(-fsanitize=address)
|
||||
endif()
|
||||
|
||||
if (LLAMA_SANITIZE_UNDEFINED)
|
||||
add_compile_options(-fsanitize=undefined)
|
||||
link_libraries(-fsanitize=undefined)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (APPLE AND LLAMA_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
add_compile_definitions(GGML_USE_ACCELERATE)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_OPENBLAS)
|
||||
if (LLAMA_STATIC)
|
||||
set(BLA_STATIC ON)
|
||||
endif()
|
||||
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
find_package(BLAS)
|
||||
if (BLAS_FOUND)
|
||||
message(STATUS "OpenBLAS found")
|
||||
|
||||
add_compile_definitions(GGML_USE_OPENBLAS)
|
||||
add_link_options(${BLAS_LIBRARIES})
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas)
|
||||
|
||||
# find header file
|
||||
set(OPENBLAS_INCLUDE_SEARCH_PATHS
|
||||
/usr/include
|
||||
/usr/include/openblas
|
||||
/usr/include/openblas-base
|
||||
/usr/local/include
|
||||
/usr/local/include/openblas
|
||||
/usr/local/include/openblas-base
|
||||
/opt/OpenBLAS/include
|
||||
$ENV{OpenBLAS_HOME}
|
||||
$ENV{OpenBLAS_HOME}/include
|
||||
)
|
||||
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
|
||||
add_compile_options(-I${OPENBLAS_INC})
|
||||
else()
|
||||
message(WARNING "OpenBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_ALL_WARNINGS)
|
||||
if (NOT MSVC)
|
||||
set(c_flags
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wpedantic
|
||||
-Wcast-qual
|
||||
-Wdouble-promotion
|
||||
-Wshadow
|
||||
-Wstrict-prototypes
|
||||
-Wpointer-arith
|
||||
)
|
||||
set(cxx_flags
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wpedantic
|
||||
-Wcast-qual
|
||||
-Wno-unused-function
|
||||
-Wno-multichar
|
||||
)
|
||||
else()
|
||||
# todo : msvc
|
||||
endif()
|
||||
|
||||
add_compile_options(
|
||||
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
|
||||
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_LTO)
|
||||
include(CheckIPOSupported)
|
||||
check_ipo_supported(RESULT result OUTPUT output)
|
||||
if (result)
|
||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
||||
else()
|
||||
message(WARNING "IPO is not supported: ${output}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if (NOT MSVC)
|
||||
if (LLAMA_STATIC)
|
||||
add_link_options(-static)
|
||||
if (MINGW)
|
||||
add_link_options(-static-libgcc -static-libstdc++)
|
||||
endif()
|
||||
endif()
|
||||
if (LLAMA_GPROF)
|
||||
add_compile_options(-pg)
|
||||
endif()
|
||||
if (LLAMA_NATIVE)
|
||||
add_compile_options(-march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
|
||||
message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}")
|
||||
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
message(STATUS "ARM detected")
|
||||
if (MSVC)
|
||||
# TODO: arm msvc?
|
||||
else()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
add_compile_options(-mcpu=native)
|
||||
endif()
|
||||
# TODO: armv6,7,8 version specific flags
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
||||
message(STATUS "x86 detected")
|
||||
if (MSVC)
|
||||
if (LLAMA_AVX512)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
|
||||
# MSVC has no compile-time flags enabling specific
|
||||
# AVX512 extensions, neither it defines the
|
||||
# macros corresponding to the extensions.
|
||||
# Do it manually.
|
||||
if (LLAMA_AVX512_VBMI)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
||||
endif()
|
||||
if (LLAMA_AVX512_VNNI)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
||||
endif()
|
||||
elseif (LLAMA_AVX2)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
|
||||
elseif (LLAMA_AVX)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
|
||||
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
|
||||
endif()
|
||||
else()
|
||||
if (LLAMA_F16C)
|
||||
add_compile_options(-mf16c)
|
||||
endif()
|
||||
if (LLAMA_FMA)
|
||||
add_compile_options(-mfma)
|
||||
endif()
|
||||
if (LLAMA_AVX)
|
||||
add_compile_options(-mavx)
|
||||
endif()
|
||||
if (LLAMA_AVX2)
|
||||
add_compile_options(-mavx2)
|
||||
endif()
|
||||
if (LLAMA_AVX512)
|
||||
add_compile_options(-mavx512f)
|
||||
add_compile_options(-mavx512bw)
|
||||
endif()
|
||||
if (LLAMA_AVX512_VBMI)
|
||||
add_compile_options(-mavx512vbmi)
|
||||
endif()
|
||||
if (LLAMA_AVX512_VNNI)
|
||||
add_compile_options(-mavx512vnni)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
# TODO: support PowerPC
|
||||
message(STATUS "Unknown architecture")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build libraries
|
||||
#
|
||||
|
||||
if (LLAMA_CUBLAS AND EXISTS ${DIRECTORY}/ggml-cuda.h)
|
||||
cmake_minimum_required(VERSION 3.17)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "cuBLAS found")
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
set(GGML_CUDA_SOURCES ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(WARNING "cuBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_CLBLAST AND EXISTS ${DIRECTORY}/ggml-opencl.h)
|
||||
find_package(CLBlast)
|
||||
if (CLBlast_FOUND)
|
||||
message(STATUS "CLBlast found")
|
||||
|
||||
set(GGML_OPENCL_SOURCES ${DIRECTORY}/ggml-opencl.c ${DIRECTORY}/ggml-opencl.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CLBLAST)
|
||||
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
|
||||
else()
|
||||
message(WARNING "CLBlast not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(ggml${SUFFIX} OBJECT
|
||||
${DIRECTORY}/ggml.c
|
||||
${DIRECTORY}/ggml.h
|
||||
${GGML_CUDA_SOURCES}
|
||||
${GGML_OPENCL_SOURCES})
|
||||
|
||||
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
|
||||
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
|
||||
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
if (WITH_LLAMA)
|
||||
# Backwards compatibility with old llama.cpp versions
|
||||
set(LLAMA_UTIL_SOURCE_FILE llama-util.h)
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
|
||||
set(LLAMA_UTIL_SOURCE_FILE llama_util.h)
|
||||
endif()
|
||||
|
||||
add_library(llama${SUFFIX}
|
||||
${DIRECTORY}/llama.cpp
|
||||
${DIRECTORY}/llama.h
|
||||
${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
|
||||
|
||||
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY})
|
||||
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
|
||||
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_SOURCES)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
if (WITH_LLAMA)
|
||||
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
@ -1,8 +1,5 @@
|
||||
#include "llamamodel.h"
|
||||
|
||||
#include "llama.cpp/examples/common.h"
|
||||
#include "llama.cpp/llama.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#include "llamamodel_impl.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -28,16 +25,77 @@
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <llama.h>
|
||||
#include <ggml.h>
|
||||
|
||||
|
||||
namespace {
|
||||
const char *modelType_ = "LLaMA";
|
||||
}
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
#if LLAMA_DATE <= 230511
|
||||
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
||||
#endif
|
||||
|
||||
#if LLAMA_DATE >= 230519
|
||||
// sampling parameters
|
||||
float tfs_z = 1.0f; // 1.0 = disabled
|
||||
float typical_p = 1.0f; // 1.0 = disabled
|
||||
#endif
|
||||
|
||||
std::string prompt = "";
|
||||
|
||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
};
|
||||
|
||||
#if LLAMA_DATE >= 230519
|
||||
static int llama_sample_top_p_top_k(
|
||||
llama_context *ctx,
|
||||
const llama_token *last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
int top_k,
|
||||
float top_p,
|
||||
float temp,
|
||||
float repeat_penalty) {
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
// Populate initial list of all candidates
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
||||
// Sample repeat penalty
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty);
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
return llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
llama_context *ctx = nullptr;
|
||||
llama_context_params params;
|
||||
int64_t n_threads = 0;
|
||||
bool empty = true;
|
||||
};
|
||||
|
||||
LLamaModel::LLamaModel()
|
||||
: d_ptr(new LLamaPrivate) {
|
||||
modelType = modelType_;
|
||||
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
@ -49,14 +107,12 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
|
||||
gpt_params params;
|
||||
d_ptr->params.n_ctx = 2048;
|
||||
d_ptr->params.n_parts = params.n_parts;
|
||||
d_ptr->params.seed = params.seed;
|
||||
d_ptr->params.f16_kv = params.memory_f16;
|
||||
d_ptr->params.use_mmap = params.use_mmap;
|
||||
#if defined (__APPLE__)
|
||||
d_ptr->params.use_mlock = true;
|
||||
#else
|
||||
d_ptr->params.use_mlock = params.use_mlock;
|
||||
#if LLAMA_DATE <= 230511
|
||||
d_ptr->params.n_parts = params.n_parts;
|
||||
#endif
|
||||
|
||||
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
|
||||
@ -75,8 +131,7 @@ void LLamaModel::setThreadCount(int32_t n_threads) {
|
||||
d_ptr->n_threads = n_threads;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::threadCount() const
|
||||
{
|
||||
int32_t LLamaModel::threadCount() const {
|
||||
return d_ptr->n_threads;
|
||||
}
|
||||
|
||||
@ -102,7 +157,8 @@ size_t LLamaModel::saveState(uint8_t *dest) const
|
||||
|
||||
size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
{
|
||||
return llama_set_state_data(d_ptr->ctx, src);
|
||||
// const_cast is required, see: https://github.com/ggerganov/llama.cpp/pull/1540
|
||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
@ -123,7 +179,11 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
|
||||
std::vector<llama_token> embd_inp(params.prompt.size() + 4);
|
||||
int n = llama_tokenize(d_ptr->ctx, params.prompt.c_str(), embd_inp.data(), embd_inp.size(), d_ptr->empty);
|
||||
assert(n >= 0);
|
||||
embd_inp.resize(n);
|
||||
d_ptr->empty = false;
|
||||
|
||||
// save the context size
|
||||
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
@ -143,20 +203,19 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
@ -166,7 +225,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
@ -179,10 +238,9 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
std::string cachedResponse;
|
||||
std::vector<llama_token> cachedTokens;
|
||||
std::unordered_set<std::string> reversePrompts
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
// sample next token
|
||||
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||
@ -209,7 +267,6 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos())
|
||||
return;
|
||||
|
||||
@ -240,7 +297,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(t);
|
||||
if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t)))
|
||||
@ -258,7 +315,7 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
|
||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||
@ -269,8 +326,43 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define DLL_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define DLL_EXPORT __attribute__ ((visibility ("default")))
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
DLL_EXPORT bool is_g4a_backend_model_implementation() {
|
||||
return true;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_model_type() {
|
||||
return modelType_;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_build_variant() {
|
||||
return GGML_BUILD_VARIANT;
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(std::istream& f) {
|
||||
// Check magic
|
||||
uint32_t magic = 0;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != 0x67676a74) return false;
|
||||
// Check version
|
||||
uint32_t version = 0;
|
||||
f.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
return version LLAMA_VERSIONS;
|
||||
}
|
||||
|
||||
DLL_EXPORT LLModel *construct() {
|
||||
return new LLamaModel;
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
#ifndef LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#error This file is NOT meant to be included outside of llamamodel.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#endif
|
||||
#ifndef LLAMAMODEL_H
|
||||
#define LLAMAMODEL_H
|
||||
|
||||
@ -6,7 +9,7 @@
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class LLamaPrivate;
|
||||
struct LLamaPrivate;
|
||||
class LLamaModel : public LLModel {
|
||||
public:
|
||||
LLamaModel();
|
||||
@ -33,4 +36,4 @@ private:
|
||||
LLamaPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // LLAMAMODEL_H
|
||||
#endif // LLAMAMODEL_H
|
90
gpt4all-backend/llmodel.cpp
Normal file
90
gpt4all-backend/llmodel.cpp
Normal file
@ -0,0 +1,90 @@
|
||||
#include "llmodel.h"
|
||||
#include "dlhandle.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
static Dlhandle *get_implementation(std::ifstream& f, const std::string& buildVariant) {
|
||||
// Collect all model implementation libraries
|
||||
// NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the
|
||||
// individual models without the cleanup of the static list interfering
|
||||
static auto* libs = new std::vector<Dlhandle>([] () {
|
||||
std::vector<Dlhandle> fres;
|
||||
|
||||
auto search_in_directory = [&](const std::filesystem::path& path) {
|
||||
// Iterate over all libraries
|
||||
for (const auto& f : std::filesystem::directory_iterator(path)) {
|
||||
// Get path
|
||||
const auto& p = f.path();
|
||||
// Check extension
|
||||
if (p.extension() != LIB_FILE_EXT) continue;
|
||||
// Add to list if model implementation
|
||||
try {
|
||||
Dlhandle dl(p.string());
|
||||
if (dl.get<bool(uint32_t)>("is_g4a_backend_model_implementation")) {
|
||||
fres.emplace_back(std::move(dl));
|
||||
}
|
||||
} catch (...) {}
|
||||
}
|
||||
};
|
||||
|
||||
search_in_directory(".");
|
||||
#if defined(__APPLE__)
|
||||
search_in_directory("../../../");
|
||||
#endif
|
||||
return fres;
|
||||
}());
|
||||
// Iterate over all libraries
|
||||
for (auto& dl : *libs) {
|
||||
f.seekg(0);
|
||||
// Check that magic matches
|
||||
auto magic_match = dl.get<bool(std::ifstream&)>("magic_match");
|
||||
if (!magic_match || !magic_match(f)) {
|
||||
continue;
|
||||
}
|
||||
// Check that build variant is correct
|
||||
auto get_build_variant = dl.get<const char *()>("get_build_variant");
|
||||
if (buildVariant != (get_build_variant?get_build_variant():"default")) {
|
||||
continue;
|
||||
}
|
||||
// Looks like we're good to go, return this dlhandle
|
||||
return &dl;
|
||||
}
|
||||
// Nothing found, so return nothing
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool requires_avxonly() {
|
||||
#ifdef __x86_64__
|
||||
return !__builtin_cpu_supports("avx2") && !__builtin_cpu_supports("fma");
|
||||
#else
|
||||
return false; // Don't know how to handle ARM
|
||||
#endif
|
||||
}
|
||||
|
||||
LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) {
|
||||
//TODO: Auto-detect
|
||||
if (buildVariant == "auto") {
|
||||
if (requires_avxonly()) {
|
||||
buildVariant = "avxonly";
|
||||
} else {
|
||||
buildVariant = "default";
|
||||
}
|
||||
}
|
||||
// Read magic
|
||||
std::ifstream f(modelPath, std::ios::binary);
|
||||
if (!f) return nullptr;
|
||||
// Get correct implementation
|
||||
auto impl = get_implementation(f, buildVariant);
|
||||
if (!impl) return nullptr;
|
||||
f.close();
|
||||
// Get inference constructor
|
||||
auto constructor = impl->get<LLModel *()>("construct");
|
||||
if (!constructor) return nullptr;
|
||||
// Construct llmodel implementation
|
||||
auto fres = constructor();
|
||||
// Return final instance
|
||||
return fres;
|
||||
}
|
@ -1,21 +1,23 @@
|
||||
#ifndef LLMODEL_H
|
||||
#define LLMODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
class LLModel {
|
||||
public:
|
||||
explicit LLModel() {}
|
||||
virtual ~LLModel() {}
|
||||
|
||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "default");
|
||||
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t *dest) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t *src) { return 0; }
|
||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||
struct PromptContext {
|
||||
std::vector<float> logits; // logits of current context
|
||||
std::vector<int32_t> tokens; // current tokens in the context window
|
||||
@ -36,12 +38,18 @@ public:
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual void setThreadCount(int32_t /*n_threads*/) {}
|
||||
virtual int32_t threadCount() const { return 1; }
|
||||
|
||||
const char *getModelType() const {
|
||||
return modelType;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) = 0;
|
||||
|
||||
const char *modelType;
|
||||
};
|
||||
|
||||
#endif // LLMODEL_H
|
||||
|
@ -1,79 +1,57 @@
|
||||
#include "llmodel_c.h"
|
||||
#include "llmodel.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <utility>
|
||||
|
||||
#include "gptj.h"
|
||||
#include "llamamodel.h"
|
||||
#include "mpt.h"
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
};
|
||||
|
||||
llmodel_model llmodel_gptj_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new GPTJ;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_gptj_destroy(llmodel_model gptj)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
thread_local static std::string last_error_message;
|
||||
|
||||
llmodel_model llmodel_mpt_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new MPT;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_mpt_destroy(llmodel_model mpt)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_llama_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new LLamaModel;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_llama_destroy(llmodel_model llama)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_model_create(const char *model_path) {
|
||||
auto fres = llmodel_model_create2(model_path, "auto", nullptr);
|
||||
if (!fres) {
|
||||
fprintf(stderr, "Invalid model file\n");
|
||||
}
|
||||
return fres;
|
||||
}
|
||||
|
||||
uint32_t magic;
|
||||
llmodel_model model;
|
||||
FILE *f = fopen(model_path, "rb");
|
||||
fread(&magic, sizeof(magic), 1, f);
|
||||
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) {
|
||||
auto wrapper = new LLModelWrapper;
|
||||
llmodel_error new_error{};
|
||||
|
||||
if (magic == 0x67676d6c) { model = llmodel_gptj_create(); }
|
||||
else if (magic == 0x67676a74) { model = llmodel_llama_create(); }
|
||||
else if (magic == 0x67676d6d) { model = llmodel_mpt_create(); }
|
||||
else {fprintf(stderr, "Invalid model file\n");}
|
||||
fclose(f);
|
||||
return model;
|
||||
try {
|
||||
wrapper->llModel = LLModel::construct(model_path, build_variant);
|
||||
} catch (const std::exception& e) {
|
||||
new_error.code = EINVAL;
|
||||
last_error_message = e.what();
|
||||
}
|
||||
|
||||
if (!wrapper->llModel) {
|
||||
delete std::exchange(wrapper, nullptr);
|
||||
// Get errno and error message if none
|
||||
if (new_error.code == 0) {
|
||||
new_error.code = errno;
|
||||
last_error_message = strerror(errno);
|
||||
}
|
||||
// Set message pointer
|
||||
new_error.message = last_error_message.c_str();
|
||||
// Set error argument
|
||||
if (error) *error = new_error;
|
||||
}
|
||||
return reinterpret_cast<llmodel_model*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_model_destroy(llmodel_model model) {
|
||||
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
const std::type_info &modelTypeInfo = typeid(*wrapper->llModel);
|
||||
|
||||
if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); }
|
||||
if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); }
|
||||
if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); }
|
||||
delete wrapper->llModel;
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
@ -84,20 +62,20 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->isModelLoaded();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
uint64_t llmodel_get_state_size(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->stateSize();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->stateSize();
|
||||
}
|
||||
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->saveState(dest);
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->saveState(dest);
|
||||
}
|
||||
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
||||
@ -181,6 +159,6 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
|
||||
|
||||
int32_t llmodel_threadCount(llmodel_model model)
|
||||
{
|
||||
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
|
||||
return llm->threadCount();
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->threadCount();
|
||||
}
|
||||
|
@ -5,6 +5,15 @@
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#define DEPRECATED __attribute__ ((deprecated))
|
||||
#elif defined(_MSC_VER)
|
||||
#define DEPRECATED __declspec(deprecated)
|
||||
#else
|
||||
#pragma message("WARNING: You need to implement DEPRECATED for this compiler")
|
||||
#define DEPRECATED
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -14,13 +23,24 @@ extern "C" {
|
||||
*/
|
||||
typedef void *llmodel_model;
|
||||
|
||||
/**
|
||||
* Structure containing any errors that may eventually occur
|
||||
*/
|
||||
struct llmodel_error {
|
||||
const char *message; // Human readable error description; Thread-local; guaranteed to survive until next llmodel C API call
|
||||
int code; // errno; 0 if none
|
||||
};
|
||||
#ifndef __cplusplus
|
||||
typedef struct llmodel_error llmodel_error;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* llmodel_prompt_context structure for holding the prompt context.
|
||||
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
|
||||
* behavior.
|
||||
*/
|
||||
typedef struct {
|
||||
struct llmodel_prompt_context {
|
||||
float *logits; // logits of current context
|
||||
size_t logits_size; // the size of the raw logits vector
|
||||
int32_t *tokens; // current tokens in the context window
|
||||
@ -35,7 +55,10 @@ typedef struct {
|
||||
float repeat_penalty; // penalty factor for repeated tokens
|
||||
int32_t repeat_last_n; // last n tokens to penalize
|
||||
float context_erase; // percent of context to erase if we exceed the context window
|
||||
} llmodel_prompt_context;
|
||||
};
|
||||
#ifndef __cplusplus
|
||||
typedef struct llmodel_prompt_context llmodel_prompt_context;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Callback type for prompt processing.
|
||||
@ -60,48 +83,22 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
|
||||
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
||||
|
||||
/**
|
||||
* Create a GPTJ instance.
|
||||
* @return A pointer to the GPTJ instance.
|
||||
* Create a llmodel instance.
|
||||
* Recognises correct model type from file at model_path
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @return A pointer to the llmodel_model instance; NULL on error.
|
||||
*/
|
||||
llmodel_model llmodel_gptj_create();
|
||||
|
||||
/**
|
||||
* Destroy a GPTJ instance.
|
||||
* @param gptj A pointer to the GPTJ instance.
|
||||
*/
|
||||
void llmodel_gptj_destroy(llmodel_model gptj);
|
||||
|
||||
/**
|
||||
* Create a MPT instance.
|
||||
* @return A pointer to the MPT instance.
|
||||
*/
|
||||
llmodel_model llmodel_mpt_create();
|
||||
|
||||
/**
|
||||
* Destroy a MPT instance.
|
||||
* @param gptj A pointer to the MPT instance.
|
||||
*/
|
||||
void llmodel_mpt_destroy(llmodel_model mpt);
|
||||
|
||||
/**
|
||||
* Create a LLAMA instance.
|
||||
* @return A pointer to the LLAMA instance.
|
||||
*/
|
||||
llmodel_model llmodel_llama_create();
|
||||
|
||||
/**
|
||||
* Destroy a LLAMA instance.
|
||||
* @param llama A pointer to the LLAMA instance.
|
||||
*/
|
||||
void llmodel_llama_destroy(llmodel_model llama);
|
||||
DEPRECATED llmodel_model llmodel_model_create(const char *model_path);
|
||||
|
||||
/**
|
||||
* Create a llmodel instance.
|
||||
* Recognises correct model type from file at model_path
|
||||
* @param model_path A string representing the path to the model file.
|
||||
* @return A pointer to the llmodel_model instance.
|
||||
* @param model_path A string representing the path to the model file; will only be used to detect model type.
|
||||
* @param build_variant A string representing the implementation to use (auto, default, avxonly, ...),
|
||||
* @param error A pointer to a llmodel_error; will only be set on error.
|
||||
* @return A pointer to the llmodel_model instance; NULL on error.
|
||||
*/
|
||||
llmodel_model llmodel_model_create(const char *model_path);
|
||||
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error);
|
||||
|
||||
/**
|
||||
* Destroy a llmodel instance.
|
||||
@ -110,7 +107,6 @@ llmodel_model llmodel_model_create(const char *model_path);
|
||||
*/
|
||||
void llmodel_model_destroy(llmodel_model model);
|
||||
|
||||
|
||||
/**
|
||||
* Load a model from a file.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include "mpt.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#include "mpt_impl.h"
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
@ -28,8 +28,14 @@
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
#include <regex>
|
||||
#include <ggml.h>
|
||||
|
||||
|
||||
namespace {
|
||||
const char *modelType_ = "MPT";
|
||||
|
||||
static const size_t MB = 1024*1024;
|
||||
}
|
||||
|
||||
// default hparams (MPT 7B)
|
||||
struct mpt_hparams {
|
||||
@ -293,7 +299,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
@ -331,14 +336,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
ggml_free(ctx);
|
||||
@ -457,9 +454,6 @@ bool mpt_eval(
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
const size_t init_buf_size = 1024u*MB;
|
||||
if (!model.buf.addr || model.buf.size < init_buf_size)
|
||||
@ -480,10 +474,12 @@ bool mpt_eval(
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = model.buf.size,
|
||||
.mem_buffer = model.buf.addr,
|
||||
.no_alloc = false
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
@ -695,8 +691,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
||||
}
|
||||
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = mpt_get_state_size(model);
|
||||
assert(written == expected);
|
||||
assert(written == mpt_get_state_size(model));
|
||||
fflush(stdout);
|
||||
return written;
|
||||
}
|
||||
@ -745,8 +740,7 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
|
||||
}
|
||||
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = mpt_get_state_size(*model);
|
||||
assert(nread == expected);
|
||||
assert(nread == mpt_get_state_size(*model));
|
||||
fflush(stdout);
|
||||
return nread;
|
||||
}
|
||||
@ -764,6 +758,7 @@ struct MPTPrivate {
|
||||
|
||||
MPT::MPT()
|
||||
: d_ptr(new MPTPrivate) {
|
||||
modelType = modelType_;
|
||||
|
||||
d_ptr->model = new mpt_model;
|
||||
d_ptr->modelLoaded = false;
|
||||
@ -833,12 +828,6 @@ void MPT::prompt(const std::string &prompt,
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_main_start_us = ggml_time_us();
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_predict_us = 0;
|
||||
int64_t t_prompt_us = 0;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<int> embd_inp = gpt_tokenize(d_ptr->vocab, prompt);
|
||||
|
||||
@ -867,20 +856,19 @@ void MPT::prompt(const std::string &prompt,
|
||||
|
||||
// process the prompt in batches
|
||||
size_t i = 0;
|
||||
const int64_t t_start_prompt_us = ggml_time_us();
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<int> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
|
||||
// Erase the first percentage of context from the tokens...
|
||||
std::cerr << "MPT: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
@ -891,7 +879,7 @@ void MPT::prompt(const std::string &prompt,
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
if (!promptCallback(batch.at(t)))
|
||||
@ -900,10 +888,6 @@ void MPT::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
}
|
||||
t_prompt_us += ggml_time_us() - t_start_prompt_us;
|
||||
|
||||
int p_instructFound = 0;
|
||||
int r_instructFound = 0;
|
||||
|
||||
std::string cachedResponse;
|
||||
std::vector<int> cachedTokens;
|
||||
@ -911,24 +895,20 @@ void MPT::prompt(const std::string &prompt,
|
||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
||||
|
||||
// predict next tokens
|
||||
int32_t totalPredictions = 0;
|
||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||
|
||||
// sample next token
|
||||
const int n_vocab = d_ptr->model->hparams.n_vocab;
|
||||
int id = 0;
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
||||
id = gpt_sample_top_k_top_p(n_vocab,
|
||||
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||
n_prev_toks,
|
||||
promptCtx.logits,
|
||||
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty,
|
||||
d_ptr->rng);
|
||||
|
||||
t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Check if the context has run out...
|
||||
@ -942,33 +922,28 @@ void MPT::prompt(const std::string &prompt,
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_start_predict_us = ggml_time_us();
|
||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
t_predict_us += ggml_time_us() - t_start_predict_us;
|
||||
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
|
||||
// display tex
|
||||
// mpt-7b-chat has special token for end
|
||||
if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"])
|
||||
goto stop_generating;
|
||||
return;
|
||||
|
||||
if (id == 0 /*end of text*/)
|
||||
goto stop_generating;
|
||||
return;
|
||||
|
||||
const std::string str = d_ptr->vocab.id_to_token[id];
|
||||
|
||||
// Check if the provided str is part of our reverse prompts
|
||||
bool foundPartialReversePrompt = false;
|
||||
const std::string completed = cachedResponse + str;
|
||||
if (reversePrompts.find(completed) != reversePrompts.end()) {
|
||||
goto stop_generating;
|
||||
}
|
||||
if (reversePrompts.find(completed) != reversePrompts.end())
|
||||
return;
|
||||
|
||||
// Check if it partially matches our reverse prompts and if so, cache
|
||||
for (auto s : reversePrompts) {
|
||||
@ -988,32 +963,14 @@ void MPT::prompt(const std::string &prompt,
|
||||
|
||||
// Empty the cache
|
||||
for (auto t : cachedTokens) {
|
||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||
promptCtx.tokens.push_back(t);
|
||||
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
||||
goto stop_generating;
|
||||
return;
|
||||
}
|
||||
cachedTokens.clear();
|
||||
}
|
||||
|
||||
stop_generating:
|
||||
|
||||
#if 0
|
||||
// report timing
|
||||
{
|
||||
const int64_t t_main_end_us = ggml_time_us();
|
||||
|
||||
std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n";
|
||||
std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n";
|
||||
std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n";
|
||||
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n";
|
||||
std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n";
|
||||
fflush(stdout);
|
||||
}
|
||||
#endif
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
@ -1024,7 +981,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<int> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
|
||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
@ -1036,8 +993,38 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == promptCtx.tokens.size());
|
||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define DLL_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define DLL_EXPORT __attribute__ ((visibility ("default")))
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
DLL_EXPORT bool is_g4a_backend_model_implementation() {
|
||||
return true;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_model_type() {
|
||||
return modelType_;
|
||||
}
|
||||
|
||||
DLL_EXPORT const char *get_build_variant() {
|
||||
return GGML_BUILD_VARIANT;
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(std::istream& f) {
|
||||
uint32_t magic = 0;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
return magic == 0x67676d6d;
|
||||
}
|
||||
|
||||
DLL_EXPORT LLModel *construct() {
|
||||
return new MPT;
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,6 @@
|
||||
#ifndef MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#error This file is NOT meant to be included outside of mpt.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
|
||||
#endif
|
||||
#ifndef MPT_H
|
||||
#define MPT_H
|
||||
|
||||
@ -6,7 +9,7 @@
|
||||
#include <vector>
|
||||
#include "llmodel.h"
|
||||
|
||||
class MPTPrivate;
|
||||
struct MPTPrivate;
|
||||
class MPT : public LLModel {
|
||||
public:
|
||||
MPT();
|
@ -218,7 +218,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
}
|
||||
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const size_t actualVocabSize,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
|
@ -79,7 +79,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
// TODO: not sure if this implementation is correct
|
||||
//
|
||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const size_t actualVocabSize,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
|
@ -55,10 +55,10 @@ get_filename_component(Qt6_ROOT_DIR "${Qt6_ROOT_DIR}/.." ABSOLUTE)
|
||||
message(STATUS "qmake binary: ${QMAKE_EXECUTABLE}")
|
||||
message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}")
|
||||
|
||||
add_subdirectory(../gpt4all-backend llmodel)
|
||||
|
||||
set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
add_subdirectory(../gpt4all-backend llmodel)
|
||||
|
||||
qt_add_executable(chat
|
||||
main.cpp
|
||||
chat.h chat.cpp
|
||||
@ -146,7 +146,6 @@ endif()
|
||||
|
||||
install(TARGETS chat DESTINATION bin COMPONENT ${COMPONENT_NAME_MAIN})
|
||||
install(TARGETS llmodel DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
|
||||
install(TARGETS llama DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
|
||||
|
||||
set(CPACK_GENERATOR "IFW")
|
||||
set(CPACK_VERBATIM_VARIABLES YES)
|
||||
|
@ -2,9 +2,7 @@
|
||||
#include "chat.h"
|
||||
#include "download.h"
|
||||
#include "network.h"
|
||||
#include "../gpt4all-backend/gptj.h"
|
||||
#include "../gpt4all-backend/llamamodel.h"
|
||||
#include "../gpt4all-backend/mpt.h"
|
||||
#include "../gpt4all-backend/llmodel.h"
|
||||
#include "chatgpt.h"
|
||||
|
||||
#include <QCoreApplication>
|
||||
@ -215,25 +213,15 @@ bool ChatLLM::loadModel(const QString &modelName)
|
||||
model->setAPIKey(apiKey);
|
||||
m_modelInfo.model = model;
|
||||
} else {
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
fin.seekg(0);
|
||||
fin.close();
|
||||
const bool isGPTJ = magic == 0x67676d6c;
|
||||
const bool isMPT = magic == 0x67676d6d;
|
||||
if (isGPTJ) {
|
||||
m_modelType = LLModelType::GPTJ_;
|
||||
m_modelInfo.model = new GPTJ;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
} else if (isMPT) {
|
||||
m_modelType = LLModelType::MPT_;
|
||||
m_modelInfo.model = new MPT;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
} else {
|
||||
m_modelType = LLModelType::LLAMA_;
|
||||
m_modelInfo.model = new LLamaModel;
|
||||
m_modelInfo.model = LLModel::construct(filePath.toStdString());
|
||||
if (m_modelInfo.model) {
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
switch (m_modelInfo.model->getModelType()[0]) {
|
||||
case 'L': m_modelType = LLModelType::LLAMA_; break;
|
||||
case 'G': m_modelType = LLModelType::GPTJ_; break;
|
||||
case 'M': m_modelType = LLModelType::MPT_; break;
|
||||
default: delete std::exchange(m_modelInfo.model, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
|
@ -112,7 +112,7 @@ void Server::start()
|
||||
);
|
||||
|
||||
m_server->route("/v1/completions", QHttpServerRequest::Method::Post,
|
||||
[=](const QHttpServerRequest &request) {
|
||||
[this](const QHttpServerRequest &request) {
|
||||
if (!LLM::globalInstance()->serverEnabled())
|
||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||
return handleCompletionRequest(request, false);
|
||||
@ -120,7 +120,7 @@ void Server::start()
|
||||
);
|
||||
|
||||
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Post,
|
||||
[=](const QHttpServerRequest &request) {
|
||||
[this](const QHttpServerRequest &request) {
|
||||
if (!LLM::globalInstance()->serverEnabled())
|
||||
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
|
||||
return handleCompletionRequest(request, true);
|
||||
|
Loading…
Reference in New Issue
Block a user