typescript: async generator and token stream (#1897)

Signed-off-by: Tare Ebelo <75279482+TareHimself@users.noreply.github.com>
Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: jacob <jacoobes@sern.dev>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
TareHimself 2024-02-24 17:50:14 -05:00 committed by GitHub
parent ef518fae3e
commit a153cc5b25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1517 additions and 955 deletions

View File

@ -611,6 +611,7 @@ jobs:
$Env:Path += ";$MinGwBin"
$Env:Path += ";C:\Program Files\CMake\bin"
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
cd gpt4all-backend
mkdir runtimes/win-x64
cd runtimes/win-x64
@ -651,6 +652,7 @@ jobs:
command: |
$Env:Path += ";C:\Program Files\CMake\bin"
$Env:Path += ";C:\VulkanSDK\1.3.261.1\bin"
$Env:VULKAN_SDK = "C:\VulkanSDK\1.3.261.1"
cd gpt4all-backend
mkdir runtimes/win-x64_msvc
cd runtimes/win-x64_msvc
@ -1107,8 +1109,12 @@ workflows:
jobs:
- hold:
type: approval
- csharp-hold:
type: approval
- nuget-hold:
type: approval
- nodejs-hold:
type: approval
- npm-hold:
type: approval
- build-bindings-backend-linux:
@ -1151,21 +1157,21 @@ workflows:
branches:
only:
requires:
- npm-hold
- nodejs-hold
- build-bindings-backend-linux
- build-nodejs-windows:
filters:
branches:
only:
requires:
- npm-hold
- nodejs-hold
- build-bindings-backend-windows-msvc
- build-nodejs-macos:
filters:
branches:
only:
requires:
- npm-hold
- nodejs-hold
- build-bindings-backend-macos
@ -1175,21 +1181,21 @@ workflows:
branches:
only:
requires:
- nuget-hold
- csharp-hold
- build-bindings-backend-linux
- build-csharp-windows:
filters:
branches:
only:
requires:
- nuget-hold
- csharp-hold
- build-bindings-backend-windows
- build-csharp-macos:
filters:
branches:
only:
requires:
- nuget-hold
- csharp-hold
- build-bindings-backend-macos
- store-and-upload-nupkgs:
filters:

View File

@ -159,6 +159,7 @@ This package is in active development, and breaking changes may happen until the
* [mpt](#mpt)
* [replit](#replit)
* [type](#type)
* [TokenCallback](#tokencallback)
* [InferenceModel](#inferencemodel)
* [dispose](#dispose)
* [EmbeddingModel](#embeddingmodel)
@ -184,16 +185,17 @@ This package is in active development, and breaking changes may happen until the
* [Parameters](#parameters-5)
* [hasGpuDevice](#hasgpudevice)
* [listGpu](#listgpu)
* [Parameters](#parameters-6)
* [dispose](#dispose-2)
* [GpuDevice](#gpudevice)
* [type](#type-2)
* [LoadModelOptions](#loadmodeloptions)
* [loadModel](#loadmodel)
* [Parameters](#parameters-6)
* [createCompletion](#createcompletion)
* [Parameters](#parameters-7)
* [createEmbedding](#createembedding)
* [createCompletion](#createcompletion)
* [Parameters](#parameters-8)
* [createEmbedding](#createembedding)
* [Parameters](#parameters-9)
* [CompletionOptions](#completionoptions)
* [verbose](#verbose)
* [systemPromptTemplate](#systemprompttemplate)
@ -225,15 +227,15 @@ This package is in active development, and breaking changes may happen until the
* [repeatPenalty](#repeatpenalty)
* [repeatLastN](#repeatlastn)
* [contextErase](#contexterase)
* [createTokenStream](#createtokenstream)
* [Parameters](#parameters-9)
* [generateTokens](#generatetokens)
* [Parameters](#parameters-10)
* [DEFAULT\_DIRECTORY](#default_directory)
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
* [DEFAULT\_MODEL\_CONFIG](#default_model_config)
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
* [downloadModel](#downloadmodel)
* [Parameters](#parameters-10)
* [Parameters](#parameters-11)
* [Examples](#examples)
* [DownloadModelOptions](#downloadmodeloptions)
* [modelPath](#modelpath)
@ -279,6 +281,12 @@ Model architecture. This argument currently does not have any functionality and
Type: ModelType
#### TokenCallback
Callback for controlling token generation
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
#### InferenceModel
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
@ -362,9 +370,9 @@ Use the prompt function exported for a value
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
* `callback` **function (res: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): void**&#x20;
* `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
Returns **void** The result of the model prompt.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt.
##### embed
@ -424,6 +432,12 @@ Returns **[boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/
GPUs that are usable for this LLModel
###### Parameters
* `nCtx` **[number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)** Maximum size of context window
<!---->
* Throws **any** if hasGpuDevice returns false (i think)
Returns **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[GpuDevice](#gpudevice)>**&#x20;
@ -690,17 +704,18 @@ The percentage of context to erase if the context window is exceeded.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### createTokenStream
#### generateTokens
TODO: Help wanted to implement this
Creates an async generator of tokens
##### Parameters
* `llmodel` **[LLModel](#llmodel)**&#x20;
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>**&#x20;
* `options` **[CompletionOptions](#completionoptions)**&#x20;
* `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
Returns **function (ll: [LLModel](#llmodel)): AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>**&#x20;
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### DEFAULT\_DIRECTORY

View File

@ -136,14 +136,18 @@ yarn test
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] generateTokens is the new name for this^
* \[x] proper unit testing (integrate with circle ci)
* \[x] publish to npm under alpha tag `gpt4all@alpha`
* \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
### API Reference

View File

@ -3,9 +3,9 @@
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Function self = DefineClass(env, "LLModel", {
InstanceMethod("type", &NodeModelWrapper::getType),
InstanceMethod("type", &NodeModelWrapper::GetType),
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
InstanceMethod("name", &NodeModelWrapper::getName),
InstanceMethod("name", &NodeModelWrapper::GetName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
@ -28,14 +28,14 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
{
auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048, 100) ));
return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) ));
}
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
{
auto env = info.Env();
int num_devices = 0;
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str());
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
if(all_devices == nullptr) {
Napi::Error::New(
@ -70,7 +70,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
return js_array;
}
Napi::Value NodeModelWrapper::getType(const Napi::CallbackInfo& info)
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info)
{
if(type.empty()) {
return info.Env().Undefined();
@ -132,6 +132,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
library_path = ".";
}
device = config_object.Get("device").As<Napi::String>();
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
}
llmodel_set_implementation_search_path(library_path.c_str());
const char* e;
@ -148,20 +151,17 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
return;
}
if(device != "cpu") {
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str());
std::cout << "Initiating GPU\n";
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers);
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
if(success) {
std::cout << "GPU init successfully\n";
} else {
if(!success) {
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
//Haven't implemented this but it is still open to contribution
std::cout << "WARNING: Failed to init GPU\n";
}
}
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048, 100);
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return;
@ -254,6 +254,9 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
.repeat_last_n = 10,
.context_erase = 0.5
};
PromptWorkerConfig promptWorkerConfig;
if(info[1].IsObject())
{
auto inputObject = info[1].As<Napi::Object>();
@ -285,29 +288,33 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
if(inputObject.Has("context_erase"))
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
}
//copy to protect llmodel resources when splitting to new thread
llmodel_prompt_context copiedPrompt = promptContext;
else
{
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
std::string copiedQuestion = question;
PromptWorkContext pc = {
copiedQuestion,
inference_,
copiedPrompt,
""
};
auto threadSafeContext = new TsfnContext(env, pc);
threadSafeContext->tsfn = Napi::ThreadSafeFunction::New(
env, // Environment
info[2].As<Napi::Function>(), // JS function from caller
"PromptCallback", // Resource name
0, // Max queue size (0 = unlimited).
1, // Initial thread count
threadSafeContext, // Context,
FinalizerCallback, // Finalizer
(void*)nullptr // Finalizer data
);
threadSafeContext->nativeThread = std::thread(threadEntry, threadSafeContext);
return threadSafeContext->deferred_.Promise();
if(info.Length() >= 3 && info[2].IsFunction()){
promptWorkerConfig.bHasTokenCallback = true;
promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>();
}
//copy to protect llmodel resources when splitting to new thread
// llmodel_prompt_context copiedPrompt = promptContext;
promptWorkerConfig.context = promptContext;
promptWorkerConfig.model = GetInference();
promptWorkerConfig.mutex = &inference_mutex;
promptWorkerConfig.prompt = question;
promptWorkerConfig.result = "";
auto worker = new PromptWorker(env, promptWorkerConfig);
worker->Queue();
return worker->GetPromise();
}
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) {
llmodel_model_destroy(inference_);
@ -321,7 +328,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
}
}
Napi::Value NodeModelWrapper::getName(const Napi::CallbackInfo& info) {
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(), name);
}
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {

View File

@ -7,14 +7,17 @@
#include <memory>
#include <filesystem>
#include <set>
#include <mutex>
namespace fs = std::filesystem;
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> {
public:
NodeModelWrapper(const Napi::CallbackInfo &);
//virtual ~NodeModelWrapper();
Napi::Value getType(const Napi::CallbackInfo& info);
Napi::Value GetType(const Napi::CallbackInfo& info);
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info);
Napi::Value StateSize(const Napi::CallbackInfo& info);
//void Finalize(Napi::Env env) override;
@ -25,7 +28,7 @@ public:
Napi::Value Prompt(const Napi::CallbackInfo& info);
void SetThreadCount(const Napi::CallbackInfo& info);
void Dispose(const Napi::CallbackInfo& info);
Napi::Value getName(const Napi::CallbackInfo& info);
Napi::Value GetName(const Napi::CallbackInfo& info);
Napi::Value ThreadCount(const Napi::CallbackInfo& info);
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info);
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info);
@ -48,8 +51,12 @@ private:
*/
llmodel_model inference_;
std::mutex inference_mutex;
std::string type;
// corresponds to LLModel::name() in typescript
std::string name;
int nCtx{};
int nGpuLayers{};
std::string full_model_path;
};

View File

@ -1,6 +1,6 @@
{
"name": "gpt4all",
"version": "3.1.0",
"version": "3.2.0",
"packageManager": "yarn@3.6.1",
"main": "src/gpt4all.js",
"repository": "nomic-ai/gpt4all",

View File

@ -1,60 +1,146 @@
#include "prompt.h"
#include <future>
TsfnContext::TsfnContext(Napi::Env env, const PromptWorkContext& pc)
: deferred_(Napi::Promise::Deferred::New(env)), pc(pc) {
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
if(_config.bHasTokenCallback){
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
}
namespace {
static std::string *res;
}
bool response_callback(int32_t token_id, const char *response) {
*res += response;
return token_id != -1;
PromptWorker::~PromptWorker()
{
if(_config.bHasTokenCallback){
_tsfn.Release();
}
bool recalculate_callback (bool isrecalculating) {
return isrecalculating;
};
bool prompt_callback (int32_t tid) {
}
void PromptWorker::Execute()
{
_config.mutex->lock();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
auto ctx = &_config.context;
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Napi::Error::Fatal(
// "SUPRA",
// "About to prompt");
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt,
[](int32_t tid) { return true; },
[this](int32_t token_id, const std::string tok)
{
return ResponseCallback(token_id, tok);
},
[](bool isRecalculating)
{
return isRecalculating;
},
wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
promise.Resolve(Napi::String::New(Env(), result));
}
void PromptWorker::OnError(const Napi::Error &e)
{
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
}
if(!_config.bHasTokenCallback){
return true;
};
}
// The thread entry point. This takes as its arguments the specific
// threadsafe-function context created inside the main thread.
void threadEntry(TsfnContext* context) {
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
res = &context->pc.res;
// Perform a call into JavaScript.
napi_status status =
context->tsfn.BlockingCall(&context->pc,
[](Napi::Env env, Napi::Function jsCallback, PromptWorkContext* pc) {
llmodel_prompt(
pc->inference_,
pc->question.c_str(),
&prompt_callback,
&response_callback,
&recalculate_callback,
&pc->prompt_params
);
result += token;
std::promise<bool> promise;
auto info = new TokenCallbackInfo();
info->tokenId = token_id;
info->token = token;
info->total = result;
auto future = promise.get_future();
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto total = Napi::String::New(env,value->total);
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
promise.set_value(jsResult);
// We're finished with the data.
delete value;
});
if (status != napi_ok) {
Napi::Error::Fatal(
"ThreadEntry",
"PromptWorkerResponseCallback",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
// Release the thread-safe function. This decrements the internal thread
// count, and will perform finalization since the count will reach 0.
context->tsfn.Release();
return future.get();
}
void FinalizerCallback(Napi::Env env,
void* finalizeData,
TsfnContext* context) {
// Resolve the Promise previously returned to JS
context->deferred_.Resolve(Napi::String::New(env, context->pc.res));
// Wait for the thread to finish executing before proceeding.
context->nativeThread.join();
delete context;
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t tid)
{
return true;
}

View File

@ -1,44 +1,59 @@
#ifndef TSFN_CONTEXT_H
#define TSFN_CONTEXT_H
#ifndef PREDICT_WORKER_H
#define PREDICT_WORKER_H
#include "napi.h"
#include "llmodel_c.h"
#include "llmodel.h"
#include <thread>
#include <mutex>
#include <iostream>
#include <atomic>
#include <memory>
struct PromptWorkContext {
std::string question;
llmodel_model inference_;
llmodel_prompt_context prompt_params;
std::string res;
struct TokenCallbackInfo
{
int32_t tokenId;
std::string total;
std::string token;
};
struct TsfnContext {
struct LLModelWrapper
{
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper() { delete llModel; }
};
struct PromptWorkerConfig
{
Napi::Function tokenCallback;
bool bHasTokenCallback = false;
llmodel_model model;
std::mutex * mutex;
std::string prompt;
llmodel_prompt_context context;
std::string result;
};
class PromptWorker : public Napi::AsyncWorker
{
public:
TsfnContext(Napi::Env env, const PromptWorkContext &pc);
std::thread nativeThread;
Napi::Promise::Deferred deferred_;
PromptWorkContext pc;
Napi::ThreadSafeFunction tsfn;
PromptWorker(Napi::Env env, PromptWorkerConfig config);
~PromptWorker();
void Execute() override;
void OnOK() override;
void OnError(const Napi::Error &e) override;
Napi::Promise GetPromise();
// Some data to pass around
// int ints[ARRAY_LENGTH];
bool ResponseCallback(int32_t token_id, const std::string token);
bool RecalculateCallback(bool isrecalculating);
bool PromptCallback(int32_t tid);
private:
Napi::Promise::Deferred promise;
std::string result;
PromptWorkerConfig _config;
Napi::ThreadSafeFunction _tsfn;
};
// The thread entry point. This takes as its arguments the specific
// threadsafe-function context created inside the main thread.
void threadEntry(TsfnContext*);
// The thread-safe function finalizer callback. This callback executes
// at destruction of thread-safe function, taking as arguments the finalizer
// data and threadsafe-function context.
void FinalizerCallback(Napi::Env, void* finalizeData, TsfnContext*);
bool response_callback(int32_t token_id, const char *response);
bool recalculate_callback (bool isrecalculating);
bool prompt_callback (int32_t tid);
#endif // TSFN_CONTEXT_H
#endif // PREDICT_WORKER_H

View File

@ -0,0 +1,41 @@
import gpt from '../src/gpt4all.js'
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
process.stdout.write('Response: ')
const tokens = gpt.generateTokens(model, [{
role: 'user',
content: "How are you ?"
}], { nPredict: 2048 })
for await (const token of tokens){
process.stdout.write(token);
}
const result = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure?"
}])
console.log(result)
const result2 = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure you sure?"
}])
console.log(result2)
const tokens2 = gpt.generateTokens(model, [{
role: 'user',
content: "If 3 + 3 is 5, what is 2 + 2?"
}], { nPredict: 2048 })
for await (const token of tokens2){
process.stdout.write(token);
}
console.log("done")
model.dispose();

View File

@ -49,6 +49,12 @@ interface ModelConfig {
path: string;
url?: string;
}
/**
* Callback for controlling token generation
*/
type TokenCallback = (tokenId: number, token: string, total: string) => boolean
/**
*
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
@ -61,7 +67,8 @@ declare class InferenceModel {
generate(
prompt: string,
options?: Partial<LLModelPromptContext>
options?: Partial<LLModelPromptContext>,
callback?: TokenCallback
): Promise<string>;
/**
@ -132,13 +139,14 @@ declare class LLModel {
* Use the prompt function exported for a value
* @param q The prompt input.
* @param params Optional parameters for the prompt context.
* @param callback - optional callback to control token generation.
* @returns The result of the model prompt.
*/
raw_prompt(
q: string,
params: Partial<LLModelPromptContext>,
callback: (res: string) => void
): void; // TODO work on return type
callback?: TokenCallback
): Promise<string>
/**
* Embed text with the model. Keep in mind that
@ -176,10 +184,11 @@ declare class LLModel {
hasGpuDevice(): boolean
/**
* GPUs that are usable for this LLModel
* @param nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think)
* @returns
*/
listGpu() : GpuDevice[]
listGpu(nCtx: number) : GpuDevice[]
/**
* delete and cleanup the native model
@ -224,6 +233,16 @@ interface LoadModelOptions {
model.
*/
device?: string;
/*
* The Maximum window size of this model
* Default of 2048
*/
nCtx?: number;
/*
* Number of gpu layers needed
* Default of 100
*/
ngl?: number;
}
interface InferenceModelOptions extends LoadModelOptions {
@ -442,14 +461,21 @@ interface LLModelPromptContext {
contextErase: number;
}
/**
* TODO: Help wanted to implement this
* Creates an async generator of tokens
* @param {InferenceModel} llmodel - The language model object.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @param {TokenCallback} callback - optional callback to control token generation.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/
declare function createTokenStream(
llmodel: LLModel,
declare function generateTokens(
llmodel: InferenceModel,
messages: PromptMessage[],
options: CompletionOptions
): (ll: LLModel) => AsyncGenerator<string>;
options: CompletionOptions,
callback?: TokenCallback
): AsyncGenerator<string>;
/**
* From python api:
* models will be stored in (homedir)/.cache/gpt4all/`
@ -568,7 +594,7 @@ export {
loadModel,
createCompletion,
createEmbedding,
createTokenStream,
generateTokens,
DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY,
DEFAULT_MODEL_CONFIG,

View File

@ -18,6 +18,7 @@ const {
DEFAULT_MODEL_LIST_URL,
} = require("./config.js");
const { InferenceModel, EmbeddingModel } = require("./models.js");
const Stream = require('stream')
const assert = require("assert");
/**
@ -36,6 +37,8 @@ async function loadModel(modelName, options = {}) {
allowDownload: true,
verbose: true,
device: 'cpu',
nCtx: 2048,
ngl : 100,
...options,
};
@ -58,11 +61,14 @@ async function loadModel(modelName, options = {}) {
model_path: loadOptions.modelPath,
library_path: existingPaths,
device: loadOptions.device,
nCtx: loadOptions.nCtx,
ngl: loadOptions.ngl
};
if (loadOptions.verbose) {
console.debug("Creating LLModel with options:", llmOptions);
}
console.log(modelConfig)
const llmodel = new LLModel(llmOptions);
if (loadOptions.type === "embedding") {
return new EmbeddingModel(llmodel, modelConfig);
@ -149,11 +155,7 @@ const defaultCompletionOptions = {
...DEFAULT_PROMPT_CONTEXT,
};
async function createCompletion(
model,
messages,
options = defaultCompletionOptions
) {
function preparePromptAndContext(model,messages,options){
if (options.hasDefaultHeader !== undefined) {
console.warn(
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead"
@ -180,6 +182,7 @@ async function createCompletion(
...promptContext
} = optionsWithDefaults;
const prompt = formatChatPrompt(messages, {
systemPromptTemplate,
defaultSystemPrompt: model.config.systemPrompt,
@ -192,11 +195,28 @@ async function createCompletion(
// promptFooter: '### Response:',
});
return {
prompt, promptContext, verbose
}
}
async function createCompletion(
model,
messages,
options = defaultCompletionOptions
) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
const response = await model.generate(prompt, promptContext);
let tokensGenerated = 0
const response = await model.generate(prompt, promptContext,() => {
tokensGenerated++;
return true;
});
if (verbose) {
console.debug("Received Response:\n" + response);
@ -206,8 +226,8 @@ async function createCompletion(
llmodel: model.llm.name(),
usage: {
prompt_tokens: prompt.length,
completion_tokens: response.length, //TODO
total_tokens: prompt.length + response.length, //TODO
completion_tokens: tokensGenerated,
total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt
},
choices: [
{
@ -220,8 +240,77 @@ async function createCompletion(
};
}
function createTokenStream() {
throw Error("This API has not been completed yet!");
function _internal_createTokenStream(stream,model,
messages,
options = defaultCompletionOptions,callback = undefined) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
model.generate(prompt, promptContext,(tokenId, token, total) => {
stream.push(token);
if(callback !== undefined){
return callback(tokenId,token,total);
}
return true;
}).then(() => {
stream.end()
})
return stream;
}
function _createTokenStream(model,
messages,
options = defaultCompletionOptions,callback = undefined) {
// Silent crash if we dont do this here
const stream = new Stream.PassThrough({
encoding: 'utf-8'
});
return _internal_createTokenStream(stream,model,messages,options,callback);
}
async function* generateTokens(model,
messages,
options = defaultCompletionOptions, callback = undefined) {
const stream = _createTokenStream(model,messages,options,callback)
let bHasFinished = false;
let activeDataCallback = undefined;
const finishCallback = () => {
bHasFinished = true;
if(activeDataCallback !== undefined){
activeDataCallback(undefined);
}
}
stream.on("finish",finishCallback)
while (!bHasFinished) {
const token = await new Promise((res) => {
activeDataCallback = (d) => {
stream.off("data",activeDataCallback)
activeDataCallback = undefined
res(d);
}
stream.on('data', activeDataCallback)
})
if (token == undefined) {
break;
}
yield token;
}
stream.off("finish",finishCallback);
}
module.exports = {
@ -238,5 +327,5 @@ module.exports = {
downloadModel,
retrieveModel,
loadModel,
createTokenStream,
generateTokens
};

View File

@ -9,10 +9,10 @@ class InferenceModel {
this.config = config;
}
async generate(prompt, promptContext) {
async generate(prompt, promptContext,callback) {
warnOnSnakeCaseKeys(promptContext);
const normalizedPromptContext = normalizePromptContext(promptContext);
const result = this.llm.raw_prompt(prompt, normalizedPromptContext, () => {});
const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback);
return result;
}

View File

@ -224,7 +224,6 @@ async function retrieveModel(modelName, options = {}) {
verbose: true,
...options,
};
await mkdirp(retrieveOptions.modelPath);
const modelFileName = appendBinSuffixIfMissing(modelName);
@ -284,7 +283,6 @@ async function retrieveModel(modelName, options = {}) {
} else {
throw Error("Failed to retrieve model.");
}
return config;
}

File diff suppressed because it is too large Load Diff