typescript: async generator and token stream (#1897)

Signed-off-by: Tare Ebelo <75279482+TareHimself@users.noreply.github.com>
Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: jacob <jacoobes@sern.dev>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
TareHimself 2024-02-24 17:50:14 -05:00 committed by GitHub
parent ef518fae3e
commit a153cc5b25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1517 additions and 955 deletions

View File

@ -611,6 +611,7 @@ jobs:
$Env:Path += ";$MinGwBin" $Env:Path += ";$MinGwBin"
$Env:Path += ";C:\Program Files\CMake\bin" $Env:Path += ";C:\Program Files\CMake\bin"
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin" $Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
cd gpt4all-backend cd gpt4all-backend
mkdir runtimes/win-x64 mkdir runtimes/win-x64
cd runtimes/win-x64 cd runtimes/win-x64
@ -651,6 +652,7 @@ jobs:
command: | command: |
$Env:Path += ";C:\Program Files\CMake\bin" $Env:Path += ";C:\Program Files\CMake\bin"
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin" $Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
cd gpt4all-backend cd gpt4all-backend
mkdir runtimes/win-x64_msvc mkdir runtimes/win-x64_msvc
cd runtimes/win-x64_msvc cd runtimes/win-x64_msvc
@ -1107,8 +1109,12 @@ workflows:
jobs: jobs:
- hold: - hold:
type: approval type: approval
- csharp-hold:
type: approval
- nuget-hold: - nuget-hold:
type: approval type: approval
- nodejs-hold:
type: approval
- npm-hold: - npm-hold:
type: approval type: approval
- build-bindings-backend-linux: - build-bindings-backend-linux:
@ -1151,21 +1157,21 @@ workflows:
branches: branches:
only: only:
requires: requires:
- npm-hold - nodejs-hold
- build-bindings-backend-linux - build-bindings-backend-linux
- build-nodejs-windows: - build-nodejs-windows:
filters: filters:
branches: branches:
only: only:
requires: requires:
- npm-hold - nodejs-hold
- build-bindings-backend-windows-msvc - build-bindings-backend-windows-msvc
- build-nodejs-macos: - build-nodejs-macos:
filters: filters:
branches: branches:
only: only:
requires: requires:
- npm-hold - nodejs-hold
- build-bindings-backend-macos - build-bindings-backend-macos
@ -1175,21 +1181,21 @@ workflows:
branches: branches:
only: only:
requires: requires:
- nuget-hold - csharp-hold
- build-bindings-backend-linux - build-bindings-backend-linux
- build-csharp-windows: - build-csharp-windows:
filters: filters:
branches: branches:
only: only:
requires: requires:
- nuget-hold - csharp-hold
- build-bindings-backend-windows - build-bindings-backend-windows
- build-csharp-macos: - build-csharp-macos:
filters: filters:
branches: branches:
only: only:
requires: requires:
- nuget-hold - csharp-hold
- build-bindings-backend-macos - build-bindings-backend-macos
- store-and-upload-nupkgs: - store-and-upload-nupkgs:
filters: filters:

View File

@ -159,6 +159,7 @@ This package is in active development, and breaking changes may happen until the
* [mpt](#mpt) * [mpt](#mpt)
* [replit](#replit) * [replit](#replit)
* [type](#type) * [type](#type)
* [TokenCallback](#tokencallback)
* [InferenceModel](#inferencemodel) * [InferenceModel](#inferencemodel)
* [dispose](#dispose) * [dispose](#dispose)
* [EmbeddingModel](#embeddingmodel) * [EmbeddingModel](#embeddingmodel)
@ -184,16 +185,17 @@ This package is in active development, and breaking changes may happen until the
* [Parameters](#parameters-5) * [Parameters](#parameters-5)
* [hasGpuDevice](#hasgpudevice) * [hasGpuDevice](#hasgpudevice)
* [listGpu](#listgpu) * [listGpu](#listgpu)
* [Parameters](#parameters-6)
* [dispose](#dispose-2) * [dispose](#dispose-2)
* [GpuDevice](#gpudevice) * [GpuDevice](#gpudevice)
* [type](#type-2) * [type](#type-2)
* [LoadModelOptions](#loadmodeloptions) * [LoadModelOptions](#loadmodeloptions)
* [loadModel](#loadmodel) * [loadModel](#loadmodel)
* [Parameters](#parameters-6)
* [createCompletion](#createcompletion)
* [Parameters](#parameters-7) * [Parameters](#parameters-7)
* [createEmbedding](#createembedding) * [createCompletion](#createcompletion)
* [Parameters](#parameters-8) * [Parameters](#parameters-8)
* [createEmbedding](#createembedding)
* [Parameters](#parameters-9)
* [CompletionOptions](#completionoptions) * [CompletionOptions](#completionoptions)
* [verbose](#verbose) * [verbose](#verbose)
* [systemPromptTemplate](#systemprompttemplate) * [systemPromptTemplate](#systemprompttemplate)
@ -225,15 +227,15 @@ This package is in active development, and breaking changes may happen until the
* [repeatPenalty](#repeatpenalty) * [repeatPenalty](#repeatpenalty)
* [repeatLastN](#repeatlastn) * [repeatLastN](#repeatlastn)
* [contextErase](#contexterase) * [contextErase](#contexterase)
* [createTokenStream](#createtokenstream) * [generateTokens](#generatetokens)
* [Parameters](#parameters-9) * [Parameters](#parameters-10)
* [DEFAULT\_DIRECTORY](#default_directory) * [DEFAULT\_DIRECTORY](#default_directory)
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory) * [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
* [DEFAULT\_MODEL\_CONFIG](#default_model_config) * [DEFAULT\_MODEL\_CONFIG](#default_model_config)
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context) * [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url) * [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
* [downloadModel](#downloadmodel) * [downloadModel](#downloadmodel)
* [Parameters](#parameters-10) * [Parameters](#parameters-11)
* [Examples](#examples) * [Examples](#examples)
* [DownloadModelOptions](#downloadmodeloptions) * [DownloadModelOptions](#downloadmodeloptions)
* [modelPath](#modelpath) * [modelPath](#modelpath)
@ -279,6 +281,12 @@ Model architecture. This argument currently does not have any functionality and
Type: ModelType Type: ModelType
#### TokenCallback
Callback for controlling token generation
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
#### InferenceModel #### InferenceModel
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
@ -362,9 +370,9 @@ Use the prompt function exported for a value
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. * `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context. * `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
* `callback` **function (res: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): void**&#x20; * `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
Returns **void** The result of the model prompt. Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt.
##### embed ##### embed
@ -424,6 +432,12 @@ Returns **[boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/
GPUs that are usable for this LLModel GPUs that are usable for this LLModel
###### Parameters
* `nCtx` **[number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)** Maximum size of context window
<!---->
* Throws **any** if hasGpuDevice returns false (i think) * Throws **any** if hasGpuDevice returns false (i think)
Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>**&#x20; Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>**&#x20;
@ -690,17 +704,18 @@ The percentage of context to erase if the context window is exceeded.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### createTokenStream #### generateTokens
TODO: Help wanted to implement this Creates an async generator of tokens
##### Parameters ##### Parameters
* `llmodel` **[LLModel](#llmodel)**&#x20; * `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>**&#x20; * `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
* `options` **[CompletionOptions](#completionoptions)**&#x20; * `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
Returns **function (ll: [LLModel](#llmodel)): AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>**&#x20; Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### DEFAULT\_DIRECTORY #### DEFAULT\_DIRECTORY

View File

@ -136,14 +136,18 @@ yarn test
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs * \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete * \[x] generateTokens is the new name for this^
* \[x] proper unit testing (integrate with circle ci) * \[x] proper unit testing (integrate with circle ci)
* \[x] publish to npm under alpha tag `gpt4all@alpha` * \[x] publish to npm under alpha tag `gpt4all@alpha`
* \[x] have more people test on other platforms (mac tester needed) * \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend * \[x] switch to new pluggable backend
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
### API Reference ### API Reference

View File

@ -3,9 +3,9 @@
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) { Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Function self = DefineClass(env, "LLModel", { Napi::Function self = DefineClass(env, "LLModel", {
InstanceMethod("type", &NodeModelWrapper::getType), InstanceMethod("type", &NodeModelWrapper::GetType),
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded), InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
InstanceMethod("name", &NodeModelWrapper::getName), InstanceMethod("name", &NodeModelWrapper::GetName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize), InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt), InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount), InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
@ -28,14 +28,14 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
{ {
auto env = info.Env(); auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048, 100) )); return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) ));
} }
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
{ {
auto env = info.Env(); auto env = info.Env();
int num_devices = 0; int num_devices = 0;
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str()); auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices); llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
if(all_devices == nullptr) { if(all_devices == nullptr) {
Napi::Error::New( Napi::Error::New(
@ -70,7 +70,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
return js_array; return js_array;
} }
Napi::Value NodeModelWrapper::getType(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info)
{ {
if(type.empty()) { if(type.empty()) {
return info.Env().Undefined(); return info.Env().Undefined();
@ -132,6 +132,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
library_path = "."; library_path = ".";
} }
device = config_object.Get("device").As<Napi::String>(); device = config_object.Get("device").As<Napi::String>();
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
} }
llmodel_set_implementation_search_path(library_path.c_str()); llmodel_set_implementation_search_path(library_path.c_str());
const char* e; const char* e;
@ -148,20 +151,17 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
return; return;
} }
if(device != "cpu") { if(device != "cpu") {
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str()); size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers);
std::cout << "Initiating GPU\n";
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str()); auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
if(success) { if(!success) {
std::cout << "GPU init successfully\n";
} else {
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215 //https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
//Haven't implemented this but it is still open to contribution //Haven't implemented this but it is still open to contribution
std::cout << "WARNING: Failed to init GPU\n"; std::cout << "WARNING: Failed to init GPU\n";
} }
} }
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048, 100); auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
if(!success) { if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return; return;
@ -254,6 +254,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
.repeat_last_n = 10, .repeat_last_n = 10,
.context_erase = 0.5 .context_erase = 0.5
}; };
PromptWorkerConfig promptWorkerConfig;
if(info[1].IsObject()) if(info[1].IsObject())
{ {
auto inputObject = info[1].As<Napi::Object>(); auto inputObject = info[1].As<Napi::Object>();
@ -285,29 +288,33 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
if(inputObject.Has("context_erase")) if(inputObject.Has("context_erase"))
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue(); promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
} }
//copy to protect llmodel resources when splitting to new thread else
llmodel_prompt_context copiedPrompt = promptContext; {
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
std::string copiedQuestion = question; if(info.Length() >= 3 && info[2].IsFunction()){
PromptWorkContext pc = { promptWorkerConfig.bHasTokenCallback = true;
copiedQuestion, promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>();
inference_, }
copiedPrompt,
""
};
auto threadSafeContext = new TsfnContext(env, pc); //copy to protect llmodel resources when splitting to new thread
threadSafeContext->tsfn = Napi::ThreadSafeFunction::New( // llmodel_prompt_context copiedPrompt = promptContext;
env, // Environment promptWorkerConfig.context = promptContext;
info[2].As<Napi::Function>(), // JS function from caller promptWorkerConfig.model = GetInference();
"PromptCallback", // Resource name promptWorkerConfig.mutex = &inference_mutex;
0, // Max queue size (0 = unlimited). promptWorkerConfig.prompt = question;
1, // Initial thread count promptWorkerConfig.result = "";
threadSafeContext, // Context,
FinalizerCallback, // Finalizer
(void*)nullptr // Finalizer data auto worker = new PromptWorker(env, promptWorkerConfig);
);
threadSafeContext->nativeThread = std::thread(threadEntry, threadSafeContext); worker->Queue();
return threadSafeContext->deferred_.Promise();
return worker->GetPromise();
} }
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) { void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
llmodel_model_destroy(inference_); llmodel_model_destroy(inference_);
@ -321,7 +328,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
} }
} }
Napi::Value NodeModelWrapper::getName(const Napi::CallbackInfo& info) { Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(), name); return Napi::String::New(info.Env(), name);
} }
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) { Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {

View File

@ -7,14 +7,17 @@
#include <memory> #include <memory>
#include <filesystem> #include <filesystem>
#include <set> #include <set>
#include <mutex>
namespace fs = std::filesystem; namespace fs = std::filesystem;
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> { class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
public: public:
NodeModelWrapper(const Napi::CallbackInfo &); NodeModelWrapper(const Napi::CallbackInfo &);
//virtual ~NodeModelWrapper(); //virtual ~NodeModelWrapper();
Napi::Value getType(const Napi::CallbackInfo& info); Napi::Value GetType(const Napi::CallbackInfo& info);
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info); Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
Napi::Value StateSize(const Napi::CallbackInfo& info); Napi::Value StateSize(const Napi::CallbackInfo& info);
//void Finalize(Napi::Env env) override; //void Finalize(Napi::Env env) override;
@ -25,7 +28,7 @@ public:
Napi::Value Prompt(const Napi::CallbackInfo& info); Napi::Value Prompt(const Napi::CallbackInfo& info);
void SetThreadCount(const Napi::CallbackInfo& info); void SetThreadCount(const Napi::CallbackInfo& info);
void Dispose(const Napi::CallbackInfo& info); void Dispose(const Napi::CallbackInfo& info);
Napi::Value getName(const Napi::CallbackInfo& info); Napi::Value GetName(const Napi::CallbackInfo& info);
Napi::Value ThreadCount(const Napi::CallbackInfo& info); Napi::Value ThreadCount(const Napi::CallbackInfo& info);
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info); Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info); Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
@ -48,8 +51,12 @@ private:
*/ */
llmodel_model inference_; llmodel_model inference_;
std::mutex inference_mutex;
std::string type; std::string type;
// corresponds to LLModel::name() in typescript // corresponds to LLModel::name() in typescript
std::string name; std::string name;
int nCtx{};
int nGpuLayers{};
std::string full_model_path; std::string full_model_path;
}; };

View File

@ -1,6 +1,6 @@
{ {
"name": "gpt4all", "name": "gpt4all",
"version": "3.1.0", "version": "3.2.0",
"packageManager": "yarn@3.6.1", "packageManager": "yarn@3.6.1",
"main": "src/gpt4all.js", "main": "src/gpt4all.js",
"repository": "nomic-ai/gpt4all", "repository": "nomic-ai/gpt4all",

View File

@ -1,60 +1,146 @@
#include "prompt.h" #include "prompt.h"
#include <future>
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
if(_config.bHasTokenCallback){
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
}
}
TsfnContext::TsfnContext(Napi::Env env, const PromptWorkContext& pc) PromptWorker::~PromptWorker()
: deferred_(Napi::Promise::Deferred::New(env)), pc(pc) { {
} if(_config.bHasTokenCallback){
namespace { _tsfn.Release();
static std::string *res; }
} }
bool response_callback(int32_t token_id, const char *response) { void PromptWorker::Execute()
*res += response; {
return token_id != -1; _config.mutex->lock();
}
bool recalculate_callback (bool isrecalculating) {
return isrecalculating;
};
bool prompt_callback (int32_t tid) {
return true;
};
// The thread entry point. This takes as its arguments the specific LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
// threadsafe-function context created inside the main thread.
void threadEntry(TsfnContext* context) {
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
res = &context->pc.res;
// Perform a call into JavaScript.
napi_status status =
context->tsfn.BlockingCall(&context->pc,
[](Napi::Env env, Napi::Function jsCallback, PromptWorkContext* pc) {
llmodel_prompt(
pc->inference_,
pc->question.c_str(),
&prompt_callback,
&response_callback,
&recalculate_callback,
&pc->prompt_params
);
});
if (status != napi_ok) { auto ctx = &_config.context;
Napi::Error::Fatal(
"ThreadEntry",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
// Release the thread-safe function. This decrements the internal thread
// count, and will perform finalization since the count will reach 0.
context->tsfn.Release();
}
void FinalizerCallback(Napi::Env env, if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
void* finalizeData, wrapper->promptContext.tokens.resize(ctx->n_past);
TsfnContext* context) {
// Resolve the Promise previously returned to JS // Copy the C prompt context
context->deferred_.Resolve(Napi::String::New(env, context->pc.res)); wrapper->promptContext.n_past = ctx->n_past;
// Wait for the thread to finish executing before proceeding. wrapper->promptContext.n_ctx = ctx->n_ctx;
context->nativeThread.join(); wrapper->promptContext.n_predict = ctx->n_predict;
delete context; wrapper->promptContext.top_k = ctx->top_k;
} wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Napi::Error::Fatal(
// "SUPRA",
// "About to prompt");
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt,
[](int32_t tid) { return true; },
[this](int32_t token_id, const std::string tok)
{
return ResponseCallback(token_id, tok);
},
[](bool isRecalculating)
{
return isRecalculating;
},
wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
promise.Resolve(Napi::String::New(Env(), result));
}
void PromptWorker::OnError(const Napi::Error &e)
{
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
}
if(!_config.bHasTokenCallback){
return true;
}
result += token;
std::promise<bool> promise;
auto info = new TokenCallbackInfo();
info->tokenId = token_id;
info->token = token;
info->total = result;
auto future = promise.get_future();
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto total = Napi::String::New(env,value->total);
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
promise.set_value(jsResult);
// We're finished with the data.
delete value;
});
if (status != napi_ok) {
Napi::Error::Fatal(
"PromptWorkerResponseCallback",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t tid)
{
return true;
}

View File

@ -1,44 +1,59 @@
#ifndef TSFN_CONTEXT_H #ifndef PREDICT_WORKER_H
#define TSFN_CONTEXT_H #define PREDICT_WORKER_H
#include "napi.h" #include "napi.h"
#include "llmodel_c.h" #include "llmodel_c.h"
#include "llmodel.h"
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <iostream> #include <iostream>
#include <atomic> #include <atomic>
#include <memory> #include <memory>
struct PromptWorkContext {
std::string question;
llmodel_model inference_;
llmodel_prompt_context prompt_params;
std::string res;
}; struct TokenCallbackInfo
{
int32_t tokenId;
std::string total;
std::string token;
};
struct TsfnContext { struct LLModelWrapper
public: {
TsfnContext(Napi::Env env, const PromptWorkContext &pc); LLModel *llModel = nullptr;
std::thread nativeThread; LLModel::PromptContext promptContext;
Napi::Promise::Deferred deferred_; ~LLModelWrapper() { delete llModel; }
PromptWorkContext pc; };
Napi::ThreadSafeFunction tsfn;
// Some data to pass around struct PromptWorkerConfig
// int ints[ARRAY_LENGTH]; {
Napi::Function tokenCallback;
bool bHasTokenCallback = false;
llmodel_model model;
std::mutex * mutex;
std::string prompt;
llmodel_prompt_context context;
std::string result;
};
}; class PromptWorker : public Napi::AsyncWorker
{
public:
PromptWorker(Napi::Env env, PromptWorkerConfig config);
~PromptWorker();
void Execute() override;
void OnOK() override;
void OnError(const Napi::Error &e) override;
Napi::Promise GetPromise();
// The thread entry point. This takes as its arguments the specific bool ResponseCallback(int32_t token_id, const std::string token);
// threadsafe-function context created inside the main thread. bool RecalculateCallback(bool isrecalculating);
void threadEntry(TsfnContext*); bool PromptCallback(int32_t tid);
// The thread-safe function finalizer callback. This callback executes private:
// at destruction of thread-safe function, taking as arguments the finalizer Napi::Promise::Deferred promise;
// data and threadsafe-function context. std::string result;
void FinalizerCallback(Napi::Env, void* finalizeData, TsfnContext*); PromptWorkerConfig _config;
Napi::ThreadSafeFunction _tsfn;
};
bool response_callback(int32_t token_id, const char *response); #endif // PREDICT_WORKER_H
bool recalculate_callback (bool isrecalculating);
bool prompt_callback (int32_t tid);
#endif // TSFN_CONTEXT_H

View File

@ -0,0 +1,41 @@
import gpt from '../src/gpt4all.js'
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
process.stdout.write('Response: ')
const tokens = gpt.generateTokens(model, [{
role: 'user',
content: "How are you ?"
}], { nPredict: 2048 })
for await (const token of tokens){
process.stdout.write(token);
}
const result = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure?"
}])
console.log(result)
const result2 = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure you sure?"
}])
console.log(result2)
const tokens2 = gpt.generateTokens(model, [{
role: 'user',
content: "If 3 + 3 is 5, what is 2 + 2?"
}], { nPredict: 2048 })
for await (const token of tokens2){
process.stdout.write(token);
}
console.log("done")
model.dispose();

View File

@ -49,6 +49,12 @@ interface ModelConfig {
path: string; path: string;
url?: string; url?: string;
} }
/**
* Callback for controlling token generation
*/
type TokenCallback = (tokenId: number, token: string, total: string) => boolean
/** /**
* *
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. * InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
@ -61,7 +67,8 @@ declare class InferenceModel {
generate( generate(
prompt: string, prompt: string,
options?: Partial<LLModelPromptContext> options?: Partial<LLModelPromptContext>,
callback?: TokenCallback
): Promise<string>; ): Promise<string>;
/** /**
@ -132,13 +139,14 @@ declare class LLModel {
* Use the prompt function exported for a value * Use the prompt function exported for a value
* @param q The prompt input. * @param q The prompt input.
* @param params Optional parameters for the prompt context. * @param params Optional parameters for the prompt context.
* @param callback - optional callback to control token generation.
* @returns The result of the model prompt. * @returns The result of the model prompt.
*/ */
raw_prompt( raw_prompt(
q: string, q: string,
params: Partial<LLModelPromptContext>, params: Partial<LLModelPromptContext>,
callback: (res: string) => void callback?: TokenCallback
): void; // TODO work on return type ): Promise<string>
/** /**
* Embed text with the model. Keep in mind that * Embed text with the model. Keep in mind that
@ -176,10 +184,11 @@ declare class LLModel {
hasGpuDevice(): boolean hasGpuDevice(): boolean
/** /**
* GPUs that are usable for this LLModel * GPUs that are usable for this LLModel
* @param nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think) * @throws if hasGpuDevice returns false (i think)
* @returns * @returns
*/ */
listGpu() : GpuDevice[] listGpu(nCtx: number) : GpuDevice[]
/** /**
* delete and cleanup the native model * delete and cleanup the native model
@ -224,6 +233,16 @@ interface LoadModelOptions {
model. model.
*/ */
device?: string; device?: string;
/*
* The Maximum window size of this model
* Default of 2048
*/
nCtx?: number;
/*
* Number of gpu layers needed
* Default of 100
*/
ngl?: number;
} }
interface InferenceModelOptions extends LoadModelOptions { interface InferenceModelOptions extends LoadModelOptions {
@ -442,14 +461,21 @@ interface LLModelPromptContext {
contextErase: number; contextErase: number;
} }
/** /**
* TODO: Help wanted to implement this * Creates an async generator of tokens
* @param {InferenceModel} llmodel - The language model object.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @param {TokenCallback} callback - optional callback to control token generation.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/ */
declare function createTokenStream( declare function generateTokens(
llmodel: LLModel, llmodel: InferenceModel,
messages: PromptMessage[], messages: PromptMessage[],
options: CompletionOptions options: CompletionOptions,
): (ll: LLModel) => AsyncGenerator<string>; callback?: TokenCallback
): AsyncGenerator<string>;
/** /**
* From python api: * From python api:
* models will be stored in (homedir)/.cache/gpt4all/` * models will be stored in (homedir)/.cache/gpt4all/`
@ -568,7 +594,7 @@ export {
loadModel, loadModel,
createCompletion, createCompletion,
createEmbedding, createEmbedding,
createTokenStream, generateTokens,
DEFAULT_DIRECTORY, DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY,
DEFAULT_MODEL_CONFIG, DEFAULT_MODEL_CONFIG,

View File

@ -18,6 +18,7 @@ const {
DEFAULT_MODEL_LIST_URL, DEFAULT_MODEL_LIST_URL,
} = require("./config.js"); } = require("./config.js");
const { InferenceModel, EmbeddingModel } = require("./models.js"); const { InferenceModel, EmbeddingModel } = require("./models.js");
const Stream = require('stream')
const assert = require("assert"); const assert = require("assert");
/** /**
@ -36,6 +37,8 @@ async function loadModel(modelName, options = {}) {
allowDownload: true, allowDownload: true,
verbose: true, verbose: true,
device: 'cpu', device: 'cpu',
nCtx: 2048,
ngl : 100,
...options, ...options,
}; };
@ -58,11 +61,14 @@ async function loadModel(modelName, options = {}) {
model_path: loadOptions.modelPath, model_path: loadOptions.modelPath,
library_path: existingPaths, library_path: existingPaths,
device: loadOptions.device, device: loadOptions.device,
nCtx: loadOptions.nCtx,
ngl: loadOptions.ngl
}; };
if (loadOptions.verbose) { if (loadOptions.verbose) {
console.debug("Creating LLModel with options:", llmOptions); console.debug("Creating LLModel with options:", llmOptions);
} }
console.log(modelConfig)
const llmodel = new LLModel(llmOptions); const llmodel = new LLModel(llmOptions);
if (loadOptions.type === "embedding") { if (loadOptions.type === "embedding") {
return new EmbeddingModel(llmodel, modelConfig); return new EmbeddingModel(llmodel, modelConfig);
@ -149,11 +155,7 @@ const defaultCompletionOptions = {
...DEFAULT_PROMPT_CONTEXT, ...DEFAULT_PROMPT_CONTEXT,
}; };
async function createCompletion( function preparePromptAndContext(model,messages,options){
model,
messages,
options = defaultCompletionOptions
) {
if (options.hasDefaultHeader !== undefined) { if (options.hasDefaultHeader !== undefined) {
console.warn( console.warn(
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead" "hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
@ -180,6 +182,7 @@ async function createCompletion(
...promptContext ...promptContext
} = optionsWithDefaults; } = optionsWithDefaults;
const prompt = formatChatPrompt(messages, { const prompt = formatChatPrompt(messages, {
systemPromptTemplate, systemPromptTemplate,
defaultSystemPrompt: model.config.systemPrompt, defaultSystemPrompt: model.config.systemPrompt,
@ -192,11 +195,28 @@ async function createCompletion(
// promptFooter: '### Response:', // promptFooter: '### Response:',
}); });
return {
prompt, promptContext, verbose
}
}
async function createCompletion(
model,
messages,
options = defaultCompletionOptions
) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) { if (verbose) {
console.debug("Sending Prompt:\n" + prompt); console.debug("Sending Prompt:\n" + prompt);
} }
const response = await model.generate(prompt, promptContext); let tokensGenerated = 0
const response = await model.generate(prompt, promptContext,() => {
tokensGenerated++;
return true;
});
if (verbose) { if (verbose) {
console.debug("Received Response:\n" + response); console.debug("Received Response:\n" + response);
@ -206,8 +226,8 @@ async function createCompletion(
llmodel: model.llm.name(), llmodel: model.llm.name(),
usage: { usage: {
prompt_tokens: prompt.length, prompt_tokens: prompt.length,
completion_tokens: response.length, //TODO completion_tokens: tokensGenerated,
total_tokens: prompt.length + response.length, //TODO total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt
}, },
choices: [ choices: [
{ {
@ -220,8 +240,77 @@ async function createCompletion(
}; };
} }
function createTokenStream() { function _internal_createTokenStream(stream,model,
throw Error("This API has not been completed yet!"); messages,
options = defaultCompletionOptions,callback = undefined) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
model.generate(prompt, promptContext,(tokenId, token, total) => {
stream.push(token);
if(callback !== undefined){
return callback(tokenId,token,total);
}
return true;
}).then(() => {
stream.end()
})
return stream;
}
function _createTokenStream(model,
messages,
options = defaultCompletionOptions,callback = undefined) {
// Silent crash if we dont do this here
const stream = new Stream.PassThrough({
encoding: 'utf-8'
});
return _internal_createTokenStream(stream,model,messages,options,callback);
}
async function* generateTokens(model,
messages,
options = defaultCompletionOptions, callback = undefined) {
const stream = _createTokenStream(model,messages,options,callback)
let bHasFinished = false;
let activeDataCallback = undefined;
const finishCallback = () => {
bHasFinished = true;
if(activeDataCallback !== undefined){
activeDataCallback(undefined);
}
}
stream.on("finish",finishCallback)
while (!bHasFinished) {
const token = await new Promise((res) => {
activeDataCallback = (d) => {
stream.off("data",activeDataCallback)
activeDataCallback = undefined
res(d);
}
stream.on('data', activeDataCallback)
})
if (token == undefined) {
break;
}
yield token;
}
stream.off("finish",finishCallback);
} }
module.exports = { module.exports = {
@ -238,5 +327,5 @@ module.exports = {
downloadModel, downloadModel,
retrieveModel, retrieveModel,
loadModel, loadModel,
createTokenStream, generateTokens
}; };

View File

@ -9,10 +9,10 @@ class InferenceModel {
this.config = config; this.config = config;
} }
async generate(prompt, promptContext) { async generate(prompt, promptContext,callback) {
warnOnSnakeCaseKeys(promptContext); warnOnSnakeCaseKeys(promptContext);
const normalizedPromptContext = normalizePromptContext(promptContext); const normalizedPromptContext = normalizePromptContext(promptContext);
const result = this.llm.raw_prompt(prompt, normalizedPromptContext, () => {}); const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback);
return result; return result;
} }

View File

@ -224,7 +224,6 @@ async function retrieveModel(modelName, options = {}) {
verbose: true, verbose: true,
...options, ...options,
}; };
await mkdirp(retrieveOptions.modelPath); await mkdirp(retrieveOptions.modelPath);
const modelFileName = appendBinSuffixIfMissing(modelName); const modelFileName = appendBinSuffixIfMissing(modelName);
@ -284,7 +283,6 @@ async function retrieveModel(modelName, options = {}) {
} else { } else {
throw Error("Failed to retrieve model."); throw Error("Failed to retrieve model.");
} }
return config; return config;
} }

File diff suppressed because it is too large Load Diff