From 102f68b18cdde795462366d6e1aaabff3a666d7a Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 25 Apr 2023 21:03:10 -0400 Subject: [PATCH] Fixup the api a bit. --- llmodel/llmodel_c.h | 56 ++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index e68cf045..5979e335 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -11,9 +11,9 @@ extern "C" { /** * Opaque pointers to the underlying C++ classes. */ -typedef void *LLMODEL_C; -typedef void *GPTJ_C; -typedef void *LLAMA_C; +typedef void *llmodel_model; +typedef void *llmodel_gptj; +typedef void *llmodel_llama; /** * PromptContext_C structure for holding the prompt context. @@ -31,37 +31,46 @@ typedef struct { float repeat_penalty; // penalty factor for repeated tokens int32_t repeat_last_n; // last n tokens to penalize float contextErase; // percent of context to erase if we exceed the context window -} PromptContext_C; +} llmodel_prompt_context; /** - * Callback types for response and recalculation. + * Callback type for response. + * @param token_id The token id of the response. + * @param response The response string. + * @return a bool indicating whether the model should keep generating. */ -typedef bool (*ResponseCallback)(int32_t, const char *); -typedef bool (*RecalculateCallback)(bool); +typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); + +/** + * Callback type for recalculation of context. + * @param whether the model is recalculating the context. + * @return a bool indicating whether the model should keep generating. + */ +typedef bool (*llmodel_recalculate_callback)(bool is_recalculating); /** * Create a GPTJ instance. * @return A pointer to the GPTJ instance. */ -GPTJ_C GPTJ_create(); +llmodel_gptj llmodel_gptj_create(); /** * Destroy a GPTJ instance. * @param gptj A pointer to the GPTJ instance. */ -void GPTJ_destroy(GPTJ_C gptj); +void llmodel_gptj_destroy(llmodel_gptj gptj); /** * Create a LLAMA instance. * @return A pointer to the LLAMA instance. */ -LLAMA_C LLAMA_create(); +llmodel_llama llmodel_llama_create(); /** * Destroy a LLAMA instance. * @param llama A pointer to the LLAMA instance. */ -void LLAMA_destroy(LLAMA_C llama); +void llmodel_llama_destroy(llmodel_llama llama); /** * Load a model from a file. @@ -69,23 +78,14 @@ void LLAMA_destroy(LLAMA_C llama); * @param modelPath A string representing the path to the model file. * @return true if the model was loaded successfully, false otherwise. */ -bool LLMODEL_loadModel(LLMODEL_C model, const char *modelPath); - -/** - * Load a model from an input stream. - * @param model A pointer to the LLMODEL_C instance. - * @param modelPath A string representing the path to the model file. - * @param fin A pointer to the input stream. - * @return true if the model was loaded successfully, false otherwise. - */ -bool LLMODEL_loadModelStream(LLMODEL_C model, const char *modelPath, void *fin); +bool llmodel_loadModel(llmodel_model model, const char *model_path); /** * Check if a model is loaded. * @param model A pointer to the LLMODEL_C instance. * @return true if the model is loaded, false otherwise. */ -bool LLMODEL_isModelLoaded(LLMODEL_C model); +bool llmodel_isModelLoaded(llmodel_model model); /** * Generate a response using the model. @@ -95,24 +95,24 @@ bool LLMODEL_isModelLoaded(LLMODEL_C model); * @param recalculate A callback function for handling recalculation requests. * @param ctx A pointer to the PromptContext_C structure. */ -void LLMODEL_prompt(LLMODEL_C model, const char *prompt, - ResponseCallback response, - RecalculateCallback recalculate, - PromptContext_C *ctx); +void llmodel_prompt(llmodel_model model, const char *prompt, + llmodel_response_callback response, + llmodel_recalculate_callback recalculate, + llmodel_prompt_context *ctx); /** * Set the number of threads to be used by the model. * @param model A pointer to the LLMODEL_C instance. * @param n_threads The number of threads to be used. */ -void LLMODEL_setThreadCount(LLMODEL_C model, int32_t n_threads); +void llmodel_setThreadCount(llmodel_model model, int32_t n_threads); /** * Get the number of threads currently being used by the model. * @param model A pointer to the LLMODEL_C instance. * @return The number of threads currently being used. */ -int32_t LLMODEL_threadCount(LLMODEL_C model); +int32_t llmodel_threadCount(llmodel_model model); #ifdef __cplusplus }