C# Bindings - Prompt formatting (#712)

* Added support for custom prompt formatting

* more docs added

* bump version
This commit is contained in:
mvenditto 2023-05-29 01:57:00 +02:00 committed by GitHub
parent 44c23cd2e8
commit 9eb81cb549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 3 deletions

View File

@ -5,7 +5,7 @@
<Company></Company> <Company></Company>
<Copyright></Copyright> <Copyright></Copyright>
<NeutralLanguage>en-US</NeutralLanguage> <NeutralLanguage>en-US</NeutralLanguage>
<Version>0.5.0</Version> <Version>0.6.0</Version>
<VersionSuffix>$(VersionSuffix)</VersionSuffix> <VersionSuffix>$(VersionSuffix)</VersionSuffix>
<Version Condition=" '$(VersionSuffix)' != '' ">$(Version)$(VersionSuffix)</Version> <Version Condition=" '$(VersionSuffix)' != '' ">$(Version)$(VersionSuffix)</Version>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>

View File

@ -7,19 +7,33 @@ public class Gpt4All : IGpt4AllModel
{ {
private readonly ILLModel _model; private readonly ILLModel _model;
/// <inheritdoc/>
public IPromptFormatter? PromptFormatter { get; set; }
internal Gpt4All(ILLModel model) internal Gpt4All(ILLModel model)
{ {
_model = model; _model = model;
PromptFormatter = new DefaultPromptFormatter();
}
private string FormatPrompt(string prompt)
{
if (PromptFormatter == null) return prompt;
return PromptFormatter.FormatPrompt(prompt);
} }
public Task<ITextPredictionResult> GetPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default) public Task<ITextPredictionResult> GetPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default)
{ {
ArgumentNullException.ThrowIfNull(text);
return Task.Run(() => return Task.Run(() =>
{ {
var result = new TextPredictionResult(); var result = new TextPredictionResult();
var context = opts.ToPromptContext(); var context = opts.ToPromptContext();
var prompt = FormatPrompt(text);
_model.Prompt(text, context, responseCallback: e => _model.Prompt(prompt, context, responseCallback: e =>
{ {
if (e.IsError) if (e.IsError)
{ {
@ -37,6 +51,8 @@ public class Gpt4All : IGpt4AllModel
public Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default) public Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default)
{ {
ArgumentNullException.ThrowIfNull(text);
var result = new TextPredictionStreamingResult(); var result = new TextPredictionStreamingResult();
_ = Task.Run(() => _ = Task.Run(() =>
@ -44,8 +60,9 @@ public class Gpt4All : IGpt4AllModel
try try
{ {
var context = opts.ToPromptContext(); var context = opts.ToPromptContext();
var prompt = FormatPrompt(text);
_model.Prompt(text, context, responseCallback: e => _model.Prompt(prompt, context, responseCallback: e =>
{ {
if (e.IsError) if (e.IsError)
{ {

View File

@ -0,0 +1,16 @@
namespace Gpt4All;
public class DefaultPromptFormatter : IPromptFormatter
{
public string FormatPrompt(string prompt)
{
return $"""
### Instruction:
The prompt below is a question to answer, a task to complete, or a conversation
to respond to; decide which and write an appropriate response.
### Prompt:
{prompt}
### Response:
""";
}
}

View File

@ -2,4 +2,9 @@
public interface IGpt4AllModel : ITextPrediction, IDisposable public interface IGpt4AllModel : ITextPrediction, IDisposable
{ {
/// <summary>
/// The prompt formatter used to format the prompt before
/// feeding it to the model, if null no transformation is applied
/// </summary>
IPromptFormatter? PromptFormatter { get; set; }
} }

View File

@ -0,0 +1,14 @@
namespace Gpt4All;
/// <summary>
/// Formats a prompt
/// </summary>
public interface IPromptFormatter
{
/// <summary>
/// Format the provided prompt
/// </summary>
/// <param name="prompt">the input prompt</param>
/// <returns>The formatted prompt</returns>
string FormatPrompt(string prompt);
}