using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All.Bindings;
///
/// Arguments for the response processing callback
///
/// The token id of the response
/// The response string. NOTE: a token_id of -1 indicates the string is an error string
///
/// A bool indicating whether the model should keep generating
///
public record ModelResponseEventArgs(int TokenId, string Response)
{
public bool IsError => TokenId == -1;
}
///
/// Arguments for the prompt processing callback
///
/// The token id of the prompt
///
/// A bool indicating whether the model should keep processing
///
public record ModelPromptEventArgs(int TokenId)
{
}
///
/// Arguments for the recalculating callback
///
/// whether the model is recalculating the context.
///
/// A bool indicating whether the model should keep generating
///
public record ModelRecalculatingEventArgs(bool IsRecalculating);
///
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
///
public class LLModel : ILLModel
{
protected readonly IntPtr _handle;
private readonly ModelType _modelType;
private readonly ILogger _logger;
private bool _disposed;
public ModelType ModelType => _modelType;
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
_handle = handle;
_modelType = modelType;
_logger = logger ?? NullLogger.Instance;
}
///
/// Create a new model from a pointer
///
/// Pointer to underlying model
/// The model type
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
return new LLModel(handle, modelType, logger: logger);
}
///
/// Generate a response using the model
///
/// The input promp
/// The context
/// A callback function for handling the processing of prompt
/// A callback function for handling the generated response
/// A callback function for handling recalculation requests
///
public void Prompt(
string text,
LLModelPromptContext context,
Func? promptCallback = null,
Func? responseCallback = null,
Func? recalculateCallback = null,
CancellationToken cancellationToken = default)
{
GC.KeepAlive(promptCallback);
GC.KeepAlive(responseCallback);
GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken);
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
NativeMethods.llmodel_prompt(
_handle,
text,
(tokenId) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (promptCallback == null) return true;
var args = new ModelPromptEventArgs(tokenId);
return promptCallback(args);
},
(tokenId, response) =>
{
if (cancellationToken.IsCancellationRequested)
{
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
return false;
}
if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args);
},
(isRecalculating) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (recalculateCallback == null) return true;
var args = new ModelRecalculatingEventArgs(isRecalculating);
return recalculateCallback(args);
},
ref context.UnderlyingContext
);
}
///
/// Set the number of threads to be used by the model.
///
/// The new thread count
public void SetThreadCount(int threadCount)
{
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
}
///
/// Get the number of threads used by the model.
///
/// the number of threads used by the model
public int GetThreadCount()
{
return NativeMethods.llmodel_threadCount(_handle);
}
///
/// Get the size of the internal state of the model.
///
///
/// This state data is specific to the type of model you have created.
///
/// the size in bytes of the internal state of the model
public ulong GetStateSizeBytes()
{
return NativeMethods.llmodel_get_state_size(_handle);
}
///
/// Saves the internal state of the model to the specified destination address.
///
/// A pointer to the src
/// The number of bytes copied
public unsafe ulong SaveStateData(byte* source)
{
return NativeMethods.llmodel_save_state_data(_handle, source);
}
///
/// Restores the internal state of the model using data from the specified address.
///
/// A pointer to destination
/// the number of bytes read
public unsafe ulong RestoreStateData(byte* destination)
{
return NativeMethods.llmodel_restore_state_data(_handle, destination);
}
///
/// Check if the model is loaded.
///
/// true if the model was loaded successfully, false otherwise.
public bool IsLoaded()
{
return NativeMethods.llmodel_isModelLoaded(_handle);
}
///
/// Load the model from a file.
///
/// The path to the model file.
/// true if the model was loaded successfully, false otherwise.
public bool Load(string modelPath)
{
return NativeMethods.llmodel_loadModel(_handle, modelPath);
}
protected void Destroy()
{
NativeMethods.llmodel_model_destroy(_handle);
}
protected virtual void Dispose(bool disposing)
{
if (_disposed) return;
if (disposing)
{
// dispose managed state
}
switch (_modelType)
{
default:
Destroy();
break;
}
_disposed = true;
}
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}