mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Provide an initial impl. of the C interface. NOTE: has not been tested.
This commit is contained in:
parent
386ce08fca
commit
0e9f85bcda
@ -1,2 +1,120 @@
|
||||
#include "llmodel_c.h"
|
||||
|
||||
#include "gptj.h"
|
||||
#include "llamamodel.h"
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
};
|
||||
|
||||
llmodel_model llmodel_gptj_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new GPTJ;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_gptj_destroy(llmodel_model gptj)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(gptj);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
llmodel_model llmodel_llama_create()
|
||||
{
|
||||
LLModelWrapper *wrapper = new LLModelWrapper;
|
||||
wrapper->llModel = new LLamaModel;
|
||||
return reinterpret_cast<void*>(wrapper);
|
||||
}
|
||||
|
||||
void llmodel_llama_destroy(llmodel_model llama)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(llama);
|
||||
delete wrapper->llModel;
|
||||
delete wrapper;
|
||||
}
|
||||
|
||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->loadModel(model_path);
|
||||
}
|
||||
|
||||
bool llmodel_isModelLoaded(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
// Wrapper functions for the C callbacks
|
||||
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
|
||||
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
|
||||
return callback(token_id, response.c_str());
|
||||
}
|
||||
|
||||
bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
||||
llmodel_recalculate_callback callback = reinterpret_cast<llmodel_recalculate_callback>(user_data);
|
||||
return callback(is_recalculating);
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
llmodel_response_callback response,
|
||||
llmodel_recalculate_callback recalculate,
|
||||
llmodel_prompt_context *ctx)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
||||
// Create std::function wrappers that call the C function pointers
|
||||
std::function<bool(int32_t, const std::string&)> response_func =
|
||||
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response));
|
||||
std::function<bool(bool)> recalc_func =
|
||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate));
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
wrapper->promptContext.temp = ctx->temp;
|
||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
ctx->logits = wrapper->promptContext.logits.data();
|
||||
ctx->logits_size = wrapper->promptContext.logits.size();
|
||||
ctx->tokens = wrapper->promptContext.tokens.data();
|
||||
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
||||
|
||||
// Update the rest of the C prompt context
|
||||
ctx->n_past = wrapper->promptContext.n_past;
|
||||
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
ctx->temp = wrapper->promptContext.temp;
|
||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
||||
ctx->context_erase = wrapper->promptContext.contextErase;
|
||||
}
|
||||
|
||||
void llmodel_setThreadCount(llmodel_model model, int32_t n_threads)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
wrapper->llModel->setThreadCount(n_threads);
|
||||
}
|
||||
|
||||
int32_t llmodel_threadCount(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->threadCount();
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
#define LLMODEL_C_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -15,10 +16,15 @@ typedef void *llmodel_model;
|
||||
|
||||
/**
|
||||
* llmodel_prompt_context structure for holding the prompt context.
|
||||
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
|
||||
* behavior.
|
||||
*/
|
||||
typedef struct {
|
||||
float *logits; // logits of current context
|
||||
size_t logits_size; // the size of the raw logits vector
|
||||
int32_t *tokens; // current tokens in the context window
|
||||
size_t tokens_size; // the size of the raw tokens vector
|
||||
int32_t n_past; // number of tokens in past conversation
|
||||
int32_t n_ctx; // number of tokens possible in context window
|
||||
int32_t n_predict; // number of tokens to predict
|
||||
|
Loading…
Reference in New Issue
Block a user