gpt4all/gpt4all-bindings/typescript/prompt.cc
Andreas Obersteiner a602f7fde7
typescript bindings maintenance (#2363)
* remove outdated comments

Signed-off-by: limez <limez@protonmail.com>

* simpler build from source

Signed-off-by: limez <limez@protonmail.com>

* update unix build script to create .so runtimes correctly

Signed-off-by: limez <limez@protonmail.com>

* configure ci build type, use RelWithDebInfo for dev build script

Signed-off-by: limez <limez@protonmail.com>

* add clean script

Signed-off-by: limez <limez@protonmail.com>

* fix streamed token decoding / emoji

Signed-off-by: limez <limez@protonmail.com>

* remove deprecated nCtx

Signed-off-by: limez <limez@protonmail.com>

* update typings

Signed-off-by: jacob <jacoobes@sern.dev>

update typings

Signed-off-by: jacob <jacoobes@sern.dev>

* readme,mspell

Signed-off-by: jacob <jacoobes@sern.dev>

* cuda/backend logic changes + name napi methods like their js counterparts

Signed-off-by: limez <limez@protonmail.com>

* convert llmodel example into a test, separate test suite that can run in ci

Signed-off-by: limez <limez@protonmail.com>

* update examples / naming

Signed-off-by: limez <limez@protonmail.com>

* update deps, remove the need for binding.ci.gyp, make node-gyp-build fallback easier testable

Signed-off-by: limez <limez@protonmail.com>

* make sure the assert-backend-sources.js script is published, but not the others

Signed-off-by: limez <limez@protonmail.com>

* build correctly on windows (regression on node-gyp-build)

Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>

* codespell

Signed-off-by: limez <limez@protonmail.com>

* make sure dlhandle.cpp gets linked correctly

Signed-off-by: limez <limez@protonmail.com>

* add include for check_cxx_compiler_flag call during aarch64 builds

Signed-off-by: limez <limez@protonmail.com>

* x86 > arm64 cross compilation of runtimes and bindings

Signed-off-by: limez <limez@protonmail.com>

* default to cpu instead of kompute on arm64

Signed-off-by: limez <limez@protonmail.com>

* formatting, more minimal example

Signed-off-by: limez <limez@protonmail.com>

---------

Signed-off-by: limez <limez@protonmail.com>
Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>
Co-authored-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>
Co-authored-by: jacob <jacoobes@sern.dev>
2024-06-03 11:12:55 -05:00

198 lines
6.3 KiB
C++

#include "prompt.h"
#include <future>
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env)
{
if (_config.hasResponseCallback)
{
_responseCallbackFn = Napi::ThreadSafeFunction::New(config.responseCallback.Env(), config.responseCallback,
"PromptWorker", 0, 1, this);
}
if (_config.hasPromptCallback)
{
_promptCallbackFn = Napi::ThreadSafeFunction::New(config.promptCallback.Env(), config.promptCallback,
"PromptWorker", 0, 1, this);
}
}
PromptWorker::~PromptWorker()
{
if (_config.hasResponseCallback)
{
_responseCallbackFn.Release();
}
if (_config.hasPromptCallback)
{
_promptCallbackFn.Release();
}
}
void PromptWorker::Execute()
{
_config.mutex->lock();
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
auto ctx = &_config.context;
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt, _config.promptTemplate, [this](int32_t token_id) { return PromptCallback(token_id); },
[this](int32_t token_id, const std::string token) { return ResponseCallback(token_id, token); },
[](bool isRecalculating) { return isRecalculating; }, wrapper->promptContext, _config.special,
_config.fakeReply);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
Napi::Object returnValue = Napi::Object::New(Env());
returnValue.Set("text", result);
returnValue.Set("nPast", _config.context.n_past);
promise.Resolve(returnValue);
delete _config.fakeReply;
}
void PromptWorker::OnError(const Napi::Error &e)
{
delete _config.fakeReply;
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
}
if (!_config.hasResponseCallback)
{
return true;
}
result += token;
std::promise<bool> promise;
auto info = new ResponseCallbackData();
info->tokenId = token_id;
info->token = token;
auto future = promise.get_future();
auto status = _responseCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, ResponseCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::Uint8Array::New(env, value->token.size());
memcpy(token.Data(), value->token.data(), value->token.size());
auto jsResult = jsCallback.Call({token_id, token}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onResponseToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
Napi::Error::Fatal("PromptWorkerResponseCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t token_id)
{
if (!_config.hasPromptCallback)
{
return true;
}
std::promise<bool> promise;
auto info = new PromptCallbackData();
info->tokenId = token_id;
auto future = promise.get_future();
auto status = _promptCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, PromptCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto jsResult = jsCallback.Call({token_id}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onPromptToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
Napi::Error::Fatal("PromptWorkerPromptCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}