using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Gpt4All.Bindings; /// /// Arguments for the response processing callback /// /// The token id of the response /// The response string. NOTE: a token_id of -1 indicates the string is an error string /// /// A bool indicating whether the model should keep generating /// public record ModelResponseEventArgs(int TokenId, string Response) { public bool IsError => TokenId == -1; } /// /// Arguments for the prompt processing callback /// /// The token id of the prompt /// /// A bool indicating whether the model should keep processing /// public record ModelPromptEventArgs(int TokenId) { } /// /// Arguments for the recalculating callback /// /// whether the model is recalculating the context. /// /// A bool indicating whether the model should keep generating /// public record ModelRecalculatingEventArgs(bool IsRecalculating); /// /// Base class and universal wrapper for GPT4All language models built around llmodel C-API. /// public class LLModel : ILLModel { protected readonly IntPtr _handle; private readonly ModelType _modelType; private readonly ILogger _logger; private bool _disposed; public ModelType ModelType => _modelType; internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null) { _handle = handle; _modelType = modelType; _logger = logger ?? NullLogger.Instance; } /// /// Create a new model from a pointer /// /// Pointer to underlying model /// The model type public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null) { return new LLModel(handle, modelType, logger: logger); } /// /// Generate a response using the model /// /// The input promp /// The context /// A callback function for handling the processing of prompt /// A callback function for handling the generated response /// A callback function for handling recalculation requests /// public void Prompt( string text, LLModelPromptContext context, Func? promptCallback = null, Func? responseCallback = null, Func? recalculateCallback = null, CancellationToken cancellationToken = default) { GC.KeepAlive(promptCallback); GC.KeepAlive(responseCallback); GC.KeepAlive(recalculateCallback); GC.KeepAlive(cancellationToken); _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump()); NativeMethods.llmodel_prompt( _handle, text, (tokenId) => { if (cancellationToken.IsCancellationRequested) return false; if (promptCallback == null) return true; var args = new ModelPromptEventArgs(tokenId); return promptCallback(args); }, (tokenId, response) => { if (cancellationToken.IsCancellationRequested) { _logger.LogDebug("ResponseCallback evt=CancellationRequested"); return false; } if (responseCallback == null) return true; var args = new ModelResponseEventArgs(tokenId, response); return responseCallback(args); }, (isRecalculating) => { if (cancellationToken.IsCancellationRequested) return false; if (recalculateCallback == null) return true; var args = new ModelRecalculatingEventArgs(isRecalculating); return recalculateCallback(args); }, ref context.UnderlyingContext ); } /// /// Set the number of threads to be used by the model. /// /// The new thread count public void SetThreadCount(int threadCount) { NativeMethods.llmodel_setThreadCount(_handle, threadCount); } /// /// Get the number of threads used by the model. /// /// the number of threads used by the model public int GetThreadCount() { return NativeMethods.llmodel_threadCount(_handle); } /// /// Get the size of the internal state of the model. /// /// /// This state data is specific to the type of model you have created. /// /// the size in bytes of the internal state of the model public ulong GetStateSizeBytes() { return NativeMethods.llmodel_get_state_size(_handle); } /// /// Saves the internal state of the model to the specified destination address. /// /// A pointer to the src /// The number of bytes copied public unsafe ulong SaveStateData(byte* source) { return NativeMethods.llmodel_save_state_data(_handle, source); } /// /// Restores the internal state of the model using data from the specified address. /// /// A pointer to destination /// the number of bytes read public unsafe ulong RestoreStateData(byte* destination) { return NativeMethods.llmodel_restore_state_data(_handle, destination); } /// /// Check if the model is loaded. /// /// true if the model was loaded successfully, false otherwise. public bool IsLoaded() { return NativeMethods.llmodel_isModelLoaded(_handle); } /// /// Load the model from a file. /// /// The path to the model file. /// true if the model was loaded successfully, false otherwise. public bool Load(string modelPath) { return NativeMethods.llmodel_loadModel(_handle, modelPath); } protected void Destroy() { NativeMethods.llmodel_model_destroy(_handle); } protected virtual void Dispose(bool disposing) { if (_disposed) return; if (disposing) { // dispose managed state } switch (_modelType) { default: Destroy(); break; } _disposed = true; } public void Dispose() { Dispose(disposing: true); GC.SuppressFinalize(this); } }