mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
typescript: async generator and token stream (#1897)
Signed-off-by: Tare Ebelo <75279482+TareHimself@users.noreply.github.com> Signed-off-by: jacob <jacoobes@sern.dev> Signed-off-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: jacob <jacoobes@sern.dev> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
ef518fae3e
commit
a153cc5b25
@ -611,6 +611,7 @@ jobs:
|
||||
$Env:Path += ";$MinGwBin"
|
||||
$Env:Path += ";C:\Program Files\CMake\bin"
|
||||
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
||||
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
|
||||
cd gpt4all-backend
|
||||
mkdir runtimes/win-x64
|
||||
cd runtimes/win-x64
|
||||
@ -651,6 +652,7 @@ jobs:
|
||||
command: |
|
||||
$Env:Path += ";C:\Program Files\CMake\bin"
|
||||
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
||||
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
|
||||
cd gpt4all-backend
|
||||
mkdir runtimes/win-x64_msvc
|
||||
cd runtimes/win-x64_msvc
|
||||
@ -1107,8 +1109,12 @@ workflows:
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- csharp-hold:
|
||||
type: approval
|
||||
- nuget-hold:
|
||||
type: approval
|
||||
- nodejs-hold:
|
||||
type: approval
|
||||
- npm-hold:
|
||||
type: approval
|
||||
- build-bindings-backend-linux:
|
||||
@ -1151,21 +1157,21 @@ workflows:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- npm-hold
|
||||
- nodejs-hold
|
||||
- build-bindings-backend-linux
|
||||
- build-nodejs-windows:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- npm-hold
|
||||
- nodejs-hold
|
||||
- build-bindings-backend-windows-msvc
|
||||
- build-nodejs-macos:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- npm-hold
|
||||
- nodejs-hold
|
||||
- build-bindings-backend-macos
|
||||
|
||||
|
||||
@ -1175,21 +1181,21 @@ workflows:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- nuget-hold
|
||||
- csharp-hold
|
||||
- build-bindings-backend-linux
|
||||
- build-csharp-windows:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- nuget-hold
|
||||
- csharp-hold
|
||||
- build-bindings-backend-windows
|
||||
- build-csharp-macos:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
requires:
|
||||
- nuget-hold
|
||||
- csharp-hold
|
||||
- build-bindings-backend-macos
|
||||
- store-and-upload-nupkgs:
|
||||
filters:
|
||||
|
@ -159,6 +159,7 @@ This package is in active development, and breaking changes may happen until the
|
||||
* [mpt](#mpt)
|
||||
* [replit](#replit)
|
||||
* [type](#type)
|
||||
* [TokenCallback](#tokencallback)
|
||||
* [InferenceModel](#inferencemodel)
|
||||
* [dispose](#dispose)
|
||||
* [EmbeddingModel](#embeddingmodel)
|
||||
@ -184,16 +185,17 @@ This package is in active development, and breaking changes may happen until the
|
||||
* [Parameters](#parameters-5)
|
||||
* [hasGpuDevice](#hasgpudevice)
|
||||
* [listGpu](#listgpu)
|
||||
* [Parameters](#parameters-6)
|
||||
* [dispose](#dispose-2)
|
||||
* [GpuDevice](#gpudevice)
|
||||
* [type](#type-2)
|
||||
* [LoadModelOptions](#loadmodeloptions)
|
||||
* [loadModel](#loadmodel)
|
||||
* [Parameters](#parameters-6)
|
||||
* [createCompletion](#createcompletion)
|
||||
* [Parameters](#parameters-7)
|
||||
* [createEmbedding](#createembedding)
|
||||
* [createCompletion](#createcompletion)
|
||||
* [Parameters](#parameters-8)
|
||||
* [createEmbedding](#createembedding)
|
||||
* [Parameters](#parameters-9)
|
||||
* [CompletionOptions](#completionoptions)
|
||||
* [verbose](#verbose)
|
||||
* [systemPromptTemplate](#systemprompttemplate)
|
||||
@ -225,15 +227,15 @@ This package is in active development, and breaking changes may happen until the
|
||||
* [repeatPenalty](#repeatpenalty)
|
||||
* [repeatLastN](#repeatlastn)
|
||||
* [contextErase](#contexterase)
|
||||
* [createTokenStream](#createtokenstream)
|
||||
* [Parameters](#parameters-9)
|
||||
* [generateTokens](#generatetokens)
|
||||
* [Parameters](#parameters-10)
|
||||
* [DEFAULT\_DIRECTORY](#default_directory)
|
||||
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
|
||||
* [DEFAULT\_MODEL\_CONFIG](#default_model_config)
|
||||
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
|
||||
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
|
||||
* [downloadModel](#downloadmodel)
|
||||
* [Parameters](#parameters-10)
|
||||
* [Parameters](#parameters-11)
|
||||
* [Examples](#examples)
|
||||
* [DownloadModelOptions](#downloadmodeloptions)
|
||||
* [modelPath](#modelpath)
|
||||
@ -279,6 +281,12 @@ Model architecture. This argument currently does not have any functionality and
|
||||
|
||||
Type: ModelType
|
||||
|
||||
#### TokenCallback
|
||||
|
||||
Callback for controlling token generation
|
||||
|
||||
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
|
||||
|
||||
#### InferenceModel
|
||||
|
||||
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
||||
@ -362,9 +370,9 @@ Use the prompt function exported for a value
|
||||
|
||||
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
|
||||
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
|
||||
* `callback` **function (res: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): void** 
|
||||
* `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
|
||||
|
||||
Returns **void** The result of the model prompt.
|
||||
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt.
|
||||
|
||||
##### embed
|
||||
|
||||
@ -424,6 +432,12 @@ Returns **[boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/
|
||||
|
||||
GPUs that are usable for this LLModel
|
||||
|
||||
###### Parameters
|
||||
|
||||
* `nCtx` **[number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)** Maximum size of context window
|
||||
|
||||
<!---->
|
||||
|
||||
* Throws **any** if hasGpuDevice returns false (i think)
|
||||
|
||||
Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>** 
|
||||
@ -690,17 +704,18 @@ The percentage of context to erase if the context window is exceeded.
|
||||
|
||||
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
|
||||
|
||||
#### createTokenStream
|
||||
#### generateTokens
|
||||
|
||||
TODO: Help wanted to implement this
|
||||
Creates an async generator of tokens
|
||||
|
||||
##### Parameters
|
||||
|
||||
* `llmodel` **[LLModel](#llmodel)** 
|
||||
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** 
|
||||
* `options` **[CompletionOptions](#completionoptions)** 
|
||||
* `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
|
||||
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
|
||||
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
|
||||
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
|
||||
|
||||
Returns **function (ll: [LLModel](#llmodel)): AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** 
|
||||
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
|
||||
|
||||
#### DEFAULT\_DIRECTORY
|
||||
|
||||
|
@ -136,14 +136,18 @@ yarn test
|
||||
|
||||
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
|
||||
|
||||
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
|
||||
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
|
||||
* Should include prebuilds to avoid painful node-gyp errors
|
||||
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
|
||||
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
|
||||
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
|
||||
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
|
||||
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
|
||||
* \[x] generateTokens is the new name for this^
|
||||
* \[x] proper unit testing (integrate with circle ci)
|
||||
* \[x] publish to npm under alpha tag `gpt4all@alpha`
|
||||
* \[x] have more people test on other platforms (mac tester needed)
|
||||
* \[x] switch to new pluggable backend
|
||||
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
|
||||
* Should include prebuilds to avoid painful node-gyp errors
|
||||
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
|
||||
|
||||
|
||||
### API Reference
|
||||
|
@ -3,9 +3,9 @@
|
||||
|
||||
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
||||
Napi::Function self = DefineClass(env, "LLModel", {
|
||||
InstanceMethod("type", &NodeModelWrapper::getType),
|
||||
InstanceMethod("type", &NodeModelWrapper::GetType),
|
||||
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
|
||||
InstanceMethod("name", &NodeModelWrapper::getName),
|
||||
InstanceMethod("name", &NodeModelWrapper::GetName),
|
||||
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
|
||||
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
|
||||
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
|
||||
@ -28,14 +28,14 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
||||
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
{
|
||||
auto env = info.Env();
|
||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048, 100) ));
|
||||
return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) ));
|
||||
|
||||
}
|
||||
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
||||
{
|
||||
auto env = info.Env();
|
||||
int num_devices = 0;
|
||||
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str());
|
||||
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
|
||||
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
|
||||
if(all_devices == nullptr) {
|
||||
Napi::Error::New(
|
||||
@ -70,7 +70,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
return js_array;
|
||||
}
|
||||
|
||||
Napi::Value NodeModelWrapper::getType(const Napi::CallbackInfo& info)
|
||||
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info)
|
||||
{
|
||||
if(type.empty()) {
|
||||
return info.Env().Undefined();
|
||||
@ -132,6 +132,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
library_path = ".";
|
||||
}
|
||||
device = config_object.Get("device").As<Napi::String>();
|
||||
|
||||
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
|
||||
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
|
||||
}
|
||||
llmodel_set_implementation_search_path(library_path.c_str());
|
||||
const char* e;
|
||||
@ -148,20 +151,17 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
return;
|
||||
}
|
||||
if(device != "cpu") {
|
||||
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str());
|
||||
std::cout << "Initiating GPU\n";
|
||||
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers);
|
||||
|
||||
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
|
||||
if(success) {
|
||||
std::cout << "GPU init successfully\n";
|
||||
} else {
|
||||
if(!success) {
|
||||
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
|
||||
//Haven't implemented this but it is still open to contribution
|
||||
std::cout << "WARNING: Failed to init GPU\n";
|
||||
}
|
||||
}
|
||||
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048, 100);
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
|
||||
if(!success) {
|
||||
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
||||
return;
|
||||
@ -254,6 +254,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
.repeat_last_n = 10,
|
||||
.context_erase = 0.5
|
||||
};
|
||||
|
||||
PromptWorkerConfig promptWorkerConfig;
|
||||
|
||||
if(info[1].IsObject())
|
||||
{
|
||||
auto inputObject = info[1].As<Napi::Object>();
|
||||
@ -285,29 +288,33 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
if(inputObject.Has("context_erase"))
|
||||
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
|
||||
}
|
||||
//copy to protect llmodel resources when splitting to new thread
|
||||
llmodel_prompt_context copiedPrompt = promptContext;
|
||||
else
|
||||
{
|
||||
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
|
||||
return info.Env().Undefined();
|
||||
}
|
||||
|
||||
std::string copiedQuestion = question;
|
||||
PromptWorkContext pc = {
|
||||
copiedQuestion,
|
||||
inference_,
|
||||
copiedPrompt,
|
||||
""
|
||||
};
|
||||
auto threadSafeContext = new TsfnContext(env, pc);
|
||||
threadSafeContext->tsfn = Napi::ThreadSafeFunction::New(
|
||||
env, // Environment
|
||||
info[2].As<Napi::Function>(), // JS function from caller
|
||||
"PromptCallback", // Resource name
|
||||
0, // Max queue size (0 = unlimited).
|
||||
1, // Initial thread count
|
||||
threadSafeContext, // Context,
|
||||
FinalizerCallback, // Finalizer
|
||||
(void*)nullptr // Finalizer data
|
||||
);
|
||||
threadSafeContext->nativeThread = std::thread(threadEntry, threadSafeContext);
|
||||
return threadSafeContext->deferred_.Promise();
|
||||
if(info.Length() >= 3 && info[2].IsFunction()){
|
||||
promptWorkerConfig.bHasTokenCallback = true;
|
||||
promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>();
|
||||
}
|
||||
|
||||
|
||||
|
||||
//copy to protect llmodel resources when splitting to new thread
|
||||
// llmodel_prompt_context copiedPrompt = promptContext;
|
||||
promptWorkerConfig.context = promptContext;
|
||||
promptWorkerConfig.model = GetInference();
|
||||
promptWorkerConfig.mutex = &inference_mutex;
|
||||
promptWorkerConfig.prompt = question;
|
||||
promptWorkerConfig.result = "";
|
||||
|
||||
|
||||
auto worker = new PromptWorker(env, promptWorkerConfig);
|
||||
|
||||
worker->Queue();
|
||||
|
||||
return worker->GetPromise();
|
||||
}
|
||||
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
|
||||
llmodel_model_destroy(inference_);
|
||||
@ -321,7 +328,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
}
|
||||
}
|
||||
|
||||
Napi::Value NodeModelWrapper::getName(const Napi::CallbackInfo& info) {
|
||||
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
|
||||
return Napi::String::New(info.Env(), name);
|
||||
}
|
||||
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {
|
||||
|
@ -7,14 +7,17 @@
|
||||
#include <memory>
|
||||
#include <filesystem>
|
||||
#include <set>
|
||||
#include <mutex>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
|
||||
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
|
||||
|
||||
public:
|
||||
NodeModelWrapper(const Napi::CallbackInfo &);
|
||||
//virtual ~NodeModelWrapper();
|
||||
Napi::Value getType(const Napi::CallbackInfo& info);
|
||||
Napi::Value GetType(const Napi::CallbackInfo& info);
|
||||
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
|
||||
Napi::Value StateSize(const Napi::CallbackInfo& info);
|
||||
//void Finalize(Napi::Env env) override;
|
||||
@ -25,7 +28,7 @@ public:
|
||||
Napi::Value Prompt(const Napi::CallbackInfo& info);
|
||||
void SetThreadCount(const Napi::CallbackInfo& info);
|
||||
void Dispose(const Napi::CallbackInfo& info);
|
||||
Napi::Value getName(const Napi::CallbackInfo& info);
|
||||
Napi::Value GetName(const Napi::CallbackInfo& info);
|
||||
Napi::Value ThreadCount(const Napi::CallbackInfo& info);
|
||||
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
|
||||
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
|
||||
@ -48,8 +51,12 @@ private:
|
||||
*/
|
||||
llmodel_model inference_;
|
||||
|
||||
std::mutex inference_mutex;
|
||||
|
||||
std::string type;
|
||||
// corresponds to LLModel::name() in typescript
|
||||
std::string name;
|
||||
int nCtx{};
|
||||
int nGpuLayers{};
|
||||
std::string full_model_path;
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "gpt4all",
|
||||
"version": "3.1.0",
|
||||
"version": "3.2.0",
|
||||
"packageManager": "yarn@3.6.1",
|
||||
"main": "src/gpt4all.js",
|
||||
"repository": "nomic-ai/gpt4all",
|
||||
|
@ -1,60 +1,146 @@
|
||||
#include "prompt.h"
|
||||
#include <future>
|
||||
|
||||
|
||||
TsfnContext::TsfnContext(Napi::Env env, const PromptWorkContext& pc)
|
||||
: deferred_(Napi::Promise::Deferred::New(env)), pc(pc) {
|
||||
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
|
||||
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
|
||||
if(_config.bHasTokenCallback){
|
||||
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
|
||||
}
|
||||
namespace {
|
||||
static std::string *res;
|
||||
}
|
||||
|
||||
bool response_callback(int32_t token_id, const char *response) {
|
||||
*res += response;
|
||||
return token_id != -1;
|
||||
PromptWorker::~PromptWorker()
|
||||
{
|
||||
if(_config.bHasTokenCallback){
|
||||
_tsfn.Release();
|
||||
}
|
||||
bool recalculate_callback (bool isrecalculating) {
|
||||
return isrecalculating;
|
||||
};
|
||||
bool prompt_callback (int32_t tid) {
|
||||
}
|
||||
|
||||
void PromptWorker::Execute()
|
||||
{
|
||||
_config.mutex->lock();
|
||||
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
|
||||
|
||||
auto ctx = &_config.context;
|
||||
|
||||
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
|
||||
wrapper->promptContext.tokens.resize(ctx->n_past);
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
wrapper->promptContext.temp = ctx->temp;
|
||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Napi::Error::Fatal(
|
||||
// "SUPRA",
|
||||
// "About to prompt");
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(
|
||||
_config.prompt,
|
||||
[](int32_t tid) { return true; },
|
||||
[this](int32_t token_id, const std::string tok)
|
||||
{
|
||||
return ResponseCallback(token_id, tok);
|
||||
},
|
||||
[](bool isRecalculating)
|
||||
{
|
||||
return isRecalculating;
|
||||
},
|
||||
wrapper->promptContext);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
ctx->logits = wrapper->promptContext.logits.data();
|
||||
ctx->logits_size = wrapper->promptContext.logits.size();
|
||||
ctx->tokens = wrapper->promptContext.tokens.data();
|
||||
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
||||
|
||||
// Update the rest of the C prompt context
|
||||
ctx->n_past = wrapper->promptContext.n_past;
|
||||
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
ctx->temp = wrapper->promptContext.temp;
|
||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
||||
ctx->context_erase = wrapper->promptContext.contextErase;
|
||||
|
||||
_config.mutex->unlock();
|
||||
}
|
||||
|
||||
void PromptWorker::OnOK()
|
||||
{
|
||||
promise.Resolve(Napi::String::New(Env(), result));
|
||||
}
|
||||
|
||||
void PromptWorker::OnError(const Napi::Error &e)
|
||||
{
|
||||
promise.Reject(e.Value());
|
||||
}
|
||||
|
||||
Napi::Promise PromptWorker::GetPromise()
|
||||
{
|
||||
return promise.Promise();
|
||||
}
|
||||
|
||||
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
|
||||
{
|
||||
if (token_id == -1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!_config.bHasTokenCallback){
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
// The thread entry point. This takes as its arguments the specific
|
||||
// threadsafe-function context created inside the main thread.
|
||||
void threadEntry(TsfnContext* context) {
|
||||
static std::mutex mtx;
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
res = &context->pc.res;
|
||||
// Perform a call into JavaScript.
|
||||
napi_status status =
|
||||
context->tsfn.BlockingCall(&context->pc,
|
||||
[](Napi::Env env, Napi::Function jsCallback, PromptWorkContext* pc) {
|
||||
llmodel_prompt(
|
||||
pc->inference_,
|
||||
pc->question.c_str(),
|
||||
&prompt_callback,
|
||||
&response_callback,
|
||||
&recalculate_callback,
|
||||
&pc->prompt_params
|
||||
);
|
||||
result += token;
|
||||
|
||||
std::promise<bool> promise;
|
||||
|
||||
auto info = new TokenCallbackInfo();
|
||||
info->tokenId = token_id;
|
||||
info->token = token;
|
||||
info->total = result;
|
||||
|
||||
auto future = promise.get_future();
|
||||
|
||||
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
|
||||
{
|
||||
// Transform native data into JS data, passing it to the provided
|
||||
// `jsCallback` -- the TSFN's JavaScript function.
|
||||
auto token_id = Napi::Number::New(env, value->tokenId);
|
||||
auto token = Napi::String::New(env, value->token);
|
||||
auto total = Napi::String::New(env,value->total);
|
||||
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
|
||||
promise.set_value(jsResult);
|
||||
// We're finished with the data.
|
||||
delete value;
|
||||
});
|
||||
|
||||
if (status != napi_ok) {
|
||||
Napi::Error::Fatal(
|
||||
"ThreadEntry",
|
||||
"PromptWorkerResponseCallback",
|
||||
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
|
||||
}
|
||||
// Release the thread-safe function. This decrements the internal thread
|
||||
// count, and will perform finalization since the count will reach 0.
|
||||
context->tsfn.Release();
|
||||
|
||||
return future.get();
|
||||
}
|
||||
|
||||
void FinalizerCallback(Napi::Env env,
|
||||
void* finalizeData,
|
||||
TsfnContext* context) {
|
||||
// Resolve the Promise previously returned to JS
|
||||
context->deferred_.Resolve(Napi::String::New(env, context->pc.res));
|
||||
// Wait for the thread to finish executing before proceeding.
|
||||
context->nativeThread.join();
|
||||
delete context;
|
||||
bool PromptWorker::RecalculateCallback(bool isRecalculating)
|
||||
{
|
||||
return isRecalculating;
|
||||
}
|
||||
|
||||
bool PromptWorker::PromptCallback(int32_t tid)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
@ -1,44 +1,59 @@
|
||||
#ifndef TSFN_CONTEXT_H
|
||||
#define TSFN_CONTEXT_H
|
||||
#ifndef PREDICT_WORKER_H
|
||||
#define PREDICT_WORKER_H
|
||||
|
||||
#include "napi.h"
|
||||
#include "llmodel_c.h"
|
||||
#include "llmodel.h"
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <iostream>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
struct PromptWorkContext {
|
||||
std::string question;
|
||||
llmodel_model inference_;
|
||||
llmodel_prompt_context prompt_params;
|
||||
std::string res;
|
||||
|
||||
struct TokenCallbackInfo
|
||||
{
|
||||
int32_t tokenId;
|
||||
std::string total;
|
||||
std::string token;
|
||||
};
|
||||
|
||||
struct TsfnContext {
|
||||
struct LLModelWrapper
|
||||
{
|
||||
LLModel *llModel = nullptr;
|
||||
LLModel::PromptContext promptContext;
|
||||
~LLModelWrapper() { delete llModel; }
|
||||
};
|
||||
|
||||
struct PromptWorkerConfig
|
||||
{
|
||||
Napi::Function tokenCallback;
|
||||
bool bHasTokenCallback = false;
|
||||
llmodel_model model;
|
||||
std::mutex * mutex;
|
||||
std::string prompt;
|
||||
llmodel_prompt_context context;
|
||||
std::string result;
|
||||
};
|
||||
|
||||
class PromptWorker : public Napi::AsyncWorker
|
||||
{
|
||||
public:
|
||||
TsfnContext(Napi::Env env, const PromptWorkContext &pc);
|
||||
std::thread nativeThread;
|
||||
Napi::Promise::Deferred deferred_;
|
||||
PromptWorkContext pc;
|
||||
Napi::ThreadSafeFunction tsfn;
|
||||
PromptWorker(Napi::Env env, PromptWorkerConfig config);
|
||||
~PromptWorker();
|
||||
void Execute() override;
|
||||
void OnOK() override;
|
||||
void OnError(const Napi::Error &e) override;
|
||||
Napi::Promise GetPromise();
|
||||
|
||||
// Some data to pass around
|
||||
// int ints[ARRAY_LENGTH];
|
||||
bool ResponseCallback(int32_t token_id, const std::string token);
|
||||
bool RecalculateCallback(bool isrecalculating);
|
||||
bool PromptCallback(int32_t tid);
|
||||
|
||||
private:
|
||||
Napi::Promise::Deferred promise;
|
||||
std::string result;
|
||||
PromptWorkerConfig _config;
|
||||
Napi::ThreadSafeFunction _tsfn;
|
||||
};
|
||||
|
||||
// The thread entry point. This takes as its arguments the specific
|
||||
// threadsafe-function context created inside the main thread.
|
||||
void threadEntry(TsfnContext*);
|
||||
|
||||
// The thread-safe function finalizer callback. This callback executes
|
||||
// at destruction of thread-safe function, taking as arguments the finalizer
|
||||
// data and threadsafe-function context.
|
||||
void FinalizerCallback(Napi::Env, void* finalizeData, TsfnContext*);
|
||||
|
||||
bool response_callback(int32_t token_id, const char *response);
|
||||
bool recalculate_callback (bool isrecalculating);
|
||||
bool prompt_callback (int32_t tid);
|
||||
#endif // TSFN_CONTEXT_H
|
||||
#endif // PREDICT_WORKER_H
|
||||
|
41
gpt4all-bindings/typescript/spec/generator.mjs
Normal file
41
gpt4all-bindings/typescript/spec/generator.mjs
Normal file
@ -0,0 +1,41 @@
|
||||
import gpt from '../src/gpt4all.js'
|
||||
|
||||
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
|
||||
|
||||
process.stdout.write('Response: ')
|
||||
|
||||
|
||||
const tokens = gpt.generateTokens(model, [{
|
||||
role: 'user',
|
||||
content: "How are you ?"
|
||||
}], { nPredict: 2048 })
|
||||
for await (const token of tokens){
|
||||
process.stdout.write(token);
|
||||
}
|
||||
|
||||
|
||||
const result = await gpt.createCompletion(model, [{
|
||||
role: 'user',
|
||||
content: "You sure?"
|
||||
}])
|
||||
|
||||
console.log(result)
|
||||
|
||||
const result2 = await gpt.createCompletion(model, [{
|
||||
role: 'user',
|
||||
content: "You sure you sure?"
|
||||
}])
|
||||
|
||||
console.log(result2)
|
||||
|
||||
|
||||
const tokens2 = gpt.generateTokens(model, [{
|
||||
role: 'user',
|
||||
content: "If 3 + 3 is 5, what is 2 + 2?"
|
||||
}], { nPredict: 2048 })
|
||||
for await (const token of tokens2){
|
||||
process.stdout.write(token);
|
||||
}
|
||||
console.log("done")
|
||||
model.dispose();
|
||||
|
46
gpt4all-bindings/typescript/src/gpt4all.d.ts
vendored
46
gpt4all-bindings/typescript/src/gpt4all.d.ts
vendored
@ -49,6 +49,12 @@ interface ModelConfig {
|
||||
path: string;
|
||||
url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Callback for controlling token generation
|
||||
*/
|
||||
type TokenCallback = (tokenId: number, token: string, total: string) => boolean
|
||||
|
||||
/**
|
||||
*
|
||||
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
||||
@ -61,7 +67,8 @@ declare class InferenceModel {
|
||||
|
||||
generate(
|
||||
prompt: string,
|
||||
options?: Partial<LLModelPromptContext>
|
||||
options?: Partial<LLModelPromptContext>,
|
||||
callback?: TokenCallback
|
||||
): Promise<string>;
|
||||
|
||||
/**
|
||||
@ -132,13 +139,14 @@ declare class LLModel {
|
||||
* Use the prompt function exported for a value
|
||||
* @param q The prompt input.
|
||||
* @param params Optional parameters for the prompt context.
|
||||
* @param callback - optional callback to control token generation.
|
||||
* @returns The result of the model prompt.
|
||||
*/
|
||||
raw_prompt(
|
||||
q: string,
|
||||
params: Partial<LLModelPromptContext>,
|
||||
callback: (res: string) => void
|
||||
): void; // TODO work on return type
|
||||
callback?: TokenCallback
|
||||
): Promise<string>
|
||||
|
||||
/**
|
||||
* Embed text with the model. Keep in mind that
|
||||
@ -176,10 +184,11 @@ declare class LLModel {
|
||||
hasGpuDevice(): boolean
|
||||
/**
|
||||
* GPUs that are usable for this LLModel
|
||||
* @param nCtx Maximum size of context window
|
||||
* @throws if hasGpuDevice returns false (i think)
|
||||
* @returns
|
||||
*/
|
||||
listGpu() : GpuDevice[]
|
||||
listGpu(nCtx: number) : GpuDevice[]
|
||||
|
||||
/**
|
||||
* delete and cleanup the native model
|
||||
@ -224,6 +233,16 @@ interface LoadModelOptions {
|
||||
model.
|
||||
*/
|
||||
device?: string;
|
||||
/*
|
||||
* The Maximum window size of this model
|
||||
* Default of 2048
|
||||
*/
|
||||
nCtx?: number;
|
||||
/*
|
||||
* Number of gpu layers needed
|
||||
* Default of 100
|
||||
*/
|
||||
ngl?: number;
|
||||
}
|
||||
|
||||
interface InferenceModelOptions extends LoadModelOptions {
|
||||
@ -442,14 +461,21 @@ interface LLModelPromptContext {
|
||||
contextErase: number;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* TODO: Help wanted to implement this
|
||||
* Creates an async generator of tokens
|
||||
* @param {InferenceModel} llmodel - The language model object.
|
||||
* @param {PromptMessage[]} messages - The array of messages for the conversation.
|
||||
* @param {CompletionOptions} options - The options for creating the completion.
|
||||
* @param {TokenCallback} callback - optional callback to control token generation.
|
||||
* @returns {AsyncGenerator<string>} The stream of generated tokens
|
||||
*/
|
||||
declare function createTokenStream(
|
||||
llmodel: LLModel,
|
||||
declare function generateTokens(
|
||||
llmodel: InferenceModel,
|
||||
messages: PromptMessage[],
|
||||
options: CompletionOptions
|
||||
): (ll: LLModel) => AsyncGenerator<string>;
|
||||
options: CompletionOptions,
|
||||
callback?: TokenCallback
|
||||
): AsyncGenerator<string>;
|
||||
/**
|
||||
* From python api:
|
||||
* models will be stored in (homedir)/.cache/gpt4all/`
|
||||
@ -568,7 +594,7 @@ export {
|
||||
loadModel,
|
||||
createCompletion,
|
||||
createEmbedding,
|
||||
createTokenStream,
|
||||
generateTokens,
|
||||
DEFAULT_DIRECTORY,
|
||||
DEFAULT_LIBRARIES_DIRECTORY,
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
|
@ -18,6 +18,7 @@ const {
|
||||
DEFAULT_MODEL_LIST_URL,
|
||||
} = require("./config.js");
|
||||
const { InferenceModel, EmbeddingModel } = require("./models.js");
|
||||
const Stream = require('stream')
|
||||
const assert = require("assert");
|
||||
|
||||
/**
|
||||
@ -36,6 +37,8 @@ async function loadModel(modelName, options = {}) {
|
||||
allowDownload: true,
|
||||
verbose: true,
|
||||
device: 'cpu',
|
||||
nCtx: 2048,
|
||||
ngl : 100,
|
||||
...options,
|
||||
};
|
||||
|
||||
@ -58,11 +61,14 @@ async function loadModel(modelName, options = {}) {
|
||||
model_path: loadOptions.modelPath,
|
||||
library_path: existingPaths,
|
||||
device: loadOptions.device,
|
||||
nCtx: loadOptions.nCtx,
|
||||
ngl: loadOptions.ngl
|
||||
};
|
||||
|
||||
if (loadOptions.verbose) {
|
||||
console.debug("Creating LLModel with options:", llmOptions);
|
||||
}
|
||||
console.log(modelConfig)
|
||||
const llmodel = new LLModel(llmOptions);
|
||||
if (loadOptions.type === "embedding") {
|
||||
return new EmbeddingModel(llmodel, modelConfig);
|
||||
@ -149,11 +155,7 @@ const defaultCompletionOptions = {
|
||||
...DEFAULT_PROMPT_CONTEXT,
|
||||
};
|
||||
|
||||
async function createCompletion(
|
||||
model,
|
||||
messages,
|
||||
options = defaultCompletionOptions
|
||||
) {
|
||||
function preparePromptAndContext(model,messages,options){
|
||||
if (options.hasDefaultHeader !== undefined) {
|
||||
console.warn(
|
||||
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
|
||||
@ -180,6 +182,7 @@ async function createCompletion(
|
||||
...promptContext
|
||||
} = optionsWithDefaults;
|
||||
|
||||
|
||||
const prompt = formatChatPrompt(messages, {
|
||||
systemPromptTemplate,
|
||||
defaultSystemPrompt: model.config.systemPrompt,
|
||||
@ -192,11 +195,28 @@ async function createCompletion(
|
||||
// promptFooter: '### Response:',
|
||||
});
|
||||
|
||||
return {
|
||||
prompt, promptContext, verbose
|
||||
}
|
||||
}
|
||||
|
||||
async function createCompletion(
|
||||
model,
|
||||
messages,
|
||||
options = defaultCompletionOptions
|
||||
) {
|
||||
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
|
||||
|
||||
if (verbose) {
|
||||
console.debug("Sending Prompt:\n" + prompt);
|
||||
}
|
||||
|
||||
const response = await model.generate(prompt, promptContext);
|
||||
let tokensGenerated = 0
|
||||
|
||||
const response = await model.generate(prompt, promptContext,() => {
|
||||
tokensGenerated++;
|
||||
return true;
|
||||
});
|
||||
|
||||
if (verbose) {
|
||||
console.debug("Received Response:\n" + response);
|
||||
@ -206,8 +226,8 @@ async function createCompletion(
|
||||
llmodel: model.llm.name(),
|
||||
usage: {
|
||||
prompt_tokens: prompt.length,
|
||||
completion_tokens: response.length, //TODO
|
||||
total_tokens: prompt.length + response.length, //TODO
|
||||
completion_tokens: tokensGenerated,
|
||||
total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
@ -220,8 +240,77 @@ async function createCompletion(
|
||||
};
|
||||
}
|
||||
|
||||
function createTokenStream() {
|
||||
throw Error("This API has not been completed yet!");
|
||||
function _internal_createTokenStream(stream,model,
|
||||
messages,
|
||||
options = defaultCompletionOptions,callback = undefined) {
|
||||
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
|
||||
|
||||
|
||||
if (verbose) {
|
||||
console.debug("Sending Prompt:\n" + prompt);
|
||||
}
|
||||
|
||||
model.generate(prompt, promptContext,(tokenId, token, total) => {
|
||||
stream.push(token);
|
||||
|
||||
if(callback !== undefined){
|
||||
return callback(tokenId,token,total);
|
||||
}
|
||||
|
||||
return true;
|
||||
}).then(() => {
|
||||
stream.end()
|
||||
})
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
function _createTokenStream(model,
|
||||
messages,
|
||||
options = defaultCompletionOptions,callback = undefined) {
|
||||
|
||||
// Silent crash if we dont do this here
|
||||
const stream = new Stream.PassThrough({
|
||||
encoding: 'utf-8'
|
||||
});
|
||||
return _internal_createTokenStream(stream,model,messages,options,callback);
|
||||
}
|
||||
|
||||
async function* generateTokens(model,
|
||||
messages,
|
||||
options = defaultCompletionOptions, callback = undefined) {
|
||||
const stream = _createTokenStream(model,messages,options,callback)
|
||||
|
||||
let bHasFinished = false;
|
||||
let activeDataCallback = undefined;
|
||||
const finishCallback = () => {
|
||||
bHasFinished = true;
|
||||
if(activeDataCallback !== undefined){
|
||||
activeDataCallback(undefined);
|
||||
}
|
||||
}
|
||||
|
||||
stream.on("finish",finishCallback)
|
||||
|
||||
while (!bHasFinished) {
|
||||
const token = await new Promise((res) => {
|
||||
|
||||
activeDataCallback = (d) => {
|
||||
stream.off("data",activeDataCallback)
|
||||
activeDataCallback = undefined
|
||||
res(d);
|
||||
}
|
||||
stream.on('data', activeDataCallback)
|
||||
})
|
||||
|
||||
if (token == undefined) {
|
||||
break;
|
||||
}
|
||||
|
||||
yield token;
|
||||
}
|
||||
|
||||
stream.off("finish",finishCallback);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
@ -238,5 +327,5 @@ module.exports = {
|
||||
downloadModel,
|
||||
retrieveModel,
|
||||
loadModel,
|
||||
createTokenStream,
|
||||
generateTokens
|
||||
};
|
||||
|
@ -9,10 +9,10 @@ class InferenceModel {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
async generate(prompt, promptContext) {
|
||||
async generate(prompt, promptContext,callback) {
|
||||
warnOnSnakeCaseKeys(promptContext);
|
||||
const normalizedPromptContext = normalizePromptContext(promptContext);
|
||||
const result = this.llm.raw_prompt(prompt, normalizedPromptContext, () => {});
|
||||
const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,6 @@ async function retrieveModel(modelName, options = {}) {
|
||||
verbose: true,
|
||||
...options,
|
||||
};
|
||||
|
||||
await mkdirp(retrieveOptions.modelPath);
|
||||
|
||||
const modelFileName = appendBinSuffixIfMissing(modelName);
|
||||
@ -284,7 +283,6 @@ async function retrieveModel(modelName, options = {}) {
|
||||
} else {
|
||||
throw Error("Failed to retrieve model.");
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user