Dlopen backend 5 (#779)

Major change to the backend that allows for pluggable versions of llama.cpp/ggml. This was squashed merged from dlopen_backend_5 where the history is preserved.
This commit is contained in:
AT 2023-05-31 17:04:01 -04:00 committed by GitHub
parent f4a1f7340c
commit 48275d0dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 993 additions and 327 deletions

10
.gitmodules vendored
View File

@ -1,3 +1,9 @@
[submodule "llama.cpp"]
path = gpt4all-backend/llama.cpp
[submodule "llama.cpp-230519"]
path = gpt4all-backend/llama.cpp-230519
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-230511"]
path = gpt4all-backend/llama.cpp-230511
url = https://github.com/manyoso/llama.cpp.git
[submodule "llama.cpp-mainline"]
path = gpt4all-backend/llama.cpp-mainline
url = https://github.com/ggerganov/llama.cpp.git

View File

@ -17,36 +17,97 @@ endif()
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
set(LLMODEL_VERSION_MAJOR 0)
set(LLMODEL_VERSION_MINOR 1)
set(LLMODEL_VERSION_PATCH 1)
set(LLMODEL_VERSION_MINOR 2)
set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
set(BUILD_SHARED_LIBS ON FORCE)
set(CMAKE_VERBOSE_MAKEFILE ON)
if (GPT4ALL_AVX_ONLY)
set(LLAMA_AVX2 OFF CACHE BOOL "llama: enable AVX2" FORCE)
set(LLAMA_F16C OFF CACHE BOOL "llama: enable F16C" FORCE)
set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE)
# Check for IPO support
include(CheckIPOSupported)
check_ipo_supported(RESULT IPO_SUPPORTED OUTPUT IPO_ERROR)
if (NOT IPO_SUPPORTED)
message(WARNING "Interprocedural optimization is not supported by your toolchain! This will lead to bigger file sizes and worse performance: ${IPO_ERROR}")
else()
message(STATUS "Interprocedural optimization support detected")
endif()
add_subdirectory(llama.cpp)
include(llama.cpp.cmake)
set(BUILD_VARIANTS default avxonly)
set(CMAKE_VERBOSE_MAKEFILE ON)
# Go through each build variant
foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS)
# Determine flags
if (BUILD_VARIANT STREQUAL avxonly)
set(GPT4ALL_ALLOW_NON_AVX NO)
else()
set(GPT4ALL_ALLOW_NON_AVX YES)
endif()
set(LLAMA_AVX2 ${GPT4ALL_ALLOW_NON_AVX})
set(LLAMA_F16C ${GPT4ALL_ALLOW_NON_AVX})
set(LLAMA_FMA ${GPT4ALL_ALLOW_NON_AVX})
# Include GGML
include_ggml(llama.cpp-mainline -mainline-${BUILD_VARIANT} ON)
include_ggml(llama.cpp-230511 -230511-${BUILD_VARIANT} ON)
include_ggml(llama.cpp-230519 -230519-${BUILD_VARIANT} ON)
# Function for preparing individual implementations
function(prepare_target TARGET_NAME BASE_LIB)
set(TARGET_NAME ${TARGET_NAME}-${BUILD_VARIANT})
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
# Link to ggml/llama
target_link_libraries(${TARGET_NAME}
PUBLIC ${BASE_LIB}-${BUILD_VARIANT})
# Let it know about its build variant
target_compile_definitions(${TARGET_NAME}
PRIVATE GGML_BUILD_VARIANT="${BUILD_VARIANT}")
# Enable IPO if possible
set_property(TARGET ${TARGET_NAME}
PROPERTY INTERPROCEDURAL_OPTIMIZATION ${IPO_SUPPORTED})
endfunction()
# Add each individual implementations
add_library(llamamodel-mainline-${BUILD_VARIANT} SHARED
llamamodel.cpp)
target_compile_definitions(llamamodel-mainline-${BUILD_VARIANT} PRIVATE
LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
prepare_target(llamamodel-mainline llama-mainline)
add_library(llamamodel-230519-${BUILD_VARIANT} SHARED
llamamodel.cpp)
target_compile_definitions(llamamodel-230519-${BUILD_VARIANT} PRIVATE
LLAMA_VERSIONS===2 LLAMA_DATE=230519)
prepare_target(llamamodel-230519 llama-230519)
add_library(llamamodel-230511-${BUILD_VARIANT} SHARED
llamamodel.cpp)
target_compile_definitions(llamamodel-230511-${BUILD_VARIANT} PRIVATE
LLAMA_VERSIONS=<=1 LLAMA_DATE=230511)
prepare_target(llamamodel-230511 llama-230511)
add_library(gptj-${BUILD_VARIANT} SHARED
gptj.cpp utils.h utils.cpp)
prepare_target(gptj ggml-230511)
add_library(mpt-${BUILD_VARIANT} SHARED
mpt.cpp utils.h utils.cpp)
prepare_target(mpt ggml-230511)
endforeach()
add_library(llmodel
gptj.h gptj.cpp
llamamodel.h llamamodel.cpp
llama.cpp/examples/common.cpp
llmodel.h llmodel_c.h llmodel_c.cpp
mpt.h mpt.cpp
utils.h utils.cpp
llmodel.h llmodel.cpp
llmodel_c.h llmodel_c.cpp
dlhandle.h
)
target_link_libraries(llmodel
PRIVATE llama)
target_compile_definitions(llmodel PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
set_target_properties(llmodel PROPERTIES
VERSION ${PROJECT_VERSION}

101
gpt4all-backend/dlhandle.h Normal file
View File

@ -0,0 +1,101 @@
#ifndef DLHANDLE_H
#define DLHANDLE_H
#ifndef _WIN32
#include <string>
#include <stdexcept>
#include <utility>
#include <dlfcn.h>
class Dlhandle {
void *chandle;
public:
class Exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};
Dlhandle() : chandle(nullptr) {}
Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) {
chandle = dlopen(fpath.c_str(), flags);
if (!chandle) {
throw Exception("dlopen(\""+fpath+"\"): "+dlerror());
}
}
Dlhandle(const Dlhandle& o) = delete;
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
o.chandle = nullptr;
}
~Dlhandle() {
if (chandle) dlclose(chandle);
}
auto operator =(Dlhandle&& o) {
chandle = std::exchange(o.chandle, nullptr);
}
bool is_valid() const {
return chandle != nullptr;
}
operator bool() const {
return is_valid();
}
template<typename T>
T* get(const std::string& fname) {
auto fres = reinterpret_cast<T*>(dlsym(chandle, fname.c_str()));
return (dlerror()==NULL)?fres:nullptr;
}
auto get_fnc(const std::string& fname) {
return get<void*(...)>(fname);
}
};
#else
#include <string>
#include <exception>
#include <stdexcept>
#include <windows.h>
#include <libloaderapi.h>
class Dlhandle {
HMODULE chandle;
public:
class Exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};
Dlhandle() : chandle(nullptr) {}
Dlhandle(const std::string& fpath) {
chandle = LoadLibraryA(fpath.c_str());
if (!chandle) {
throw Exception("dlopen(\""+fpath+"\"): Error");
}
}
Dlhandle(const Dlhandle& o) = delete;
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
o.chandle = nullptr;
}
~Dlhandle() {
if (chandle) FreeLibrary(chandle);
}
bool is_valid() const {
return chandle != nullptr;
}
template<typename T>
T* get(const std::string& fname) {
return reinterpret_cast<T*>(GetProcAddress(chandle, fname.c_str()));
}
auto get_fnc(const std::string& fname) {
return get<void*(...)>(fname);
}
};
#endif
#endif // DLHANDLE_H

View File

@ -1,5 +1,5 @@
#include "gptj.h"
#include "llama.cpp/ggml.h"
#define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#include "gptj_impl.h"
#include "utils.h"
@ -25,10 +25,16 @@
#endif
#include <sstream>
#include <unordered_set>
#include <ggml.h>
namespace {
const char *modelType_ = "MPT";
static const size_t MB = 1024*1024;
}
// default hparams (GPT-J 6B)
static const size_t MB = 1024*1024;
struct gptj_hparams {
int32_t n_vocab = 50400;
int32_t n_ctx = 2048;
@ -229,8 +235,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
}
}
const ggml_type wtype2 = GGML_TYPE_F32;
auto & ctx = model.ctx;
size_t ctx_size = 0;
@ -279,6 +283,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false
};
model.ctx = ggml_init(params);
@ -294,7 +299,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
@ -355,14 +359,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx);
@ -505,8 +501,6 @@ bool gptj_eval(
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
const int d_key = n_embd/n_head;
const size_t init_buf_size = 1024u*MB;
if (!model.buf.addr || model.buf.size < init_buf_size)
model.buf.resize(init_buf_size);
@ -526,10 +520,12 @@ bool gptj_eval(
struct ggml_init_params params = {
.mem_size = model.buf.size,
.mem_buffer = model.buf.addr,
.no_alloc = false
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { .n_threads = n_threads };
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@ -772,8 +768,7 @@ size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, ui
}
const size_t written = out - dest;
const size_t expected = gptj_get_state_size(model);
assert(written == expected);
assert(written == gptj_get_state_size(model));
fflush(stdout);
return written;
}
@ -822,8 +817,7 @@ size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *
}
const size_t nread = in - src;
const size_t expected = gptj_get_state_size(*model);
assert(nread == expected);
assert(nread == gptj_get_state_size(*model));
fflush(stdout);
return nread;
}
@ -840,6 +834,7 @@ struct GPTJPrivate {
GPTJ::GPTJ()
: d_ptr(new GPTJPrivate) {
modelType = modelType_;
d_ptr->model = new gptj_model;
d_ptr->modelLoaded = false;
@ -908,12 +903,6 @@ void GPTJ::prompt(const std::string &prompt,
return;
}
const int64_t t_main_start_us = ggml_time_us();
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
int64_t t_prompt_us = 0;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
@ -942,20 +931,19 @@ void GPTJ::prompt(const std::string &prompt,
// process the prompt in batches
size_t i = 0;
const int64_t t_start_prompt_us = ggml_time_us();
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<gpt_vocab::id> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
@ -966,7 +954,7 @@ void GPTJ::prompt(const std::string &prompt,
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(batch.at(t));
if (!promptCallback(batch.at(t)))
@ -975,10 +963,6 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_past += batch.size();
i = batch_end;
}
t_prompt_us += ggml_time_us() - t_start_prompt_us;
int p_instructFound = 0;
int r_instructFound = 0;
std::string cachedResponse;
std::vector<gpt_vocab::id> cachedTokens;
@ -986,24 +970,20 @@ void GPTJ::prompt(const std::string &prompt,
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
const int n_vocab = d_ptr->model->hparams.n_vocab;
gpt_vocab::id id = 0;
{
const int64_t t_start_sample_us = ggml_time_us();
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
id = gpt_sample_top_k_top_p(n_vocab,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks,
promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty,
d_ptr->rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Check if the context has run out...
@ -1017,29 +997,24 @@ void GPTJ::prompt(const std::string &prompt,
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
const int64_t t_start_predict_us = ggml_time_us();
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
return;
}
t_predict_us += ggml_time_us() - t_start_predict_us;
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/)
goto stop_generating;
return;
const std::string str = d_ptr->vocab.id_to_token[id];
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + str;
if (reversePrompts.find(completed) != reversePrompts.end()) {
goto stop_generating;
}
if (reversePrompts.find(completed) != reversePrompts.end())
return;
// Check if it partially matches our reverse prompts and if so, cache
for (auto s : reversePrompts) {
@ -1059,32 +1034,14 @@ void GPTJ::prompt(const std::string &prompt,
// Empty the cache
for (auto t : cachedTokens) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
goto stop_generating;
return;
}
cachedTokens.clear();
}
stop_generating:
#if 0
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n";
std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n";
std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n";
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n";
std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n";
fflush(stdout);
}
#endif
return;
}
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
@ -1095,7 +1052,7 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
@ -1107,8 +1064,38 @@ void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT __attribute__ ((visibility ("default")))
#endif
extern "C" {
DLL_EXPORT bool is_g4a_backend_model_implementation() {
return true;
}
DLL_EXPORT const char *get_model_type() {
return modelType_;
}
DLL_EXPORT const char *get_build_variant() {
return GGML_BUILD_VARIANT;
}
DLL_EXPORT bool magic_match(std::istream& f) {
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
return magic == 0x67676d6c;
}
DLL_EXPORT LLModel *construct() {
return new GPTJ;
}
}

View File

@ -1,3 +1,6 @@
#ifndef GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#error This file is NOT meant to be included outside of gptj.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define GPTJ_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#endif
#ifndef GPTJ_H
#define GPTJ_H
@ -6,7 +9,7 @@
#include <vector>
#include "llmodel.h"
class GPTJPrivate;
struct GPTJPrivate;
class GPTJ : public LLModel {
public:
GPTJ();

@ -0,0 +1 @@
Subproject commit 5ea43392731040b454c293123839b90e159cbb99

@ -0,0 +1 @@
Subproject commit ea600071cb005267e9e8f2629c1e406dd5fde083

View File

@ -0,0 +1,364 @@
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(LLAMA_STANDALONE ON)
# configure project version
# TODO
else()
set(LLAMA_STANDALONE OFF)
endif()
if (EMSCRIPTEN)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
else()
if (MINGW)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
else()
set(BUILD_SHARED_LIBS_DEFAULT ON)
endif()
endif()
#
# Option list
#
# general
option(LLAMA_STATIC "llama: static link libraries" OFF)
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
option(LLAMA_LTO "llama: enable link time optimization" OFF)
# debug
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
option(LLAMA_GPROF "llama: enable gprof" OFF)
# sanitizers
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
# instruction set specific
#option(LLAMA_AVX "llama: enable AVX" ON)
#option(LLAMA_AVX2 "llama: enable AVX2" ON)
#option(LLAMA_AVX512 "llama: enable AVX512" OFF)
#option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
#option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
#option(LLAMA_FMA "llama: enable FMA" ON)
# in MSVC F16C is implied with AVX2/AVX512
#if (NOT MSVC)
# option(LLAMA_F16C "llama: enable F16C" ON)
#endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
#
# Compile flags
#
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD)
add_compile_options(-fsanitize=thread)
link_libraries(-fsanitize=thread)
endif()
if (LLAMA_SANITIZE_ADDRESS)
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
link_libraries(-fsanitize=address)
endif()
if (LLAMA_SANITIZE_UNDEFINED)
add_compile_options(-fsanitize=undefined)
link_libraries(-fsanitize=undefined)
endif()
endif()
if (APPLE AND LLAMA_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")
add_compile_definitions(GGML_USE_ACCELERATE)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
else()
message(WARNING "Accelerate framework not found")
endif()
endif()
if (LLAMA_OPENBLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
endif()
set(BLA_VENDOR OpenBLAS)
find_package(BLAS)
if (BLAS_FOUND)
message(STATUS "OpenBLAS found")
add_compile_definitions(GGML_USE_OPENBLAS)
add_link_options(${BLAS_LIBRARIES})
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas)
# find header file
set(OPENBLAS_INCLUDE_SEARCH_PATHS
/usr/include
/usr/include/openblas
/usr/include/openblas-base
/usr/local/include
/usr/local/include/openblas
/usr/local/include/openblas-base
/opt/OpenBLAS/include
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/include
)
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
add_compile_options(-I${OPENBLAS_INC})
else()
message(WARNING "OpenBLAS not found")
endif()
endif()
if (LLAMA_ALL_WARNINGS)
if (NOT MSVC)
set(c_flags
-Wall
-Wextra
-Wpedantic
-Wcast-qual
-Wdouble-promotion
-Wshadow
-Wstrict-prototypes
-Wpointer-arith
)
set(cxx_flags
-Wall
-Wextra
-Wpedantic
-Wcast-qual
-Wno-unused-function
-Wno-multichar
)
else()
# todo : msvc
endif()
add_compile_options(
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
)
endif()
if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
endif()
if (LLAMA_LTO)
include(CheckIPOSupported)
check_ipo_supported(RESULT result OUTPUT output)
if (result)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
else()
message(WARNING "IPO is not supported: ${output}")
endif()
endif()
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (NOT MSVC)
if (LLAMA_STATIC)
add_link_options(-static)
if (MINGW)
add_link_options(-static-libgcc -static-libstdc++)
endif()
endif()
if (LLAMA_GPROF)
add_compile_options(-pg)
endif()
if (LLAMA_NATIVE)
add_compile_options(-march=native)
endif()
endif()
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
if (MSVC)
# TODO: arm msvc?
else()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
add_compile_options(-mcpu=native)
endif()
# TODO: armv6,7,8 version specific flags
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
add_compile_options(-mf16c)
endif()
if (LLAMA_FMA)
add_compile_options(-mfma)
endif()
if (LLAMA_AVX)
add_compile_options(-mavx)
endif()
if (LLAMA_AVX2)
add_compile_options(-mavx2)
endif()
if (LLAMA_AVX512)
add_compile_options(-mavx512f)
add_compile_options(-mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
add_compile_options(-mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_options(-mavx512vnni)
endif()
endif()
else()
# TODO: support PowerPC
message(STATUS "Unknown architecture")
endif()
#
# Build libraries
#
if (LLAMA_CUBLAS AND EXISTS ${DIRECTORY}/ggml-cuda.h)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
if (LLAMA_CLBLAST AND EXISTS ${DIRECTORY}/ggml-opencl.h)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ${DIRECTORY}/ggml-opencl.c ${DIRECTORY}/ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
else()
message(WARNING "CLBlast not found")
endif()
endif()
add_library(ggml${SUFFIX} OBJECT
${DIRECTORY}/ggml.c
${DIRECTORY}/ggml.h
${GGML_CUDA_SOURCES}
${GGML_OPENCL_SOURCES})
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
if (BUILD_SHARED_LIBS)
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
if (WITH_LLAMA)
# Backwards compatibility with old llama.cpp versions
set(LLAMA_UTIL_SOURCE_FILE llama-util.h)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
set(LLAMA_UTIL_SOURCE_FILE llama_util.h)
endif()
add_library(llama${SUFFIX}
${DIRECTORY}/llama.cpp
${DIRECTORY}/llama.h
${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
if (BUILD_SHARED_LIBS)
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif()
endif()
if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
if (WITH_LLAMA)
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
endif()
endif()
endfunction()

View File

@ -1,8 +1,5 @@
#include "llamamodel.h"
#include "llama.cpp/examples/common.h"
#include "llama.cpp/llama.h"
#include "llama.cpp/ggml.h"
#define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#include "llamamodel_impl.h"
#include <cassert>
#include <cmath>
@ -28,16 +25,77 @@
#include <thread>
#include <unordered_set>
#include <llama.h>
#include <ggml.h>
namespace {
const char *modelType_ = "LLaMA";
}
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
#if LLAMA_DATE <= 230511
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
#endif
#if LLAMA_DATE >= 230519
// sampling parameters
float tfs_z = 1.0f; // 1.0 = disabled
float typical_p = 1.0f; // 1.0 = disabled
#endif
std::string prompt = "";
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
};
#if LLAMA_DATE >= 230519
static int llama_sample_top_p_top_k(
llama_context *ctx,
const llama_token *last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float temp,
float repeat_penalty) {
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (int token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty);
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
return llama_sample_token(ctx, &candidates_p);
}
#endif
struct LLamaPrivate {
const std::string modelPath;
bool modelLoaded;
llama_context *ctx = nullptr;
llama_context_params params;
int64_t n_threads = 0;
bool empty = true;
};
LLamaModel::LLamaModel()
: d_ptr(new LLamaPrivate) {
modelType = modelType_;
d_ptr->modelLoaded = false;
}
@ -49,14 +107,12 @@ bool LLamaModel::loadModel(const std::string &modelPath)
gpt_params params;
d_ptr->params.n_ctx = 2048;
d_ptr->params.n_parts = params.n_parts;
d_ptr->params.seed = params.seed;
d_ptr->params.f16_kv = params.memory_f16;
d_ptr->params.use_mmap = params.use_mmap;
#if defined (__APPLE__)
d_ptr->params.use_mlock = true;
#else
d_ptr->params.use_mlock = params.use_mlock;
#if LLAMA_DATE <= 230511
d_ptr->params.n_parts = params.n_parts;
#endif
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
@ -75,8 +131,7 @@ void LLamaModel::setThreadCount(int32_t n_threads) {
d_ptr->n_threads = n_threads;
}
int32_t LLamaModel::threadCount() const
{
int32_t LLamaModel::threadCount() const {
return d_ptr->n_threads;
}
@ -102,7 +157,8 @@ size_t LLamaModel::saveState(uint8_t *dest) const
size_t LLamaModel::restoreState(const uint8_t *src)
{
return llama_set_state_data(d_ptr->ctx, src);
// const_cast is required, see: https://github.com/ggerganov/llama.cpp/pull/1540
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
void LLamaModel::prompt(const std::string &prompt,
@ -123,7 +179,11 @@ void LLamaModel::prompt(const std::string &prompt,
params.prompt.insert(0, 1, ' ');
// tokenize the prompt
auto embd_inp = ::llama_tokenize(d_ptr->ctx, params.prompt, false);
std::vector<llama_token> embd_inp(params.prompt.size() + 4);
int n = llama_tokenize(d_ptr->ctx, params.prompt.c_str(), embd_inp.data(), embd_inp.size(), d_ptr->empty);
assert(n >= 0);
embd_inp.resize(n);
d_ptr->empty = false;
// save the context size
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
@ -143,20 +203,19 @@ void LLamaModel::prompt(const std::string &prompt,
// process the prompt in batches
size_t i = 0;
const int64_t t_start_prompt_us = ggml_time_us();
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<llama_token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
@ -166,7 +225,7 @@ void LLamaModel::prompt(const std::string &prompt,
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(batch.at(t));
if (!promptCallback(batch.at(t)))
@ -179,10 +238,9 @@ void LLamaModel::prompt(const std::string &prompt,
std::string cachedResponse;
std::vector<llama_token> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant" };
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
@ -209,7 +267,6 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == llama_token_eos())
return;
@ -240,7 +297,7 @@ void LLamaModel::prompt(const std::string &prompt,
// Empty the cache
for (auto t : cachedTokens) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
if (!responseCallback(t, llama_token_to_str(d_ptr->ctx, t)))
@ -258,7 +315,7 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
@ -269,8 +326,43 @@ void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT __attribute__ ((visibility ("default")))
#endif
extern "C" {
DLL_EXPORT bool is_g4a_backend_model_implementation() {
return true;
}
DLL_EXPORT const char *get_model_type() {
return modelType_;
}
DLL_EXPORT const char *get_build_variant() {
return GGML_BUILD_VARIANT;
}
DLL_EXPORT bool magic_match(std::istream& f) {
// Check magic
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != 0x67676a74) return false;
// Check version
uint32_t version = 0;
f.read(reinterpret_cast<char*>(&version), sizeof(version));
return version LLAMA_VERSIONS;
}
DLL_EXPORT LLModel *construct() {
return new LLamaModel;
}
}

View File

@ -1,3 +1,6 @@
#ifndef LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#error This file is NOT meant to be included outside of llamamodel.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#endif
#ifndef LLAMAMODEL_H
#define LLAMAMODEL_H
@ -6,7 +9,7 @@
#include <vector>
#include "llmodel.h"
class LLamaPrivate;
struct LLamaPrivate;
class LLamaModel : public LLModel {
public:
LLamaModel();
@ -33,4 +36,4 @@ private:
LLamaPrivate *d_ptr;
};
#endif // LLAMAMODEL_H
#endif // LLAMAMODEL_H

View File

@ -0,0 +1,90 @@
#include "llmodel.h"
#include "dlhandle.h"
#include <string>
#include <vector>
#include <fstream>
#include <filesystem>
static Dlhandle *get_implementation(std::ifstream& f, const std::string& buildVariant) {
// Collect all model implementation libraries
// NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the
// individual models without the cleanup of the static list interfering
static auto* libs = new std::vector<Dlhandle>([] () {
std::vector<Dlhandle> fres;
auto search_in_directory = [&](const std::filesystem::path& path) {
// Iterate over all libraries
for (const auto& f : std::filesystem::directory_iterator(path)) {
// Get path
const auto& p = f.path();
// Check extension
if (p.extension() != LIB_FILE_EXT) continue;
// Add to list if model implementation
try {
Dlhandle dl(p.string());
if (dl.get<bool(uint32_t)>("is_g4a_backend_model_implementation")) {
fres.emplace_back(std::move(dl));
}
} catch (...) {}
}
};
search_in_directory(".");
#if defined(__APPLE__)
search_in_directory("../../../");
#endif
return fres;
}());
// Iterate over all libraries
for (auto& dl : *libs) {
f.seekg(0);
// Check that magic matches
auto magic_match = dl.get<bool(std::ifstream&)>("magic_match");
if (!magic_match || !magic_match(f)) {
continue;
}
// Check that build variant is correct
auto get_build_variant = dl.get<const char *()>("get_build_variant");
if (buildVariant != (get_build_variant?get_build_variant():"default")) {
continue;
}
// Looks like we're good to go, return this dlhandle
return &dl;
}
// Nothing found, so return nothing
return nullptr;
}
static bool requires_avxonly() {
#ifdef __x86_64__
return !__builtin_cpu_supports("avx2") && !__builtin_cpu_supports("fma");
#else
return false; // Don't know how to handle ARM
#endif
}
LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) {
//TODO: Auto-detect
if (buildVariant == "auto") {
if (requires_avxonly()) {
buildVariant = "avxonly";
} else {
buildVariant = "default";
}
}
// Read magic
std::ifstream f(modelPath, std::ios::binary);
if (!f) return nullptr;
// Get correct implementation
auto impl = get_implementation(f, buildVariant);
if (!impl) return nullptr;
f.close();
// Get inference constructor
auto constructor = impl->get<LLModel *()>("construct");
if (!constructor) return nullptr;
// Construct llmodel implementation
auto fres = constructor();
// Return final instance
return fres;
}

View File

@ -1,21 +1,23 @@
#ifndef LLMODEL_H
#define LLMODEL_H
#include <string>
#include <functional>
#include <vector>
#include <cstdint>
class LLModel {
public:
explicit LLModel() {}
virtual ~LLModel() {}
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "default");
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t *dest) const { return 0; }
virtual size_t restoreState(const uint8_t *src) { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
struct PromptContext {
std::vector<float> logits; // logits of current context
std::vector<int32_t> tokens; // current tokens in the context window
@ -36,12 +38,18 @@ public:
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {}
virtual void setThreadCount(int32_t /*n_threads*/) {}
virtual int32_t threadCount() const { return 1; }
const char *getModelType() const {
return modelType;
}
protected:
virtual void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) = 0;
const char *modelType;
};
#endif // LLMODEL_H

View File

@ -1,79 +1,57 @@
#include "llmodel_c.h"
#include "llmodel.h"
#include <cstring>
#include <cerrno>
#include <utility>
#include "gptj.h"
#include "llamamodel.h"
#include "mpt.h"
struct LLModelWrapper {
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
};
llmodel_model llmodel_gptj_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new GPTJ;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_gptj_destroy(llmodel_model gptj)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
delete wrapper->llModel;
delete wrapper;
}
thread_local static std::string last_error_message;
llmodel_model llmodel_mpt_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new MPT;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_mpt_destroy(llmodel_model mpt)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(mpt);
delete wrapper->llModel;
delete wrapper;
}
llmodel_model llmodel_llama_create()
{
LLModelWrapper *wrapper = new LLModelWrapper;
wrapper->llModel = new LLamaModel;
return reinterpret_cast<void*>(wrapper);
}
void llmodel_llama_destroy(llmodel_model llama)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
delete wrapper->llModel;
delete wrapper;
}
llmodel_model llmodel_model_create(const char *model_path) {
auto fres = llmodel_model_create2(model_path, "auto", nullptr);
if (!fres) {
fprintf(stderr, "Invalid model file\n");
}
return fres;
}
uint32_t magic;
llmodel_model model;
FILE *f = fopen(model_path, "rb");
fread(&magic, sizeof(magic), 1, f);
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error) {
auto wrapper = new LLModelWrapper;
llmodel_error new_error{};
if (magic == 0x67676d6c) { model = llmodel_gptj_create(); }
else if (magic == 0x67676a74) { model = llmodel_llama_create(); }
else if (magic == 0x67676d6d) { model = llmodel_mpt_create(); }
else {fprintf(stderr, "Invalid model file\n");}
fclose(f);
return model;
try {
wrapper->llModel = LLModel::construct(model_path, build_variant);
} catch (const std::exception& e) {
new_error.code = EINVAL;
last_error_message = e.what();
}
if (!wrapper->llModel) {
delete std::exchange(wrapper, nullptr);
// Get errno and error message if none
if (new_error.code == 0) {
new_error.code = errno;
last_error_message = strerror(errno);
}
// Set message pointer
new_error.message = last_error_message.c_str();
// Set error argument
if (error) *error = new_error;
}
return reinterpret_cast<llmodel_model*>(wrapper);
}
void llmodel_model_destroy(llmodel_model model) {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
const std::type_info &modelTypeInfo = typeid(*wrapper->llModel);
if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); }
if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); }
if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); }
delete wrapper->llModel;
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)
@ -84,20 +62,20 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path)
bool llmodel_isModelLoaded(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->isModelLoaded();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->stateSize();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->saveState(dest);
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->saveState(dest);
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
@ -181,6 +159,6 @@ void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
int32_t llmodel_threadCount(llmodel_model model)
{
const auto *llm = reinterpret_cast<LLModelWrapper*>(model)->llModel;
return llm->threadCount();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->threadCount();
}

View File

@ -5,6 +5,15 @@
#include <stddef.h>
#include <stdbool.h>
#ifdef __GNUC__
#define DEPRECATED __attribute__ ((deprecated))
#elif defined(_MSC_VER)
#define DEPRECATED __declspec(deprecated)
#else
#pragma message("WARNING: You need to implement DEPRECATED for this compiler")
#define DEPRECATED
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -14,13 +23,24 @@ extern "C" {
*/
typedef void *llmodel_model;
/**
* Structure containing any errors that may eventually occur
*/
struct llmodel_error {
const char *message; // Human readable error description; Thread-local; guaranteed to survive until next llmodel C API call
int code; // errno; 0 if none
};
#ifndef __cplusplus
typedef struct llmodel_error llmodel_error;
#endif
/**
* llmodel_prompt_context structure for holding the prompt context.
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
* behavior.
*/
typedef struct {
struct llmodel_prompt_context {
float *logits; // logits of current context
size_t logits_size; // the size of the raw logits vector
int32_t *tokens; // current tokens in the context window
@ -35,7 +55,10 @@ typedef struct {
float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize
float context_erase; // percent of context to erase if we exceed the context window
} llmodel_prompt_context;
};
#ifndef __cplusplus
typedef struct llmodel_prompt_context llmodel_prompt_context;
#endif
/**
* Callback type for prompt processing.
@ -60,48 +83,22 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
/**
* Create a GPTJ instance.
* @return A pointer to the GPTJ instance.
* Create a llmodel instance.
* Recognises correct model type from file at model_path
* @param model_path A string representing the path to the model file.
* @return A pointer to the llmodel_model instance; NULL on error.
*/
llmodel_model llmodel_gptj_create();
/**
* Destroy a GPTJ instance.
* @param gptj A pointer to the GPTJ instance.
*/
void llmodel_gptj_destroy(llmodel_model gptj);
/**
* Create a MPT instance.
* @return A pointer to the MPT instance.
*/
llmodel_model llmodel_mpt_create();
/**
* Destroy a MPT instance.
* @param gptj A pointer to the MPT instance.
*/
void llmodel_mpt_destroy(llmodel_model mpt);
/**
* Create a LLAMA instance.
* @return A pointer to the LLAMA instance.
*/
llmodel_model llmodel_llama_create();
/**
* Destroy a LLAMA instance.
* @param llama A pointer to the LLAMA instance.
*/
void llmodel_llama_destroy(llmodel_model llama);
DEPRECATED llmodel_model llmodel_model_create(const char *model_path);
/**
* Create a llmodel instance.
* Recognises correct model type from file at model_path
* @param model_path A string representing the path to the model file.
* @return A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file; will only be used to detect model type.
* @param build_variant A string representing the implementation to use (auto, default, avxonly, ...),
* @param error A pointer to a llmodel_error; will only be set on error.
* @return A pointer to the llmodel_model instance; NULL on error.
*/
llmodel_model llmodel_model_create(const char *model_path);
llmodel_model llmodel_model_create2(const char *model_path, const char *build_variant, llmodel_error *error);
/**
* Destroy a llmodel instance.
@ -110,7 +107,6 @@ llmodel_model llmodel_model_create(const char *model_path);
*/
void llmodel_model_destroy(llmodel_model model);
/**
* Load a model from a file.
* @param model A pointer to the llmodel_model instance.

View File

@ -1,5 +1,5 @@
#include "mpt.h"
#include "llama.cpp/ggml.h"
#define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#include "mpt_impl.h"
#include "utils.h"
@ -28,8 +28,14 @@
#include <thread>
#include <unordered_set>
#include <regex>
#include <ggml.h>
namespace {
const char *modelType_ = "MPT";
static const size_t MB = 1024*1024;
}
// default hparams (MPT 7B)
struct mpt_hparams {
@ -293,7 +299,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int expand = hparams.expand;
@ -331,14 +336,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx);
@ -457,9 +454,6 @@ bool mpt_eval(
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int expand = hparams.expand;
const int d_key = n_embd/n_head;
const size_t init_buf_size = 1024u*MB;
if (!model.buf.addr || model.buf.size < init_buf_size)
@ -480,10 +474,12 @@ bool mpt_eval(
struct ggml_init_params params = {
.mem_size = model.buf.size,
.mem_buffer = model.buf.addr,
.no_alloc = false
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { .n_threads = n_threads };
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@ -695,8 +691,7 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
}
const size_t written = out - dest;
const size_t expected = mpt_get_state_size(model);
assert(written == expected);
assert(written == mpt_get_state_size(model));
fflush(stdout);
return written;
}
@ -745,8 +740,7 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
}
const size_t nread = in - src;
const size_t expected = mpt_get_state_size(*model);
assert(nread == expected);
assert(nread == mpt_get_state_size(*model));
fflush(stdout);
return nread;
}
@ -764,6 +758,7 @@ struct MPTPrivate {
MPT::MPT()
: d_ptr(new MPTPrivate) {
modelType = modelType_;
d_ptr->model = new mpt_model;
d_ptr->modelLoaded = false;
@ -833,12 +828,6 @@ void MPT::prompt(const std::string &prompt,
return;
}
const int64_t t_main_start_us = ggml_time_us();
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
int64_t t_prompt_us = 0;
// tokenize the prompt
std::vector<int> embd_inp = gpt_tokenize(d_ptr->vocab, prompt);
@ -867,20 +856,19 @@ void MPT::prompt(const std::string &prompt,
// process the prompt in batches
size_t i = 0;
const int64_t t_start_prompt_us = ggml_time_us();
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<int> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "MPT: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
@ -891,7 +879,7 @@ void MPT::prompt(const std::string &prompt,
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(batch.at(t));
if (!promptCallback(batch.at(t)))
@ -900,10 +888,6 @@ void MPT::prompt(const std::string &prompt,
promptCtx.n_past += batch.size();
i = batch_end;
}
t_prompt_us += ggml_time_us() - t_start_prompt_us;
int p_instructFound = 0;
int r_instructFound = 0;
std::string cachedResponse;
std::vector<int> cachedTokens;
@ -911,24 +895,20 @@ void MPT::prompt(const std::string &prompt,
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
// predict next tokens
int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
const int n_vocab = d_ptr->model->hparams.n_vocab;
int id = 0;
{
const int64_t t_start_sample_us = ggml_time_us();
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
id = gpt_sample_top_k_top_p(n_vocab,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks,
promptCtx.logits,
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty,
d_ptr->rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Check if the context has run out...
@ -942,33 +922,28 @@ void MPT::prompt(const std::string &prompt,
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
const int64_t t_start_predict_us = ggml_time_us();
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
return;
}
t_predict_us += ggml_time_us() - t_start_predict_us;
promptCtx.n_past += 1;
// display text
++totalPredictions;
// display tex
// mpt-7b-chat has special token for end
if (d_ptr->has_im_end && id == d_ptr->vocab.token_to_id["<|im_end|>"])
goto stop_generating;
return;
if (id == 0 /*end of text*/)
goto stop_generating;
return;
const std::string str = d_ptr->vocab.id_to_token[id];
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + str;
if (reversePrompts.find(completed) != reversePrompts.end()) {
goto stop_generating;
}
if (reversePrompts.find(completed) != reversePrompts.end())
return;
// Check if it partially matches our reverse prompts and if so, cache
for (auto s : reversePrompts) {
@ -988,32 +963,14 @@ void MPT::prompt(const std::string &prompt,
// Empty the cache
for (auto t : cachedTokens) {
if (promptCtx.tokens.size() == promptCtx.n_ctx)
if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx)
promptCtx.tokens.erase(promptCtx.tokens.begin());
promptCtx.tokens.push_back(t);
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
goto stop_generating;
return;
}
cachedTokens.clear();
}
stop_generating:
#if 0
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
std::cout << "GPT-J INFO: mem per token = " << mem_per_token << " bytes\n";
std::cout << "GPT-J INFO: sample time = " << t_sample_us/1000.0f << " ms\n";
std::cout << "GPT-J INFO: prompt time = " << t_prompt_us/1000.0f << " ms\n";
std::cout << "GPT-J INFO: predict time = " << t_predict_us/1000.0f << " ms / " << t_predict_us/1000.0f/totalPredictions << " ms per token\n";
std::cout << "GPT-J INFO: total time = " << (t_main_end_us - t_main_start_us)/1000.0f << " ms\n";
fflush(stdout);
}
#endif
return;
}
void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
@ -1024,7 +981,7 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<int> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
@ -1036,8 +993,38 @@ void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)>
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
}
#if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT __attribute__ ((visibility ("default")))
#endif
extern "C" {
DLL_EXPORT bool is_g4a_backend_model_implementation() {
return true;
}
DLL_EXPORT const char *get_model_type() {
return modelType_;
}
DLL_EXPORT const char *get_build_variant() {
return GGML_BUILD_VARIANT;
}
DLL_EXPORT bool magic_match(std::istream& f) {
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
return magic == 0x67676d6d;
}
DLL_EXPORT LLModel *construct() {
return new MPT;
}
}

View File

@ -1,3 +1,6 @@
#ifndef MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#error This file is NOT meant to be included outside of mpt.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define MPT_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE
#endif
#ifndef MPT_H
#define MPT_H
@ -6,7 +9,7 @@
#include <vector>
#include "llmodel.h"
class MPTPrivate;
struct MPTPrivate;
class MPT : public LLModel {
public:
MPT();

View File

@ -218,7 +218,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,

View File

@ -79,7 +79,6 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// TODO: not sure if this implementation is correct
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,

View File

@ -55,10 +55,10 @@ get_filename_component(Qt6_ROOT_DIR "${Qt6_ROOT_DIR}/.." ABSOLUTE)
message(STATUS "qmake binary: ${QMAKE_EXECUTABLE}")
message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}")
add_subdirectory(../gpt4all-backend llmodel)
set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
add_subdirectory(../gpt4all-backend llmodel)
qt_add_executable(chat
main.cpp
chat.h chat.cpp
@ -146,7 +146,6 @@ endif()
install(TARGETS chat DESTINATION bin COMPONENT ${COMPONENT_NAME_MAIN})
install(TARGETS llmodel DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
install(TARGETS llama DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN})
set(CPACK_GENERATOR "IFW")
set(CPACK_VERBATIM_VARIABLES YES)

View File

@ -2,9 +2,7 @@
#include "chat.h"
#include "download.h"
#include "network.h"
#include "../gpt4all-backend/gptj.h"
#include "../gpt4all-backend/llamamodel.h"
#include "../gpt4all-backend/mpt.h"
#include "../gpt4all-backend/llmodel.h"
#include "chatgpt.h"
#include <QCoreApplication>
@ -215,25 +213,15 @@ bool ChatLLM::loadModel(const QString &modelName)
model->setAPIKey(apiKey);
m_modelInfo.model = model;
} else {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
fin.seekg(0);
fin.close();
const bool isGPTJ = magic == 0x67676d6c;
const bool isMPT = magic == 0x67676d6d;
if (isGPTJ) {
m_modelType = LLModelType::GPTJ_;
m_modelInfo.model = new GPTJ;
m_modelInfo.model->loadModel(filePath.toStdString());
} else if (isMPT) {
m_modelType = LLModelType::MPT_;
m_modelInfo.model = new MPT;
m_modelInfo.model->loadModel(filePath.toStdString());
} else {
m_modelType = LLModelType::LLAMA_;
m_modelInfo.model = new LLamaModel;
m_modelInfo.model = LLModel::construct(filePath.toStdString());
if (m_modelInfo.model) {
m_modelInfo.model->loadModel(filePath.toStdString());
switch (m_modelInfo.model->getModelType()[0]) {
case 'L': m_modelType = LLModelType::LLAMA_; break;
case 'G': m_modelType = LLModelType::GPTJ_; break;
case 'M': m_modelType = LLModelType::MPT_; break;
default: delete std::exchange(m_modelInfo.model, nullptr);
}
}
}
#if defined(DEBUG_MODEL_LOADING)

View File

@ -112,7 +112,7 @@ void Server::start()
);
m_server->route("/v1/completions", QHttpServerRequest::Method::Post,
[=](const QHttpServerRequest &request) {
[this](const QHttpServerRequest &request) {
if (!LLM::globalInstance()->serverEnabled())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return handleCompletionRequest(request, false);
@ -120,7 +120,7 @@ void Server::start()
);
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Post,
[=](const QHttpServerRequest &request) {
[this](const QHttpServerRequest &request) {
if (!LLM::globalInstance()->serverEnabled())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return handleCompletionRequest(request, true);