2023-04-26 22:44:52 -04:00
|
|
|
#include "llmodel_c.h"
|
2024-05-31 16:34:54 -04:00
|
|
|
|
2023-05-31 17:04:01 -04:00
|
|
|
#include "llmodel.h"
|
|
|
|
|
2024-05-31 16:34:54 -04:00
|
|
|
#include <algorithm>
|
|
|
|
#include <cstdio>
|
|
|
|
#include <cstdlib>
|
2024-02-21 15:45:32 -05:00
|
|
|
#include <cstring>
|
2024-05-31 16:34:54 -04:00
|
|
|
#include <exception>
|
|
|
|
#include <functional>
|
2024-02-21 15:45:32 -05:00
|
|
|
#include <iostream>
|
2024-04-04 14:52:13 -04:00
|
|
|
#include <memory>
|
2024-03-13 18:09:24 -04:00
|
|
|
#include <optional>
|
2024-05-31 16:34:54 -04:00
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
2023-04-26 22:44:52 -04:00
|
|
|
|
2023-04-27 09:43:24 -04:00
|
|
|
struct LLModelWrapper {
|
|
|
|
LLModel *llModel = nullptr;
|
|
|
|
LLModel::PromptContext promptContext;
|
2023-06-12 12:41:22 -04:00
|
|
|
~LLModelWrapper() { delete llModel; }
|
2023-04-27 09:43:24 -04:00
|
|
|
};
|
|
|
|
|
2024-06-24 18:49:23 -04:00
|
|
|
llmodel_model llmodel_model_create(const char *model_path)
|
|
|
|
{
|
2023-11-07 11:20:14 -05:00
|
|
|
const char *error;
|
|
|
|
auto fres = llmodel_model_create2(model_path, "auto", &error);
|
2023-05-31 17:04:01 -04:00
|
|
|
if (!fres) {
|
2023-11-07 11:20:14 -05:00
|
|
|
fprintf(stderr, "Unable to instantiate model: %s\n", error);
|
2023-05-31 17:04:01 -04:00
|
|
|
}
|
|
|
|
return fres;
|
|
|
|
}
|
|
|
|
|
2024-06-24 18:49:23 -04:00
|
|
|
static void llmodel_set_error(const char **errptr, const char *message)
|
|
|
|
{
|
2024-03-19 10:56:14 -04:00
|
|
|
thread_local static std::string last_error_message;
|
|
|
|
if (errptr) {
|
|
|
|
last_error_message = message;
|
|
|
|
*errptr = last_error_message.c_str();
|
|
|
|
}
|
|
|
|
}
|
2023-05-31 17:04:01 -04:00
|
|
|
|
2024-06-24 18:49:23 -04:00
|
|
|
llmodel_model llmodel_model_create2(const char *model_path, const char *backend, const char **error)
|
|
|
|
{
|
2024-03-19 10:56:14 -04:00
|
|
|
LLModel *llModel;
|
2023-05-31 17:04:01 -04:00
|
|
|
try {
|
2024-05-15 15:27:50 -04:00
|
|
|
llModel = LLModel::Implementation::construct(model_path, backend);
|
2023-05-31 17:04:01 -04:00
|
|
|
} catch (const std::exception& e) {
|
2024-03-19 10:56:14 -04:00
|
|
|
llmodel_set_error(error, e.what());
|
|
|
|
return nullptr;
|
2023-05-31 17:04:01 -04:00
|
|
|
}
|
|
|
|
|
2024-03-19 10:56:14 -04:00
|
|
|
auto wrapper = new LLModelWrapper;
|
|
|
|
wrapper->llModel = llModel;
|
2024-03-13 18:09:24 -04:00
|
|
|
return wrapper;
|
2023-05-16 11:36:46 -04:00
|
|
|
}
|
|
|
|
|
2024-06-24 18:49:23 -04:00
|
|
|
void llmodel_model_destroy(llmodel_model model)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
delete static_cast<LLModelWrapper *>(model);
|
2023-05-16 11:36:46 -04:00
|
|
|
}
|
|
|
|
|
2024-01-31 14:17:44 -05:00
|
|
|
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx, int ngl)
|
2023-06-26 15:17:34 -04:00
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2024-01-31 14:17:44 -05:00
|
|
|
return wrapper->llModel->requiredMem(model_path, n_ctx, ngl);
|
2023-06-26 15:17:34 -04:00
|
|
|
}
|
|
|
|
|
2024-01-31 14:17:44 -05:00
|
|
|
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx, int ngl)
|
2023-04-27 09:43:24 -04:00
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2024-02-21 15:45:32 -05:00
|
|
|
|
|
|
|
std::string modelPath(model_path);
|
|
|
|
if (wrapper->llModel->isModelBlacklisted(modelPath)) {
|
|
|
|
size_t slash = modelPath.find_last_of("/\\");
|
|
|
|
auto basename = slash == std::string::npos ? modelPath : modelPath.substr(slash + 1);
|
|
|
|
std::cerr << "warning: model '" << basename << "' is out-of-date, please check for an updated version\n";
|
|
|
|
}
|
|
|
|
return wrapper->llModel->loadModel(modelPath, n_ctx, ngl);
|
2023-04-27 09:43:24 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
bool llmodel_isModelLoaded(llmodel_model model)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-05-31 17:04:01 -04:00
|
|
|
return wrapper->llModel->isModelLoaded();
|
2023-04-27 09:43:24 -04:00
|
|
|
}
|
|
|
|
|
2023-05-04 15:31:41 -04:00
|
|
|
uint64_t llmodel_get_state_size(llmodel_model model)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-05-31 17:04:01 -04:00
|
|
|
return wrapper->llModel->stateSize();
|
2023-05-04 15:31:41 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-05-31 17:04:01 -04:00
|
|
|
return wrapper->llModel->saveState(dest);
|
2023-05-04 15:31:41 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-05-04 15:31:41 -04:00
|
|
|
return wrapper->llModel->restoreState(src);
|
|
|
|
}
|
|
|
|
|
2023-04-27 09:43:24 -04:00
|
|
|
void llmodel_prompt(llmodel_model model, const char *prompt,
|
2024-02-21 15:45:32 -05:00
|
|
|
const char *prompt_template,
|
2023-05-21 15:43:45 -04:00
|
|
|
llmodel_prompt_callback prompt_callback,
|
2023-04-27 11:08:15 -04:00
|
|
|
llmodel_response_callback response_callback,
|
|
|
|
llmodel_recalculate_callback recalculate_callback,
|
2024-02-21 15:45:32 -05:00
|
|
|
llmodel_prompt_context *ctx,
|
2024-03-06 13:32:24 -05:00
|
|
|
bool special,
|
|
|
|
const char *fake_reply)
|
2023-04-27 09:43:24 -04:00
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-04-27 09:43:24 -04:00
|
|
|
|
2024-03-13 18:09:24 -04:00
|
|
|
auto response_func = [response_callback](int32_t token_id, const std::string &response) {
|
|
|
|
return response_callback(token_id, response.c_str());
|
|
|
|
};
|
2023-04-27 09:43:24 -04:00
|
|
|
|
2023-06-30 16:02:02 -04:00
|
|
|
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
|
|
|
|
wrapper->promptContext.tokens.resize(ctx->n_past);
|
|
|
|
|
2023-04-27 09:43:24 -04:00
|
|
|
// Copy the C prompt context
|
|
|
|
wrapper->promptContext.n_past = ctx->n_past;
|
|
|
|
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
|
|
|
wrapper->promptContext.n_predict = ctx->n_predict;
|
|
|
|
wrapper->promptContext.top_k = ctx->top_k;
|
|
|
|
wrapper->promptContext.top_p = ctx->top_p;
|
2024-02-24 17:51:34 -05:00
|
|
|
wrapper->promptContext.min_p = ctx->min_p;
|
2023-04-27 09:43:24 -04:00
|
|
|
wrapper->promptContext.temp = ctx->temp;
|
|
|
|
wrapper->promptContext.n_batch = ctx->n_batch;
|
|
|
|
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
|
|
|
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
|
|
|
wrapper->promptContext.contextErase = ctx->context_erase;
|
|
|
|
|
2024-03-06 13:32:24 -05:00
|
|
|
std::string fake_reply_str;
|
|
|
|
if (fake_reply) { fake_reply_str = fake_reply; }
|
|
|
|
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
|
|
|
|
2023-04-27 09:43:24 -04:00
|
|
|
// Call the C++ prompt method
|
2024-03-13 18:09:24 -04:00
|
|
|
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
|
|
|
|
wrapper->promptContext, special, fake_reply_p);
|
2023-04-27 09:43:24 -04:00
|
|
|
|
|
|
|
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
|
|
|
// which involves no copies
|
|
|
|
ctx->tokens = wrapper->promptContext.tokens.data();
|
|
|
|
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
|
|
|
|
|
|
|
// Update the rest of the C prompt context
|
|
|
|
ctx->n_past = wrapper->promptContext.n_past;
|
|
|
|
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
|
|
|
ctx->n_predict = wrapper->promptContext.n_predict;
|
|
|
|
ctx->top_k = wrapper->promptContext.top_k;
|
|
|
|
ctx->top_p = wrapper->promptContext.top_p;
|
2024-02-24 17:51:34 -05:00
|
|
|
ctx->min_p = wrapper->promptContext.min_p;
|
2023-04-27 09:43:24 -04:00
|
|
|
ctx->temp = wrapper->promptContext.temp;
|
|
|
|
ctx->n_batch = wrapper->promptContext.n_batch;
|
|
|
|
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
|
|
|
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
|
|
|
ctx->context_erase = wrapper->promptContext.contextErase;
|
|
|
|
}
|
|
|
|
|
2024-03-13 18:09:24 -04:00
|
|
|
float *llmodel_embed(
|
|
|
|
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
|
2024-04-12 16:00:39 -04:00
|
|
|
size_t *token_count, bool do_mean, bool atlas, llmodel_emb_cancel_callback cancel_cb, const char **error
|
2024-03-13 18:09:24 -04:00
|
|
|
) {
|
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
|
|
|
|
|
|
|
if (!texts || !*texts) {
|
2024-03-19 10:56:14 -04:00
|
|
|
llmodel_set_error(error, "'texts' is NULL or empty");
|
2023-07-17 16:21:03 -04:00
|
|
|
return nullptr;
|
|
|
|
}
|
2024-03-13 18:09:24 -04:00
|
|
|
|
|
|
|
std::vector<std::string> textsVec;
|
|
|
|
while (*texts) { textsVec.emplace_back(*texts++); }
|
|
|
|
|
|
|
|
size_t embd_size;
|
|
|
|
float *embedding;
|
|
|
|
|
|
|
|
try {
|
|
|
|
embd_size = wrapper->llModel->embeddingSize();
|
|
|
|
if (dimensionality > 0 && dimensionality < int(embd_size))
|
|
|
|
embd_size = dimensionality;
|
|
|
|
|
|
|
|
embd_size *= textsVec.size();
|
|
|
|
|
|
|
|
std::optional<std::string> prefixStr;
|
|
|
|
if (prefix) { prefixStr = prefix; }
|
|
|
|
|
|
|
|
embedding = new float[embd_size];
|
2024-04-12 16:00:39 -04:00
|
|
|
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas, cancel_cb);
|
2024-03-13 18:09:24 -04:00
|
|
|
} catch (std::exception const &e) {
|
2024-03-19 10:56:14 -04:00
|
|
|
llmodel_set_error(error, e.what());
|
2023-07-09 11:32:51 -04:00
|
|
|
return nullptr;
|
|
|
|
}
|
2024-03-13 18:09:24 -04:00
|
|
|
|
|
|
|
*embedding_size = embd_size;
|
2023-07-09 11:32:51 -04:00
|
|
|
return embedding;
|
|
|
|
}
|
|
|
|
|
|
|
|
void llmodel_free_embedding(float *ptr)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
delete[] ptr;
|
2023-07-09 11:32:51 -04:00
|
|
|
}
|
|
|
|
|
2023-04-27 09:43:24 -04:00
|
|
|
void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-04-27 09:43:24 -04:00
|
|
|
wrapper->llModel->setThreadCount(n_threads);
|
|
|
|
}
|
|
|
|
|
|
|
|
int32_t llmodel_threadCount(llmodel_model model)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-05-31 17:04:01 -04:00
|
|
|
return wrapper->llModel->threadCount();
|
2023-04-27 09:43:24 -04:00
|
|
|
}
|
2023-06-02 10:57:21 -04:00
|
|
|
|
|
|
|
void llmodel_set_implementation_search_path(const char *path)
|
|
|
|
{
|
2023-07-09 11:00:20 -04:00
|
|
|
LLModel::Implementation::setImplementationsSearchPath(path);
|
2023-06-02 10:57:21 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
const char *llmodel_get_implementation_search_path()
|
|
|
|
{
|
2023-07-09 11:00:20 -04:00
|
|
|
return LLModel::Implementation::implementationsSearchPath().c_str();
|
2023-06-02 10:57:21 -04:00
|
|
|
}
|
2023-08-30 09:43:56 -04:00
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
// RAII wrapper around a C-style struct
|
|
|
|
struct llmodel_gpu_device_cpp: llmodel_gpu_device {
|
|
|
|
llmodel_gpu_device_cpp() = default;
|
2023-08-30 09:43:56 -04:00
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
llmodel_gpu_device_cpp(const llmodel_gpu_device_cpp &) = delete;
|
|
|
|
llmodel_gpu_device_cpp( llmodel_gpu_device_cpp &&) = delete;
|
|
|
|
|
|
|
|
const llmodel_gpu_device_cpp &operator=(const llmodel_gpu_device_cpp &) = delete;
|
|
|
|
llmodel_gpu_device_cpp &operator=( llmodel_gpu_device_cpp &&) = delete;
|
|
|
|
|
|
|
|
~llmodel_gpu_device_cpp() {
|
|
|
|
free(const_cast<char *>(name));
|
|
|
|
free(const_cast<char *>(vendor));
|
|
|
|
}
|
|
|
|
};
|
2023-08-30 09:43:56 -04:00
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
static_assert(sizeof(llmodel_gpu_device_cpp) == sizeof(llmodel_gpu_device));
|
2023-08-30 09:43:56 -04:00
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
struct llmodel_gpu_device *llmodel_available_gpu_devices(size_t memoryRequired, int *num_devices)
|
|
|
|
{
|
|
|
|
static thread_local std::unique_ptr<llmodel_gpu_device_cpp[]> c_devices;
|
|
|
|
|
|
|
|
auto devices = LLModel::Implementation::availableGPUDevices(memoryRequired);
|
|
|
|
*num_devices = devices.size();
|
2023-08-30 09:43:56 -04:00
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
if (devices.empty()) { return nullptr; /* no devices */ }
|
|
|
|
|
|
|
|
c_devices = std::make_unique<llmodel_gpu_device_cpp[]>(devices.size());
|
|
|
|
for (unsigned i = 0; i < devices.size(); i++) {
|
|
|
|
const auto &dev = devices[i];
|
|
|
|
auto &cdev = c_devices[i];
|
2024-05-15 15:27:50 -04:00
|
|
|
cdev.backend = dev.backend;
|
2024-04-04 14:52:13 -04:00
|
|
|
cdev.index = dev.index;
|
|
|
|
cdev.type = dev.type;
|
|
|
|
cdev.heapSize = dev.heapSize;
|
|
|
|
cdev.name = strdup(dev.name.c_str());
|
|
|
|
cdev.vendor = strdup(dev.vendor.c_str());
|
2023-08-30 09:43:56 -04:00
|
|
|
}
|
|
|
|
|
2024-04-04 14:52:13 -04:00
|
|
|
return c_devices.get();
|
2023-08-30 09:43:56 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
bool llmodel_gpu_init_gpu_device_by_string(llmodel_model model, size_t memoryRequired, const char *device)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-08-30 09:43:56 -04:00
|
|
|
return wrapper->llModel->initializeGPUDevice(memoryRequired, std::string(device));
|
|
|
|
}
|
|
|
|
|
|
|
|
bool llmodel_gpu_init_gpu_device_by_struct(llmodel_model model, const llmodel_gpu_device *device)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2024-01-31 14:17:44 -05:00
|
|
|
return wrapper->llModel->initializeGPUDevice(device->index);
|
2023-08-30 09:43:56 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
bool llmodel_gpu_init_gpu_device_by_int(llmodel_model model, int device)
|
|
|
|
{
|
2024-03-13 18:09:24 -04:00
|
|
|
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
2023-08-30 09:43:56 -04:00
|
|
|
return wrapper->llModel->initializeGPUDevice(device);
|
|
|
|
}
|
|
|
|
|
2024-04-18 14:52:02 -04:00
|
|
|
const char *llmodel_model_backend_name(llmodel_model model)
|
|
|
|
{
|
|
|
|
const auto *wrapper = static_cast<LLModelWrapper *>(model);
|
|
|
|
return wrapper->llModel->backendName();
|
|
|
|
}
|
|
|
|
|
|
|
|
const char *llmodel_model_gpu_device_name(llmodel_model model)
|
|
|
|
{
|
|
|
|
const auto *wrapper = static_cast<LLModelWrapper *>(model);
|
|
|
|
return wrapper->llModel->gpuDeviceName();
|
|
|
|
}
|