mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
typescript: async generator and token stream (#1897)
Signed-off-by: Tare Ebelo <75279482+TareHimself@users.noreply.github.com> Signed-off-by: jacob <jacoobes@sern.dev> Signed-off-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: jacob <jacoobes@sern.dev> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
ef518fae3e
commit
a153cc5b25
@ -611,6 +611,7 @@ jobs:
|
|||||||
$Env:Path += ";$MinGwBin"
|
$Env:Path += ";$MinGwBin"
|
||||||
$Env:Path += ";C:\Program Files\CMake\bin"
|
$Env:Path += ";C:\Program Files\CMake\bin"
|
||||||
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
||||||
|
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
|
||||||
cd gpt4all-backend
|
cd gpt4all-backend
|
||||||
mkdir runtimes/win-x64
|
mkdir runtimes/win-x64
|
||||||
cd runtimes/win-x64
|
cd runtimes/win-x64
|
||||||
@ -651,6 +652,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
$Env:Path += ";C:\Program Files\CMake\bin"
|
$Env:Path += ";C:\Program Files\CMake\bin"
|
||||||
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
|
||||||
|
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
|
||||||
cd gpt4all-backend
|
cd gpt4all-backend
|
||||||
mkdir runtimes/win-x64_msvc
|
mkdir runtimes/win-x64_msvc
|
||||||
cd runtimes/win-x64_msvc
|
cd runtimes/win-x64_msvc
|
||||||
@ -1107,8 +1109,12 @@ workflows:
|
|||||||
jobs:
|
jobs:
|
||||||
- hold:
|
- hold:
|
||||||
type: approval
|
type: approval
|
||||||
|
- csharp-hold:
|
||||||
|
type: approval
|
||||||
- nuget-hold:
|
- nuget-hold:
|
||||||
type: approval
|
type: approval
|
||||||
|
- nodejs-hold:
|
||||||
|
type: approval
|
||||||
- npm-hold:
|
- npm-hold:
|
||||||
type: approval
|
type: approval
|
||||||
- build-bindings-backend-linux:
|
- build-bindings-backend-linux:
|
||||||
@ -1151,21 +1157,21 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- npm-hold
|
- nodejs-hold
|
||||||
- build-bindings-backend-linux
|
- build-bindings-backend-linux
|
||||||
- build-nodejs-windows:
|
- build-nodejs-windows:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- npm-hold
|
- nodejs-hold
|
||||||
- build-bindings-backend-windows-msvc
|
- build-bindings-backend-windows-msvc
|
||||||
- build-nodejs-macos:
|
- build-nodejs-macos:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- npm-hold
|
- nodejs-hold
|
||||||
- build-bindings-backend-macos
|
- build-bindings-backend-macos
|
||||||
|
|
||||||
|
|
||||||
@ -1175,21 +1181,21 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- nuget-hold
|
- csharp-hold
|
||||||
- build-bindings-backend-linux
|
- build-bindings-backend-linux
|
||||||
- build-csharp-windows:
|
- build-csharp-windows:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- nuget-hold
|
- csharp-hold
|
||||||
- build-bindings-backend-windows
|
- build-bindings-backend-windows
|
||||||
- build-csharp-macos:
|
- build-csharp-macos:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
only:
|
only:
|
||||||
requires:
|
requires:
|
||||||
- nuget-hold
|
- csharp-hold
|
||||||
- build-bindings-backend-macos
|
- build-bindings-backend-macos
|
||||||
- store-and-upload-nupkgs:
|
- store-and-upload-nupkgs:
|
||||||
filters:
|
filters:
|
||||||
|
@ -159,6 +159,7 @@ This package is in active development, and breaking changes may happen until the
|
|||||||
* [mpt](#mpt)
|
* [mpt](#mpt)
|
||||||
* [replit](#replit)
|
* [replit](#replit)
|
||||||
* [type](#type)
|
* [type](#type)
|
||||||
|
* [TokenCallback](#tokencallback)
|
||||||
* [InferenceModel](#inferencemodel)
|
* [InferenceModel](#inferencemodel)
|
||||||
* [dispose](#dispose)
|
* [dispose](#dispose)
|
||||||
* [EmbeddingModel](#embeddingmodel)
|
* [EmbeddingModel](#embeddingmodel)
|
||||||
@ -184,16 +185,17 @@ This package is in active development, and breaking changes may happen until the
|
|||||||
* [Parameters](#parameters-5)
|
* [Parameters](#parameters-5)
|
||||||
* [hasGpuDevice](#hasgpudevice)
|
* [hasGpuDevice](#hasgpudevice)
|
||||||
* [listGpu](#listgpu)
|
* [listGpu](#listgpu)
|
||||||
|
* [Parameters](#parameters-6)
|
||||||
* [dispose](#dispose-2)
|
* [dispose](#dispose-2)
|
||||||
* [GpuDevice](#gpudevice)
|
* [GpuDevice](#gpudevice)
|
||||||
* [type](#type-2)
|
* [type](#type-2)
|
||||||
* [LoadModelOptions](#loadmodeloptions)
|
* [LoadModelOptions](#loadmodeloptions)
|
||||||
* [loadModel](#loadmodel)
|
* [loadModel](#loadmodel)
|
||||||
* [Parameters](#parameters-6)
|
|
||||||
* [createCompletion](#createcompletion)
|
|
||||||
* [Parameters](#parameters-7)
|
* [Parameters](#parameters-7)
|
||||||
* [createEmbedding](#createembedding)
|
* [createCompletion](#createcompletion)
|
||||||
* [Parameters](#parameters-8)
|
* [Parameters](#parameters-8)
|
||||||
|
* [createEmbedding](#createembedding)
|
||||||
|
* [Parameters](#parameters-9)
|
||||||
* [CompletionOptions](#completionoptions)
|
* [CompletionOptions](#completionoptions)
|
||||||
* [verbose](#verbose)
|
* [verbose](#verbose)
|
||||||
* [systemPromptTemplate](#systemprompttemplate)
|
* [systemPromptTemplate](#systemprompttemplate)
|
||||||
@ -225,15 +227,15 @@ This package is in active development, and breaking changes may happen until the
|
|||||||
* [repeatPenalty](#repeatpenalty)
|
* [repeatPenalty](#repeatpenalty)
|
||||||
* [repeatLastN](#repeatlastn)
|
* [repeatLastN](#repeatlastn)
|
||||||
* [contextErase](#contexterase)
|
* [contextErase](#contexterase)
|
||||||
* [createTokenStream](#createtokenstream)
|
* [generateTokens](#generatetokens)
|
||||||
* [Parameters](#parameters-9)
|
* [Parameters](#parameters-10)
|
||||||
* [DEFAULT\_DIRECTORY](#default_directory)
|
* [DEFAULT\_DIRECTORY](#default_directory)
|
||||||
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
|
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
|
||||||
* [DEFAULT\_MODEL\_CONFIG](#default_model_config)
|
* [DEFAULT\_MODEL\_CONFIG](#default_model_config)
|
||||||
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
|
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
|
||||||
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
|
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
|
||||||
* [downloadModel](#downloadmodel)
|
* [downloadModel](#downloadmodel)
|
||||||
* [Parameters](#parameters-10)
|
* [Parameters](#parameters-11)
|
||||||
* [Examples](#examples)
|
* [Examples](#examples)
|
||||||
* [DownloadModelOptions](#downloadmodeloptions)
|
* [DownloadModelOptions](#downloadmodeloptions)
|
||||||
* [modelPath](#modelpath)
|
* [modelPath](#modelpath)
|
||||||
@ -279,6 +281,12 @@ Model architecture. This argument currently does not have any functionality and
|
|||||||
|
|
||||||
Type: ModelType
|
Type: ModelType
|
||||||
|
|
||||||
|
#### TokenCallback
|
||||||
|
|
||||||
|
Callback for controlling token generation
|
||||||
|
|
||||||
|
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
|
||||||
|
|
||||||
#### InferenceModel
|
#### InferenceModel
|
||||||
|
|
||||||
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
||||||
@ -362,9 +370,9 @@ Use the prompt function exported for a value
|
|||||||
|
|
||||||
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
|
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
|
||||||
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
|
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
|
||||||
* `callback` **function (res: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): void** 
|
* `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
|
||||||
|
|
||||||
Returns **void** The result of the model prompt.
|
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt.
|
||||||
|
|
||||||
##### embed
|
##### embed
|
||||||
|
|
||||||
@ -424,6 +432,12 @@ Returns **[boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/
|
|||||||
|
|
||||||
GPUs that are usable for this LLModel
|
GPUs that are usable for this LLModel
|
||||||
|
|
||||||
|
###### Parameters
|
||||||
|
|
||||||
|
* `nCtx` **[number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)** Maximum size of context window
|
||||||
|
|
||||||
|
<!---->
|
||||||
|
|
||||||
* Throws **any** if hasGpuDevice returns false (i think)
|
* Throws **any** if hasGpuDevice returns false (i think)
|
||||||
|
|
||||||
Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>** 
|
Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>** 
|
||||||
@ -690,17 +704,18 @@ The percentage of context to erase if the context window is exceeded.
|
|||||||
|
|
||||||
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
|
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
|
||||||
|
|
||||||
#### createTokenStream
|
#### generateTokens
|
||||||
|
|
||||||
TODO: Help wanted to implement this
|
Creates an async generator of tokens
|
||||||
|
|
||||||
##### Parameters
|
##### Parameters
|
||||||
|
|
||||||
* `llmodel` **[LLModel](#llmodel)** 
|
* `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
|
||||||
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** 
|
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
|
||||||
* `options` **[CompletionOptions](#completionoptions)** 
|
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
|
||||||
|
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
|
||||||
|
|
||||||
Returns **function (ll: [LLModel](#llmodel)): AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** 
|
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
|
||||||
|
|
||||||
#### DEFAULT\_DIRECTORY
|
#### DEFAULT\_DIRECTORY
|
||||||
|
|
||||||
|
@ -136,14 +136,18 @@ yarn test
|
|||||||
|
|
||||||
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
|
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
|
||||||
|
|
||||||
|
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
|
||||||
|
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
|
||||||
|
* Should include prebuilds to avoid painful node-gyp errors
|
||||||
|
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
|
||||||
|
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
|
||||||
|
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
|
||||||
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
|
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
|
||||||
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
|
* \[x] generateTokens is the new name for this^
|
||||||
* \[x] proper unit testing (integrate with circle ci)
|
* \[x] proper unit testing (integrate with circle ci)
|
||||||
* \[x] publish to npm under alpha tag `gpt4all@alpha`
|
* \[x] publish to npm under alpha tag `gpt4all@alpha`
|
||||||
* \[x] have more people test on other platforms (mac tester needed)
|
* \[x] have more people test on other platforms (mac tester needed)
|
||||||
* \[x] switch to new pluggable backend
|
* \[x] switch to new pluggable backend
|
||||||
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
|
|
||||||
* Should include prebuilds to avoid painful node-gyp errors
|
|
||||||
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
|
|
||||||
|
|
||||||
### API Reference
|
### API Reference
|
||||||
|
@ -3,9 +3,9 @@
|
|||||||
|
|
||||||
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
||||||
Napi::Function self = DefineClass(env, "LLModel", {
|
Napi::Function self = DefineClass(env, "LLModel", {
|
||||||
InstanceMethod("type", &NodeModelWrapper::getType),
|
InstanceMethod("type", &NodeModelWrapper::GetType),
|
||||||
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
|
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
|
||||||
InstanceMethod("name", &NodeModelWrapper::getName),
|
InstanceMethod("name", &NodeModelWrapper::GetName),
|
||||||
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
|
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
|
||||||
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
|
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
|
||||||
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
|
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
|
||||||
@ -28,14 +28,14 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
|||||||
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||||
{
|
{
|
||||||
auto env = info.Env();
|
auto env = info.Env();
|
||||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048, 100) ));
|
return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) ));
|
||||||
|
|
||||||
}
|
}
|
||||||
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
||||||
{
|
{
|
||||||
auto env = info.Env();
|
auto env = info.Env();
|
||||||
int num_devices = 0;
|
int num_devices = 0;
|
||||||
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str());
|
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
|
||||||
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
|
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
|
||||||
if(all_devices == nullptr) {
|
if(all_devices == nullptr) {
|
||||||
Napi::Error::New(
|
Napi::Error::New(
|
||||||
@ -70,7 +70,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
return js_array;
|
return js_array;
|
||||||
}
|
}
|
||||||
|
|
||||||
Napi::Value NodeModelWrapper::getType(const Napi::CallbackInfo& info)
|
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info)
|
||||||
{
|
{
|
||||||
if(type.empty()) {
|
if(type.empty()) {
|
||||||
return info.Env().Undefined();
|
return info.Env().Undefined();
|
||||||
@ -132,6 +132,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
library_path = ".";
|
library_path = ".";
|
||||||
}
|
}
|
||||||
device = config_object.Get("device").As<Napi::String>();
|
device = config_object.Get("device").As<Napi::String>();
|
||||||
|
|
||||||
|
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
|
||||||
|
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
|
||||||
}
|
}
|
||||||
llmodel_set_implementation_search_path(library_path.c_str());
|
llmodel_set_implementation_search_path(library_path.c_str());
|
||||||
const char* e;
|
const char* e;
|
||||||
@ -148,20 +151,17 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if(device != "cpu") {
|
if(device != "cpu") {
|
||||||
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str());
|
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers);
|
||||||
std::cout << "Initiating GPU\n";
|
|
||||||
|
|
||||||
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
|
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
|
||||||
if(success) {
|
if(!success) {
|
||||||
std::cout << "GPU init successfully\n";
|
|
||||||
} else {
|
|
||||||
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
|
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
|
||||||
//Haven't implemented this but it is still open to contribution
|
//Haven't implemented this but it is still open to contribution
|
||||||
std::cout << "WARNING: Failed to init GPU\n";
|
std::cout << "WARNING: Failed to init GPU\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048, 100);
|
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
|
||||||
if(!success) {
|
if(!success) {
|
||||||
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
||||||
return;
|
return;
|
||||||
@ -254,6 +254,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
.repeat_last_n = 10,
|
.repeat_last_n = 10,
|
||||||
.context_erase = 0.5
|
.context_erase = 0.5
|
||||||
};
|
};
|
||||||
|
|
||||||
|
PromptWorkerConfig promptWorkerConfig;
|
||||||
|
|
||||||
if(info[1].IsObject())
|
if(info[1].IsObject())
|
||||||
{
|
{
|
||||||
auto inputObject = info[1].As<Napi::Object>();
|
auto inputObject = info[1].As<Napi::Object>();
|
||||||
@ -285,29 +288,33 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
if(inputObject.Has("context_erase"))
|
if(inputObject.Has("context_erase"))
|
||||||
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
|
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
|
||||||
}
|
}
|
||||||
//copy to protect llmodel resources when splitting to new thread
|
else
|
||||||
llmodel_prompt_context copiedPrompt = promptContext;
|
{
|
||||||
|
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
|
||||||
|
return info.Env().Undefined();
|
||||||
|
}
|
||||||
|
|
||||||
std::string copiedQuestion = question;
|
if(info.Length() >= 3 && info[2].IsFunction()){
|
||||||
PromptWorkContext pc = {
|
promptWorkerConfig.bHasTokenCallback = true;
|
||||||
copiedQuestion,
|
promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>();
|
||||||
inference_,
|
}
|
||||||
copiedPrompt,
|
|
||||||
""
|
|
||||||
};
|
|
||||||
auto threadSafeContext = new TsfnContext(env, pc);
|
//copy to protect llmodel resources when splitting to new thread
|
||||||
threadSafeContext->tsfn = Napi::ThreadSafeFunction::New(
|
// llmodel_prompt_context copiedPrompt = promptContext;
|
||||||
env, // Environment
|
promptWorkerConfig.context = promptContext;
|
||||||
info[2].As<Napi::Function>(), // JS function from caller
|
promptWorkerConfig.model = GetInference();
|
||||||
"PromptCallback", // Resource name
|
promptWorkerConfig.mutex = &inference_mutex;
|
||||||
0, // Max queue size (0 = unlimited).
|
promptWorkerConfig.prompt = question;
|
||||||
1, // Initial thread count
|
promptWorkerConfig.result = "";
|
||||||
threadSafeContext, // Context,
|
|
||||||
FinalizerCallback, // Finalizer
|
|
||||||
(void*)nullptr // Finalizer data
|
auto worker = new PromptWorker(env, promptWorkerConfig);
|
||||||
);
|
|
||||||
threadSafeContext->nativeThread = std::thread(threadEntry, threadSafeContext);
|
worker->Queue();
|
||||||
return threadSafeContext->deferred_.Promise();
|
|
||||||
|
return worker->GetPromise();
|
||||||
}
|
}
|
||||||
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
|
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
|
||||||
llmodel_model_destroy(inference_);
|
llmodel_model_destroy(inference_);
|
||||||
@ -321,7 +328,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Napi::Value NodeModelWrapper::getName(const Napi::CallbackInfo& info) {
|
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
|
||||||
return Napi::String::New(info.Env(), name);
|
return Napi::String::New(info.Env(), name);
|
||||||
}
|
}
|
||||||
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {
|
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {
|
||||||
|
@ -7,14 +7,17 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
|
|
||||||
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
|
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
NodeModelWrapper(const Napi::CallbackInfo &);
|
NodeModelWrapper(const Napi::CallbackInfo &);
|
||||||
//virtual ~NodeModelWrapper();
|
//virtual ~NodeModelWrapper();
|
||||||
Napi::Value getType(const Napi::CallbackInfo& info);
|
Napi::Value GetType(const Napi::CallbackInfo& info);
|
||||||
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
|
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
|
||||||
Napi::Value StateSize(const Napi::CallbackInfo& info);
|
Napi::Value StateSize(const Napi::CallbackInfo& info);
|
||||||
//void Finalize(Napi::Env env) override;
|
//void Finalize(Napi::Env env) override;
|
||||||
@ -25,7 +28,7 @@ public:
|
|||||||
Napi::Value Prompt(const Napi::CallbackInfo& info);
|
Napi::Value Prompt(const Napi::CallbackInfo& info);
|
||||||
void SetThreadCount(const Napi::CallbackInfo& info);
|
void SetThreadCount(const Napi::CallbackInfo& info);
|
||||||
void Dispose(const Napi::CallbackInfo& info);
|
void Dispose(const Napi::CallbackInfo& info);
|
||||||
Napi::Value getName(const Napi::CallbackInfo& info);
|
Napi::Value GetName(const Napi::CallbackInfo& info);
|
||||||
Napi::Value ThreadCount(const Napi::CallbackInfo& info);
|
Napi::Value ThreadCount(const Napi::CallbackInfo& info);
|
||||||
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
|
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
|
||||||
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
|
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
|
||||||
@ -48,8 +51,12 @@ private:
|
|||||||
*/
|
*/
|
||||||
llmodel_model inference_;
|
llmodel_model inference_;
|
||||||
|
|
||||||
|
std::mutex inference_mutex;
|
||||||
|
|
||||||
std::string type;
|
std::string type;
|
||||||
// corresponds to LLModel::name() in typescript
|
// corresponds to LLModel::name() in typescript
|
||||||
std::string name;
|
std::string name;
|
||||||
|
int nCtx{};
|
||||||
|
int nGpuLayers{};
|
||||||
std::string full_model_path;
|
std::string full_model_path;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "gpt4all",
|
"name": "gpt4all",
|
||||||
"version": "3.1.0",
|
"version": "3.2.0",
|
||||||
"packageManager": "yarn@3.6.1",
|
"packageManager": "yarn@3.6.1",
|
||||||
"main": "src/gpt4all.js",
|
"main": "src/gpt4all.js",
|
||||||
"repository": "nomic-ai/gpt4all",
|
"repository": "nomic-ai/gpt4all",
|
||||||
|
@ -1,60 +1,146 @@
|
|||||||
#include "prompt.h"
|
#include "prompt.h"
|
||||||
|
#include <future>
|
||||||
|
|
||||||
|
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
|
||||||
|
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
|
||||||
|
if(_config.bHasTokenCallback){
|
||||||
|
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TsfnContext::TsfnContext(Napi::Env env, const PromptWorkContext& pc)
|
PromptWorker::~PromptWorker()
|
||||||
: deferred_(Napi::Promise::Deferred::New(env)), pc(pc) {
|
{
|
||||||
}
|
if(_config.bHasTokenCallback){
|
||||||
namespace {
|
_tsfn.Release();
|
||||||
static std::string *res;
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool response_callback(int32_t token_id, const char *response) {
|
void PromptWorker::Execute()
|
||||||
*res += response;
|
{
|
||||||
return token_id != -1;
|
_config.mutex->lock();
|
||||||
}
|
|
||||||
bool recalculate_callback (bool isrecalculating) {
|
|
||||||
return isrecalculating;
|
|
||||||
};
|
|
||||||
bool prompt_callback (int32_t tid) {
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
// The thread entry point. This takes as its arguments the specific
|
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
|
||||||
// threadsafe-function context created inside the main thread.
|
|
||||||
void threadEntry(TsfnContext* context) {
|
|
||||||
static std::mutex mtx;
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
res = &context->pc.res;
|
|
||||||
// Perform a call into JavaScript.
|
|
||||||
napi_status status =
|
|
||||||
context->tsfn.BlockingCall(&context->pc,
|
|
||||||
[](Napi::Env env, Napi::Function jsCallback, PromptWorkContext* pc) {
|
|
||||||
llmodel_prompt(
|
|
||||||
pc->inference_,
|
|
||||||
pc->question.c_str(),
|
|
||||||
&prompt_callback,
|
|
||||||
&response_callback,
|
|
||||||
&recalculate_callback,
|
|
||||||
&pc->prompt_params
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (status != napi_ok) {
|
auto ctx = &_config.context;
|
||||||
Napi::Error::Fatal(
|
|
||||||
"ThreadEntry",
|
|
||||||
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
|
|
||||||
}
|
|
||||||
// Release the thread-safe function. This decrements the internal thread
|
|
||||||
// count, and will perform finalization since the count will reach 0.
|
|
||||||
context->tsfn.Release();
|
|
||||||
}
|
|
||||||
|
|
||||||
void FinalizerCallback(Napi::Env env,
|
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
|
||||||
void* finalizeData,
|
wrapper->promptContext.tokens.resize(ctx->n_past);
|
||||||
TsfnContext* context) {
|
|
||||||
// Resolve the Promise previously returned to JS
|
// Copy the C prompt context
|
||||||
context->deferred_.Resolve(Napi::String::New(env, context->pc.res));
|
wrapper->promptContext.n_past = ctx->n_past;
|
||||||
// Wait for the thread to finish executing before proceeding.
|
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||||
context->nativeThread.join();
|
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||||
delete context;
|
wrapper->promptContext.top_k = ctx->top_k;
|
||||||
}
|
wrapper->promptContext.top_p = ctx->top_p;
|
||||||
|
wrapper->promptContext.temp = ctx->temp;
|
||||||
|
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||||
|
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||||
|
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||||
|
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||||
|
|
||||||
|
// Napi::Error::Fatal(
|
||||||
|
// "SUPRA",
|
||||||
|
// "About to prompt");
|
||||||
|
// Call the C++ prompt method
|
||||||
|
wrapper->llModel->prompt(
|
||||||
|
_config.prompt,
|
||||||
|
[](int32_t tid) { return true; },
|
||||||
|
[this](int32_t token_id, const std::string tok)
|
||||||
|
{
|
||||||
|
return ResponseCallback(token_id, tok);
|
||||||
|
},
|
||||||
|
[](bool isRecalculating)
|
||||||
|
{
|
||||||
|
return isRecalculating;
|
||||||
|
},
|
||||||
|
wrapper->promptContext);
|
||||||
|
|
||||||
|
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||||
|
// which involves no copies
|
||||||
|
ctx->logits = wrapper->promptContext.logits.data();
|
||||||
|
ctx->logits_size = wrapper->promptContext.logits.size();
|
||||||
|
ctx->tokens = wrapper->promptContext.tokens.data();
|
||||||
|
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
||||||
|
|
||||||
|
// Update the rest of the C prompt context
|
||||||
|
ctx->n_past = wrapper->promptContext.n_past;
|
||||||
|
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
||||||
|
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||||
|
ctx->top_k = wrapper->promptContext.top_k;
|
||||||
|
ctx->top_p = wrapper->promptContext.top_p;
|
||||||
|
ctx->temp = wrapper->promptContext.temp;
|
||||||
|
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||||
|
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||||
|
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
|
||||||
|
ctx->context_erase = wrapper->promptContext.contextErase;
|
||||||
|
|
||||||
|
_config.mutex->unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PromptWorker::OnOK()
|
||||||
|
{
|
||||||
|
promise.Resolve(Napi::String::New(Env(), result));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PromptWorker::OnError(const Napi::Error &e)
|
||||||
|
{
|
||||||
|
promise.Reject(e.Value());
|
||||||
|
}
|
||||||
|
|
||||||
|
Napi::Promise PromptWorker::GetPromise()
|
||||||
|
{
|
||||||
|
return promise.Promise();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
|
||||||
|
{
|
||||||
|
if (token_id == -1)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!_config.bHasTokenCallback){
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
result += token;
|
||||||
|
|
||||||
|
std::promise<bool> promise;
|
||||||
|
|
||||||
|
auto info = new TokenCallbackInfo();
|
||||||
|
info->tokenId = token_id;
|
||||||
|
info->token = token;
|
||||||
|
info->total = result;
|
||||||
|
|
||||||
|
auto future = promise.get_future();
|
||||||
|
|
||||||
|
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
|
||||||
|
{
|
||||||
|
// Transform native data into JS data, passing it to the provided
|
||||||
|
// `jsCallback` -- the TSFN's JavaScript function.
|
||||||
|
auto token_id = Napi::Number::New(env, value->tokenId);
|
||||||
|
auto token = Napi::String::New(env, value->token);
|
||||||
|
auto total = Napi::String::New(env,value->total);
|
||||||
|
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
|
||||||
|
promise.set_value(jsResult);
|
||||||
|
// We're finished with the data.
|
||||||
|
delete value;
|
||||||
|
});
|
||||||
|
if (status != napi_ok) {
|
||||||
|
Napi::Error::Fatal(
|
||||||
|
"PromptWorkerResponseCallback",
|
||||||
|
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
return future.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PromptWorker::RecalculateCallback(bool isRecalculating)
|
||||||
|
{
|
||||||
|
return isRecalculating;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PromptWorker::PromptCallback(int32_t tid)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
@ -1,44 +1,59 @@
|
|||||||
#ifndef TSFN_CONTEXT_H
|
#ifndef PREDICT_WORKER_H
|
||||||
#define TSFN_CONTEXT_H
|
#define PREDICT_WORKER_H
|
||||||
|
|
||||||
#include "napi.h"
|
#include "napi.h"
|
||||||
#include "llmodel_c.h"
|
#include "llmodel_c.h"
|
||||||
|
#include "llmodel.h"
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
struct PromptWorkContext {
|
|
||||||
std::string question;
|
|
||||||
llmodel_model inference_;
|
|
||||||
llmodel_prompt_context prompt_params;
|
|
||||||
std::string res;
|
|
||||||
|
|
||||||
};
|
struct TokenCallbackInfo
|
||||||
|
{
|
||||||
|
int32_t tokenId;
|
||||||
|
std::string total;
|
||||||
|
std::string token;
|
||||||
|
};
|
||||||
|
|
||||||
struct TsfnContext {
|
struct LLModelWrapper
|
||||||
public:
|
{
|
||||||
TsfnContext(Napi::Env env, const PromptWorkContext &pc);
|
LLModel *llModel = nullptr;
|
||||||
std::thread nativeThread;
|
LLModel::PromptContext promptContext;
|
||||||
Napi::Promise::Deferred deferred_;
|
~LLModelWrapper() { delete llModel; }
|
||||||
PromptWorkContext pc;
|
};
|
||||||
Napi::ThreadSafeFunction tsfn;
|
|
||||||
|
|
||||||
// Some data to pass around
|
struct PromptWorkerConfig
|
||||||
// int ints[ARRAY_LENGTH];
|
{
|
||||||
|
Napi::Function tokenCallback;
|
||||||
|
bool bHasTokenCallback = false;
|
||||||
|
llmodel_model model;
|
||||||
|
std::mutex * mutex;
|
||||||
|
std::string prompt;
|
||||||
|
llmodel_prompt_context context;
|
||||||
|
std::string result;
|
||||||
|
};
|
||||||
|
|
||||||
};
|
class PromptWorker : public Napi::AsyncWorker
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
PromptWorker(Napi::Env env, PromptWorkerConfig config);
|
||||||
|
~PromptWorker();
|
||||||
|
void Execute() override;
|
||||||
|
void OnOK() override;
|
||||||
|
void OnError(const Napi::Error &e) override;
|
||||||
|
Napi::Promise GetPromise();
|
||||||
|
|
||||||
// The thread entry point. This takes as its arguments the specific
|
bool ResponseCallback(int32_t token_id, const std::string token);
|
||||||
// threadsafe-function context created inside the main thread.
|
bool RecalculateCallback(bool isrecalculating);
|
||||||
void threadEntry(TsfnContext*);
|
bool PromptCallback(int32_t tid);
|
||||||
|
|
||||||
// The thread-safe function finalizer callback. This callback executes
|
private:
|
||||||
// at destruction of thread-safe function, taking as arguments the finalizer
|
Napi::Promise::Deferred promise;
|
||||||
// data and threadsafe-function context.
|
std::string result;
|
||||||
void FinalizerCallback(Napi::Env, void* finalizeData, TsfnContext*);
|
PromptWorkerConfig _config;
|
||||||
|
Napi::ThreadSafeFunction _tsfn;
|
||||||
|
};
|
||||||
|
|
||||||
bool response_callback(int32_t token_id, const char *response);
|
#endif // PREDICT_WORKER_H
|
||||||
bool recalculate_callback (bool isrecalculating);
|
|
||||||
bool prompt_callback (int32_t tid);
|
|
||||||
#endif // TSFN_CONTEXT_H
|
|
||||||
|
41
gpt4all-bindings/typescript/spec/generator.mjs
Normal file
41
gpt4all-bindings/typescript/spec/generator.mjs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import gpt from '../src/gpt4all.js'
|
||||||
|
|
||||||
|
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
|
||||||
|
|
||||||
|
process.stdout.write('Response: ')
|
||||||
|
|
||||||
|
|
||||||
|
const tokens = gpt.generateTokens(model, [{
|
||||||
|
role: 'user',
|
||||||
|
content: "How are you ?"
|
||||||
|
}], { nPredict: 2048 })
|
||||||
|
for await (const token of tokens){
|
||||||
|
process.stdout.write(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const result = await gpt.createCompletion(model, [{
|
||||||
|
role: 'user',
|
||||||
|
content: "You sure?"
|
||||||
|
}])
|
||||||
|
|
||||||
|
console.log(result)
|
||||||
|
|
||||||
|
const result2 = await gpt.createCompletion(model, [{
|
||||||
|
role: 'user',
|
||||||
|
content: "You sure you sure?"
|
||||||
|
}])
|
||||||
|
|
||||||
|
console.log(result2)
|
||||||
|
|
||||||
|
|
||||||
|
const tokens2 = gpt.generateTokens(model, [{
|
||||||
|
role: 'user',
|
||||||
|
content: "If 3 + 3 is 5, what is 2 + 2?"
|
||||||
|
}], { nPredict: 2048 })
|
||||||
|
for await (const token of tokens2){
|
||||||
|
process.stdout.write(token);
|
||||||
|
}
|
||||||
|
console.log("done")
|
||||||
|
model.dispose();
|
||||||
|
|
46
gpt4all-bindings/typescript/src/gpt4all.d.ts
vendored
46
gpt4all-bindings/typescript/src/gpt4all.d.ts
vendored
@ -49,6 +49,12 @@ interface ModelConfig {
|
|||||||
path: string;
|
path: string;
|
||||||
url?: string;
|
url?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback for controlling token generation
|
||||||
|
*/
|
||||||
|
type TokenCallback = (tokenId: number, token: string, total: string) => boolean
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
|
||||||
@ -61,7 +67,8 @@ declare class InferenceModel {
|
|||||||
|
|
||||||
generate(
|
generate(
|
||||||
prompt: string,
|
prompt: string,
|
||||||
options?: Partial<LLModelPromptContext>
|
options?: Partial<LLModelPromptContext>,
|
||||||
|
callback?: TokenCallback
|
||||||
): Promise<string>;
|
): Promise<string>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -132,13 +139,14 @@ declare class LLModel {
|
|||||||
* Use the prompt function exported for a value
|
* Use the prompt function exported for a value
|
||||||
* @param q The prompt input.
|
* @param q The prompt input.
|
||||||
* @param params Optional parameters for the prompt context.
|
* @param params Optional parameters for the prompt context.
|
||||||
|
* @param callback - optional callback to control token generation.
|
||||||
* @returns The result of the model prompt.
|
* @returns The result of the model prompt.
|
||||||
*/
|
*/
|
||||||
raw_prompt(
|
raw_prompt(
|
||||||
q: string,
|
q: string,
|
||||||
params: Partial<LLModelPromptContext>,
|
params: Partial<LLModelPromptContext>,
|
||||||
callback: (res: string) => void
|
callback?: TokenCallback
|
||||||
): void; // TODO work on return type
|
): Promise<string>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Embed text with the model. Keep in mind that
|
* Embed text with the model. Keep in mind that
|
||||||
@ -176,10 +184,11 @@ declare class LLModel {
|
|||||||
hasGpuDevice(): boolean
|
hasGpuDevice(): boolean
|
||||||
/**
|
/**
|
||||||
* GPUs that are usable for this LLModel
|
* GPUs that are usable for this LLModel
|
||||||
|
* @param nCtx Maximum size of context window
|
||||||
* @throws if hasGpuDevice returns false (i think)
|
* @throws if hasGpuDevice returns false (i think)
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
listGpu() : GpuDevice[]
|
listGpu(nCtx: number) : GpuDevice[]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* delete and cleanup the native model
|
* delete and cleanup the native model
|
||||||
@ -224,6 +233,16 @@ interface LoadModelOptions {
|
|||||||
model.
|
model.
|
||||||
*/
|
*/
|
||||||
device?: string;
|
device?: string;
|
||||||
|
/*
|
||||||
|
* The Maximum window size of this model
|
||||||
|
* Default of 2048
|
||||||
|
*/
|
||||||
|
nCtx?: number;
|
||||||
|
/*
|
||||||
|
* Number of gpu layers needed
|
||||||
|
* Default of 100
|
||||||
|
*/
|
||||||
|
ngl?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface InferenceModelOptions extends LoadModelOptions {
|
interface InferenceModelOptions extends LoadModelOptions {
|
||||||
@ -442,14 +461,21 @@ interface LLModelPromptContext {
|
|||||||
contextErase: number;
|
contextErase: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: Help wanted to implement this
|
* Creates an async generator of tokens
|
||||||
|
* @param {InferenceModel} llmodel - The language model object.
|
||||||
|
* @param {PromptMessage[]} messages - The array of messages for the conversation.
|
||||||
|
* @param {CompletionOptions} options - The options for creating the completion.
|
||||||
|
* @param {TokenCallback} callback - optional callback to control token generation.
|
||||||
|
* @returns {AsyncGenerator<string>} The stream of generated tokens
|
||||||
*/
|
*/
|
||||||
declare function createTokenStream(
|
declare function generateTokens(
|
||||||
llmodel: LLModel,
|
llmodel: InferenceModel,
|
||||||
messages: PromptMessage[],
|
messages: PromptMessage[],
|
||||||
options: CompletionOptions
|
options: CompletionOptions,
|
||||||
): (ll: LLModel) => AsyncGenerator<string>;
|
callback?: TokenCallback
|
||||||
|
): AsyncGenerator<string>;
|
||||||
/**
|
/**
|
||||||
* From python api:
|
* From python api:
|
||||||
* models will be stored in (homedir)/.cache/gpt4all/`
|
* models will be stored in (homedir)/.cache/gpt4all/`
|
||||||
@ -568,7 +594,7 @@ export {
|
|||||||
loadModel,
|
loadModel,
|
||||||
createCompletion,
|
createCompletion,
|
||||||
createEmbedding,
|
createEmbedding,
|
||||||
createTokenStream,
|
generateTokens,
|
||||||
DEFAULT_DIRECTORY,
|
DEFAULT_DIRECTORY,
|
||||||
DEFAULT_LIBRARIES_DIRECTORY,
|
DEFAULT_LIBRARIES_DIRECTORY,
|
||||||
DEFAULT_MODEL_CONFIG,
|
DEFAULT_MODEL_CONFIG,
|
||||||
|
@ -18,6 +18,7 @@ const {
|
|||||||
DEFAULT_MODEL_LIST_URL,
|
DEFAULT_MODEL_LIST_URL,
|
||||||
} = require("./config.js");
|
} = require("./config.js");
|
||||||
const { InferenceModel, EmbeddingModel } = require("./models.js");
|
const { InferenceModel, EmbeddingModel } = require("./models.js");
|
||||||
|
const Stream = require('stream')
|
||||||
const assert = require("assert");
|
const assert = require("assert");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -36,6 +37,8 @@ async function loadModel(modelName, options = {}) {
|
|||||||
allowDownload: true,
|
allowDownload: true,
|
||||||
verbose: true,
|
verbose: true,
|
||||||
device: 'cpu',
|
device: 'cpu',
|
||||||
|
nCtx: 2048,
|
||||||
|
ngl : 100,
|
||||||
...options,
|
...options,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -58,11 +61,14 @@ async function loadModel(modelName, options = {}) {
|
|||||||
model_path: loadOptions.modelPath,
|
model_path: loadOptions.modelPath,
|
||||||
library_path: existingPaths,
|
library_path: existingPaths,
|
||||||
device: loadOptions.device,
|
device: loadOptions.device,
|
||||||
|
nCtx: loadOptions.nCtx,
|
||||||
|
ngl: loadOptions.ngl
|
||||||
};
|
};
|
||||||
|
|
||||||
if (loadOptions.verbose) {
|
if (loadOptions.verbose) {
|
||||||
console.debug("Creating LLModel with options:", llmOptions);
|
console.debug("Creating LLModel with options:", llmOptions);
|
||||||
}
|
}
|
||||||
|
console.log(modelConfig)
|
||||||
const llmodel = new LLModel(llmOptions);
|
const llmodel = new LLModel(llmOptions);
|
||||||
if (loadOptions.type === "embedding") {
|
if (loadOptions.type === "embedding") {
|
||||||
return new EmbeddingModel(llmodel, modelConfig);
|
return new EmbeddingModel(llmodel, modelConfig);
|
||||||
@ -149,11 +155,7 @@ const defaultCompletionOptions = {
|
|||||||
...DEFAULT_PROMPT_CONTEXT,
|
...DEFAULT_PROMPT_CONTEXT,
|
||||||
};
|
};
|
||||||
|
|
||||||
async function createCompletion(
|
function preparePromptAndContext(model,messages,options){
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
options = defaultCompletionOptions
|
|
||||||
) {
|
|
||||||
if (options.hasDefaultHeader !== undefined) {
|
if (options.hasDefaultHeader !== undefined) {
|
||||||
console.warn(
|
console.warn(
|
||||||
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
|
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
|
||||||
@ -180,6 +182,7 @@ async function createCompletion(
|
|||||||
...promptContext
|
...promptContext
|
||||||
} = optionsWithDefaults;
|
} = optionsWithDefaults;
|
||||||
|
|
||||||
|
|
||||||
const prompt = formatChatPrompt(messages, {
|
const prompt = formatChatPrompt(messages, {
|
||||||
systemPromptTemplate,
|
systemPromptTemplate,
|
||||||
defaultSystemPrompt: model.config.systemPrompt,
|
defaultSystemPrompt: model.config.systemPrompt,
|
||||||
@ -192,11 +195,28 @@ async function createCompletion(
|
|||||||
// promptFooter: '### Response:',
|
// promptFooter: '### Response:',
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
prompt, promptContext, verbose
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function createCompletion(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
options = defaultCompletionOptions
|
||||||
|
) {
|
||||||
|
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
|
||||||
|
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
console.debug("Sending Prompt:\n" + prompt);
|
console.debug("Sending Prompt:\n" + prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await model.generate(prompt, promptContext);
|
let tokensGenerated = 0
|
||||||
|
|
||||||
|
const response = await model.generate(prompt, promptContext,() => {
|
||||||
|
tokensGenerated++;
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
console.debug("Received Response:\n" + response);
|
console.debug("Received Response:\n" + response);
|
||||||
@ -206,8 +226,8 @@ async function createCompletion(
|
|||||||
llmodel: model.llm.name(),
|
llmodel: model.llm.name(),
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: prompt.length,
|
prompt_tokens: prompt.length,
|
||||||
completion_tokens: response.length, //TODO
|
completion_tokens: tokensGenerated,
|
||||||
total_tokens: prompt.length + response.length, //TODO
|
total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt
|
||||||
},
|
},
|
||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
@ -220,8 +240,77 @@ async function createCompletion(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function createTokenStream() {
|
function _internal_createTokenStream(stream,model,
|
||||||
throw Error("This API has not been completed yet!");
|
messages,
|
||||||
|
options = defaultCompletionOptions,callback = undefined) {
|
||||||
|
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
|
||||||
|
|
||||||
|
|
||||||
|
if (verbose) {
|
||||||
|
console.debug("Sending Prompt:\n" + prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
model.generate(prompt, promptContext,(tokenId, token, total) => {
|
||||||
|
stream.push(token);
|
||||||
|
|
||||||
|
if(callback !== undefined){
|
||||||
|
return callback(tokenId,token,total);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}).then(() => {
|
||||||
|
stream.end()
|
||||||
|
})
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
function _createTokenStream(model,
|
||||||
|
messages,
|
||||||
|
options = defaultCompletionOptions,callback = undefined) {
|
||||||
|
|
||||||
|
// Silent crash if we dont do this here
|
||||||
|
const stream = new Stream.PassThrough({
|
||||||
|
encoding: 'utf-8'
|
||||||
|
});
|
||||||
|
return _internal_createTokenStream(stream,model,messages,options,callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function* generateTokens(model,
|
||||||
|
messages,
|
||||||
|
options = defaultCompletionOptions, callback = undefined) {
|
||||||
|
const stream = _createTokenStream(model,messages,options,callback)
|
||||||
|
|
||||||
|
let bHasFinished = false;
|
||||||
|
let activeDataCallback = undefined;
|
||||||
|
const finishCallback = () => {
|
||||||
|
bHasFinished = true;
|
||||||
|
if(activeDataCallback !== undefined){
|
||||||
|
activeDataCallback(undefined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.on("finish",finishCallback)
|
||||||
|
|
||||||
|
while (!bHasFinished) {
|
||||||
|
const token = await new Promise((res) => {
|
||||||
|
|
||||||
|
activeDataCallback = (d) => {
|
||||||
|
stream.off("data",activeDataCallback)
|
||||||
|
activeDataCallback = undefined
|
||||||
|
res(d);
|
||||||
|
}
|
||||||
|
stream.on('data', activeDataCallback)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (token == undefined) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
yield token;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.off("finish",finishCallback);
|
||||||
}
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
@ -238,5 +327,5 @@ module.exports = {
|
|||||||
downloadModel,
|
downloadModel,
|
||||||
retrieveModel,
|
retrieveModel,
|
||||||
loadModel,
|
loadModel,
|
||||||
createTokenStream,
|
generateTokens
|
||||||
};
|
};
|
||||||
|
@ -9,10 +9,10 @@ class InferenceModel {
|
|||||||
this.config = config;
|
this.config = config;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt, promptContext) {
|
async generate(prompt, promptContext,callback) {
|
||||||
warnOnSnakeCaseKeys(promptContext);
|
warnOnSnakeCaseKeys(promptContext);
|
||||||
const normalizedPromptContext = normalizePromptContext(promptContext);
|
const normalizedPromptContext = normalizePromptContext(promptContext);
|
||||||
const result = this.llm.raw_prompt(prompt, normalizedPromptContext, () => {});
|
const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +224,6 @@ async function retrieveModel(modelName, options = {}) {
|
|||||||
verbose: true,
|
verbose: true,
|
||||||
...options,
|
...options,
|
||||||
};
|
};
|
||||||
|
|
||||||
await mkdirp(retrieveOptions.modelPath);
|
await mkdirp(retrieveOptions.modelPath);
|
||||||
|
|
||||||
const modelFileName = appendBinSuffixIfMissing(modelName);
|
const modelFileName = appendBinSuffixIfMissing(modelName);
|
||||||
@ -284,7 +283,6 @@ async function retrieveModel(modelName, options = {}) {
|
|||||||
} else {
|
} else {
|
||||||
throw Error("Failed to retrieve model.");
|
throw Error("Failed to retrieve model.");
|
||||||
}
|
}
|
||||||
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user