add min_p sampling parameter (#2014)

Signed-off-by: Christopher Barrera <cb@arda.tx.rr.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
chrisbarrera 2024-02-24 16:51:34 -06:00 committed by GitHub
parent a153cc5b25
commit f8b1069a1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 176 additions and 14 deletions

View File

@ -64,6 +64,7 @@ static int llama_sample_top_p_top_k(
int last_n_tokens_size, int last_n_tokens_size,
int top_k, int top_k,
float top_p, float top_p,
float min_p,
float temp, float temp,
float repeat_penalty, float repeat_penalty,
int32_t pos) { int32_t pos) {
@ -83,6 +84,7 @@ static int llama_sample_top_p_top_k(
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1); llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1); llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
llama_sample_temp(ctx, &candidates_p, temp); llama_sample_temp(ctx, &candidates_p, temp);
return llama_sample_token(ctx, &candidates_p); return llama_sample_token(ctx, &candidates_p);
} }
@ -392,7 +394,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
return llama_sample_top_p_top_k(d_ptr->ctx, return llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1); promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
} }

View File

@ -66,6 +66,7 @@ public:
int32_t n_predict = 200; int32_t n_predict = 200;
int32_t top_k = 40; int32_t top_k = 40;
float top_p = 0.9f; float top_p = 0.9f;
float min_p = 0.0f;
float temp = 0.9f; float temp = 0.9f;
int32_t n_batch = 9; int32_t n_batch = 9;
float repeat_penalty = 1.10f; float repeat_penalty = 1.10f;

View File

@ -134,6 +134,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.n_predict = ctx->n_predict; wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k; wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p; wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.min_p = ctx->min_p;
wrapper->promptContext.temp = ctx->temp; wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch; wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
@ -156,6 +157,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
ctx->n_predict = wrapper->promptContext.n_predict; ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k; ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p; ctx->top_p = wrapper->promptContext.top_p;
ctx->min_p = wrapper->promptContext.min_p;
ctx->temp = wrapper->promptContext.temp; ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch; ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;

View File

@ -39,6 +39,7 @@ struct llmodel_prompt_context {
int32_t n_predict; // number of tokens to predict int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold float top_p; // nucleus sampling probability threshold
float min_p; // Min P sampling
float temp; // temperature to adjust model's output distribution float temp; // temperature to adjust model's output distribution
int32_t n_batch; // number of predictions to generate in parallel int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens float repeat_penalty; // penalty factor for repeated tokens

View File

@ -120,6 +120,7 @@ def _old_loop(gpt4all_instance):
n_predict=200, n_predict=200,
top_k=40, top_k=40,
top_p=0.9, top_p=0.9,
min_p=0.0,
temp=0.9, temp=0.9,
n_batch=9, n_batch=9,
repeat_penalty=1.1, repeat_penalty=1.1,
@ -156,6 +157,7 @@ def _new_loop(gpt4all_instance):
temp=0.9, temp=0.9,
top_k=40, top_k=40,
top_p=0.9, top_p=0.9,
min_p=0.0,
repeat_penalty=1.1, repeat_penalty=1.1,
repeat_last_n=64, repeat_last_n=64,
n_batch=9, n_batch=9,

View File

@ -64,6 +64,15 @@ public unsafe class LLModelPromptContext
set => _ctx.top_p = value; set => _ctx.top_p = value;
} }
/// <summary>
/// min p sampling probability threshold
/// </summary>
public float MinP
{
get => _ctx.min_p;
set => _ctx.min_p = value;
}
/// <summary> /// <summary>
/// temperature to adjust model's output distribution /// temperature to adjust model's output distribution
/// </summary> /// </summary>

View File

@ -29,6 +29,8 @@ public unsafe partial struct llmodel_prompt_context
public float top_p; public float top_p;
public float min_p;
public float temp; public float temp;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]

View File

@ -16,6 +16,7 @@ internal static class LLPromptContextExtensions
n_predict = {ctx.n_predict} n_predict = {ctx.n_predict}
top_k = {ctx.top_k} top_k = {ctx.top_k}
top_p = {ctx.top_p} top_p = {ctx.top_p}
min_p = {ctx.min_p}
temp = {ctx.temp} temp = {ctx.temp}
n_batch = {ctx.n_batch} n_batch = {ctx.n_batch}
repeat_penalty = {ctx.repeat_penalty} repeat_penalty = {ctx.repeat_penalty}

View File

@ -12,6 +12,7 @@ public static class PredictRequestOptionsExtensions
TokensSize = opts.TokensSize, TokensSize = opts.TokensSize,
TopK = opts.TopK, TopK = opts.TopK,
TopP = opts.TopP, TopP = opts.TopP,
MinP = opts.MinP,
PastNum = opts.PastConversationTokensNum, PastNum = opts.PastConversationTokensNum,
RepeatPenalty = opts.RepeatPenalty, RepeatPenalty = opts.RepeatPenalty,
Temperature = opts.Temperature, Temperature = opts.Temperature,

View File

@ -16,6 +16,8 @@ public record PredictRequestOptions
public float TopP { get; init; } = 0.9f; public float TopP { get; init; } = 0.9f;
public float MinP { get; init; } = 0.0f;
public float Temperature { get; init; } = 0.1f; public float Temperature { get; init; } = 0.1f;
public int Batches { get; init; } = 8; public int Batches { get; init; } = 8;

View File

@ -36,7 +36,7 @@ std::string res = "";
void * mm; void * mm;
void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase) float top_p, float min_p, float temp, int n_batch,float ctx_erase)
{ {
llmodel_model* model = (llmodel_model*) m; llmodel_model* model = (llmodel_model*) m;
@ -69,6 +69,7 @@ void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n,
.n_predict = 50, .n_predict = 50,
.top_k = 10, .top_k = 10,
.top_p = 0.9, .top_p = 0.9,
.min_p = 0.0,
.temp = 1.0, .temp = 1.0,
.n_batch = 1, .n_batch = 1,
.repeat_penalty = 1.2, .repeat_penalty = 1.2,
@ -83,6 +84,7 @@ void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n,
prompt_context->top_k = top_k; prompt_context->top_k = top_k;
prompt_context->context_erase = ctx_erase; prompt_context->context_erase = ctx_erase;
prompt_context->top_p = top_p; prompt_context->top_p = top_p;
prompt_context->min_p = min_p;
prompt_context->temp = temp; prompt_context->temp = temp;
prompt_context->n_batch = n_batch; prompt_context->n_batch = n_batch;

View File

@ -7,7 +7,7 @@ extern "C" {
void* load_model(const char *fname, int n_threads); void* load_model(const char *fname, int n_threads);
void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase); float top_p, float min_p, float temp, int n_batch,float ctx_erase);
void free_model(void *state_ptr); void free_model(void *state_ptr);

View File

@ -7,7 +7,7 @@ package gpt4all
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl // #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl
// void* load_model(const char *fname, int n_threads); // void* load_model(const char *fname, int n_threads);
// void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, // void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
// float top_p, float temp, int n_batch,float ctx_erase); // float top_p, float min_p, float temp, int n_batch,float ctx_erase);
// void free_model(void *state_ptr); // void free_model(void *state_ptr);
// extern unsigned char getTokenCallback(void *, char *); // extern unsigned char getTokenCallback(void *, char *);
// void llmodel_set_implementation_search_path(const char *path); // void llmodel_set_implementation_search_path(const char *path);
@ -58,7 +58,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
out := make([]byte, po.Tokens) out := make([]byte, po.Tokens)
C.model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), C.model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.MinP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
res = strings.TrimPrefix(res, " ") res = strings.TrimPrefix(res, " ")

View File

@ -2,7 +2,7 @@ package gpt4all
type PredictOptions struct { type PredictOptions struct {
ContextSize, RepeatLastN, Tokens, TopK, Batch int ContextSize, RepeatLastN, Tokens, TopK, Batch int
TopP, Temperature, ContextErase, RepeatPenalty float64 TopP, MinP, Temperature, ContextErase, RepeatPenalty float64
} }
type PredictOption func(p *PredictOptions) type PredictOption func(p *PredictOptions)
@ -11,6 +11,7 @@ var DefaultOptions PredictOptions = PredictOptions{
Tokens: 200, Tokens: 200,
TopK: 10, TopK: 10,
TopP: 0.90, TopP: 0.90,
MinP: 0.0,
Temperature: 0.96, Temperature: 0.96,
Batch: 1, Batch: 1,
ContextErase: 0.55, ContextErase: 0.55,
@ -50,6 +51,13 @@ func SetTopP(topp float64) PredictOption {
} }
} }
// SetMinP sets the value for min p sampling
func SetMinP(minp float64) PredictOption {
return func(p *PredictOptions) {
p.MinP = minp
}
}
// SetRepeatPenalty sets the repeat penalty. // SetRepeatPenalty sets the repeat penalty.
func SetRepeatPenalty(ce float64) PredictOption { func SetRepeatPenalty(ce float64) PredictOption {
return func(p *PredictOptions) { return func(p *PredictOptions) {

View File

@ -32,6 +32,7 @@ public class LLModel implements AutoCloseable {
n_predict.set(128); n_predict.set(128);
top_k.set(40); top_k.set(40);
top_p.set(0.95); top_p.set(0.95);
min_p.set(0.0);
temp.set(0.28); temp.set(0.28);
n_batch.set(8); n_batch.set(8);
repeat_penalty.set(1.1); repeat_penalty.set(1.1);
@ -71,6 +72,11 @@ public class LLModel implements AutoCloseable {
return this; return this;
} }
public Builder withMinP(float min_p) {
configToBuild.min_p.set(min_p);
return this;
}
public Builder withTemp(float temp) { public Builder withTemp(float temp) {
configToBuild.temp.set(temp); configToBuild.temp.set(temp);
return this; return this;

View File

@ -48,6 +48,7 @@ public interface LLModelLibrary {
public final int32_t n_predict = new int32_t(); public final int32_t n_predict = new int32_t();
public final int32_t top_k = new int32_t(); public final int32_t top_k = new int32_t();
public final Float top_p = new Float(); public final Float top_p = new Float();
public final Float min_p = new Float();
public final Float temp = new Float(); public final Float temp = new Float();
public final int32_t n_batch = new int32_t(); public final int32_t n_batch = new int32_t();
public final Float repeat_penalty = new Float(); public final Float repeat_penalty = new Float();

View File

@ -49,6 +49,7 @@ class LLModelPromptContext(ctypes.Structure):
("n_predict", ctypes.c_int32), ("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32), ("top_k", ctypes.c_int32),
("top_p", ctypes.c_float), ("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("temp", ctypes.c_float), ("temp", ctypes.c_float),
("n_batch", ctypes.c_int32), ("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float), ("repeat_penalty", ctypes.c_float),
@ -241,6 +242,7 @@ class LLModel:
n_predict: int = 4096, n_predict: int = 4096,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.9, top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1, temp: float = 0.1,
n_batch: int = 8, n_batch: int = 8,
repeat_penalty: float = 1.2, repeat_penalty: float = 1.2,
@ -257,6 +259,7 @@ class LLModel:
n_predict=n_predict, n_predict=n_predict,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
min_p=min_p,
temp=temp, temp=temp,
n_batch=n_batch, n_batch=n_batch,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
@ -272,6 +275,7 @@ class LLModel:
self.context.n_predict = n_predict self.context.n_predict = n_predict
self.context.top_k = top_k self.context.top_k = top_k
self.context.top_p = top_p self.context.top_p = top_p
self.context.min_p = min_p
self.context.temp = temp self.context.temp = temp
self.context.n_batch = n_batch self.context.n_batch = n_batch
self.context.repeat_penalty = repeat_penalty self.context.repeat_penalty = repeat_penalty
@ -297,6 +301,7 @@ class LLModel:
n_predict: int = 4096, n_predict: int = 4096,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.9, top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1, temp: float = 0.1,
n_batch: int = 8, n_batch: int = 8,
repeat_penalty: float = 1.2, repeat_penalty: float = 1.2,
@ -334,6 +339,7 @@ class LLModel:
n_predict=n_predict, n_predict=n_predict,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
min_p=min_p,
temp=temp, temp=temp,
n_batch=n_batch, n_batch=n_batch,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,

View File

@ -289,6 +289,7 @@ class GPT4All:
temp: float = 0.7, temp: float = 0.7,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.4, top_p: float = 0.4,
min_p: float = 0.0,
repeat_penalty: float = 1.18, repeat_penalty: float = 1.18,
repeat_last_n: int = 64, repeat_last_n: int = 64,
n_batch: int = 8, n_batch: int = 8,
@ -305,6 +306,7 @@ class GPT4All:
temp: The model temperature. Larger values increase creativity but decrease factuality. temp: The model temperature. Larger values increase creativity but decrease factuality.
top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding. top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.
top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p. top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.
min_p: Randomly sample at each generation step from the top most likely tokens whose probabilities are at least min_p.
repeat_penalty: Penalize the model for repetition. Higher values result in less repetition. repeat_penalty: Penalize the model for repetition. Higher values result in less repetition.
repeat_last_n: How far in the models generation history to apply the repeat penalty. repeat_last_n: How far in the models generation history to apply the repeat penalty.
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
@ -325,6 +327,7 @@ class GPT4All:
temp=temp, temp=temp,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
min_p=min_p,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n, repeat_last_n=repeat_last_n,
n_batch=n_batch, n_batch=n_batch,

View File

@ -248,6 +248,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
.n_predict = 128, .n_predict = 128,
.top_k = 40, .top_k = 40,
.top_p = 0.9f, .top_p = 0.9f,
.min_p = 0.0f,
.temp = 0.72f, .temp = 0.72f,
.n_batch = 8, .n_batch = 8,
.repeat_penalty = 1.0f, .repeat_penalty = 1.0f,
@ -277,6 +278,8 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value(); promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
if(inputObject.Has("top_p")) if(inputObject.Has("top_p"))
promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue(); promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("min_p"))
promptContext.min_p = inputObject.Get("min_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("temp")) if(inputObject.Has("temp"))
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue(); promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
if(inputObject.Has("n_batch")) if(inputObject.Has("n_batch"))

View File

@ -568,16 +568,17 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, temp, n_batch, return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens); repeat_penalty, repeat_penalty_tokens);
} }
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens) int32_t repeat_penalty_tokens)
{ {
if (!isModelLoaded()) if (!isModelLoaded())
@ -608,6 +609,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.n_predict = n_predict; m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k; m_ctx.top_k = top_k;
m_ctx.top_p = top_p; m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp; m_ctx.temp = temp;
m_ctx.n_batch = n_batch; m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_penalty = repeat_penalty;
@ -1020,6 +1022,7 @@ void ChatLLM::processSystemPrompt()
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
@ -1028,6 +1031,7 @@ void ChatLLM::processSystemPrompt()
m_ctx.n_predict = n_predict; m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k; m_ctx.top_k = top_k;
m_ctx.top_p = top_p; m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp; m_ctx.temp = temp;
m_ctx.n_batch = n_batch; m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_penalty = repeat_penalty;
@ -1067,6 +1071,7 @@ void ChatLLM::processRestoreStateFromText()
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
@ -1075,6 +1080,7 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.n_predict = n_predict; m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k; m_ctx.top_k = top_k;
m_ctx.top_p = top_p; m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp; m_ctx.temp = temp;
m_ctx.n_batch = n_batch; m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_penalty = repeat_penalty;

View File

@ -139,7 +139,7 @@ Q_SIGNALS:
protected: protected:
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens); int32_t repeat_penalty_tokens);
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);

View File

@ -1380,6 +1380,7 @@ Window {
MySettings.maxLength, MySettings.maxLength,
MySettings.topK, MySettings.topK,
MySettings.topP, MySettings.topP,
MySettings.minP,
MySettings.temperature, MySettings.temperature,
MySettings.promptBatchSize, MySettings.promptBatchSize,
MySettings.repeatPenalty, MySettings.repeatPenalty,

View File

@ -60,12 +60,23 @@ double ModelInfo::topP() const
return MySettings::globalInstance()->modelTopP(*this); return MySettings::globalInstance()->modelTopP(*this);
} }
double ModelInfo::minP() const
{
return MySettings::globalInstance()->modelMinP(*this);
}
void ModelInfo::setTopP(double p) void ModelInfo::setTopP(double p)
{ {
if (isClone) MySettings::globalInstance()->setModelTopP(*this, p, isClone /*force*/); if (isClone) MySettings::globalInstance()->setModelTopP(*this, p, isClone /*force*/);
m_topP = p; m_topP = p;
} }
void ModelInfo::setMinP(double p)
{
if (isClone) MySettings::globalInstance()->setModelMinP(*this, p, isClone /*force*/);
m_minP = p;
}
int ModelInfo::topK() const int ModelInfo::topK() const
{ {
return MySettings::globalInstance()->modelTopK(*this); return MySettings::globalInstance()->modelTopK(*this);
@ -321,6 +332,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::nameChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::temperatureChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::topPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::minPChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
@ -571,6 +583,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->temperature(); return info->temperature();
case TopPRole: case TopPRole:
return info->topP(); return info->topP();
case MinPRole:
return info->minP();
case TopKRole: case TopKRole:
return info->topK(); return info->topK();
case MaxLengthRole: case MaxLengthRole:
@ -700,6 +714,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value)
info->setTemperature(value.toDouble()); break; info->setTemperature(value.toDouble()); break;
case TopPRole: case TopPRole:
info->setTopP(value.toDouble()); break; info->setTopP(value.toDouble()); break;
case MinPRole:
info->setMinP(value.toDouble()); break;
case TopKRole: case TopKRole:
info->setTopK(value.toInt()); break; info->setTopK(value.toInt()); break;
case MaxLengthRole: case MaxLengthRole:
@ -797,6 +813,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::OnlineRole, model.isOnline); updateData(id, ModelList::OnlineRole, model.isOnline);
updateData(id, ModelList::TemperatureRole, model.temperature()); updateData(id, ModelList::TemperatureRole, model.temperature());
updateData(id, ModelList::TopPRole, model.topP()); updateData(id, ModelList::TopPRole, model.topP());
updateData(id, ModelList::MinPRole, model.minP());
updateData(id, ModelList::TopKRole, model.topK()); updateData(id, ModelList::TopKRole, model.topK());
updateData(id, ModelList::MaxLengthRole, model.maxLength()); updateData(id, ModelList::MaxLengthRole, model.maxLength());
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize()); updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
@ -1163,6 +1180,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::TemperatureRole, obj["temperature"].toDouble()); updateData(id, ModelList::TemperatureRole, obj["temperature"].toDouble());
if (obj.contains("topP")) if (obj.contains("topP"))
updateData(id, ModelList::TopPRole, obj["topP"].toDouble()); updateData(id, ModelList::TopPRole, obj["topP"].toDouble());
if (obj.contains("minP"))
updateData(id, ModelList::MinPRole, obj["minP"].toDouble());
if (obj.contains("topK")) if (obj.contains("topK"))
updateData(id, ModelList::TopKRole, obj["topK"].toInt()); updateData(id, ModelList::TopKRole, obj["topK"].toInt());
if (obj.contains("maxLength")) if (obj.contains("maxLength"))
@ -1287,6 +1306,8 @@ void ModelList::updateModelsFromSettings()
const double temperature = settings.value(g + "/temperature").toDouble(); const double temperature = settings.value(g + "/temperature").toDouble();
Q_ASSERT(settings.contains(g + "/topP")); Q_ASSERT(settings.contains(g + "/topP"));
const double topP = settings.value(g + "/topP").toDouble(); const double topP = settings.value(g + "/topP").toDouble();
Q_ASSERT(settings.contains(g + "/minP"));
const double minP = settings.value(g + "/minP").toDouble();
Q_ASSERT(settings.contains(g + "/topK")); Q_ASSERT(settings.contains(g + "/topK"));
const int topK = settings.value(g + "/topK").toInt(); const int topK = settings.value(g + "/topK").toInt();
Q_ASSERT(settings.contains(g + "/maxLength")); Q_ASSERT(settings.contains(g + "/maxLength"));
@ -1312,6 +1333,7 @@ void ModelList::updateModelsFromSettings()
updateData(id, ModelList::FilenameRole, filename); updateData(id, ModelList::FilenameRole, filename);
updateData(id, ModelList::TemperatureRole, temperature); updateData(id, ModelList::TemperatureRole, temperature);
updateData(id, ModelList::TopPRole, topP); updateData(id, ModelList::TopPRole, topP);
updateData(id, ModelList::MinPRole, minP);
updateData(id, ModelList::TopKRole, topK); updateData(id, ModelList::TopKRole, topK);
updateData(id, ModelList::MaxLengthRole, maxLength); updateData(id, ModelList::MaxLengthRole, maxLength);
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize); updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);

View File

@ -36,6 +36,7 @@ struct ModelInfo {
Q_PROPERTY(bool isClone MEMBER isClone) Q_PROPERTY(bool isClone MEMBER isClone)
Q_PROPERTY(double temperature READ temperature WRITE setTemperature) Q_PROPERTY(double temperature READ temperature WRITE setTemperature)
Q_PROPERTY(double topP READ topP WRITE setTopP) Q_PROPERTY(double topP READ topP WRITE setTopP)
Q_PROPERTY(double minP READ minP WRITE setMinP)
Q_PROPERTY(int topK READ topK WRITE setTopK) Q_PROPERTY(int topK READ topK WRITE setTopK)
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength) Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize) Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
@ -92,6 +93,8 @@ public:
void setTemperature(double t); void setTemperature(double t);
double topP() const; double topP() const;
void setTopP(double p); void setTopP(double p);
double minP() const;
void setMinP(double p);
int topK() const; int topK() const;
void setTopK(int k); void setTopK(int k);
int maxLength() const; int maxLength() const;
@ -119,6 +122,7 @@ private:
QString m_filename; QString m_filename;
double m_temperature = 0.7; double m_temperature = 0.7;
double m_topP = 0.4; double m_topP = 0.4;
double m_minP = 0.0;
int m_topK = 40; int m_topK = 40;
int m_maxLength = 4096; int m_maxLength = 4096;
int m_promptBatchSize = 128; int m_promptBatchSize = 128;
@ -247,6 +251,7 @@ public:
RepeatPenaltyTokensRole, RepeatPenaltyTokensRole,
PromptTemplateRole, PromptTemplateRole,
SystemPromptRole, SystemPromptRole,
MinPRole,
}; };
QHash<int, QByteArray> roleNames() const override QHash<int, QByteArray> roleNames() const override
@ -282,6 +287,7 @@ public:
roles[IsCloneRole] = "isClone"; roles[IsCloneRole] = "isClone";
roles[TemperatureRole] = "temperature"; roles[TemperatureRole] = "temperature";
roles[TopPRole] = "topP"; roles[TopPRole] = "topP";
roles[MinPRole] = "minP";
roles[TopKRole] = "topK"; roles[TopKRole] = "topK";
roles[MaxLengthRole] = "maxLength"; roles[MaxLengthRole] = "maxLength";
roles[PromptBatchSizeRole] = "promptBatchSize"; roles[PromptBatchSizeRole] = "promptBatchSize";

View File

@ -87,6 +87,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
{ {
setModelTemperature(model, model.m_temperature); setModelTemperature(model, model.m_temperature);
setModelTopP(model, model.m_topP); setModelTopP(model, model.m_topP);
setModelMinP(model, model.m_minP);
setModelTopK(model, model.m_topK);; setModelTopK(model, model.m_topK);;
setModelMaxLength(model, model.m_maxLength); setModelMaxLength(model, model.m_maxLength);
setModelPromptBatchSize(model, model.m_promptBatchSize); setModelPromptBatchSize(model, model.m_promptBatchSize);
@ -201,6 +202,13 @@ double MySettings::modelTopP(const ModelInfo &m) const
return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble(); return setting.value(QString("model-%1").arg(m.id()) + "/topP", m.m_topP).toDouble();
} }
double MySettings::modelMinP(const ModelInfo &m) const
{
QSettings setting;
setting.sync();
return setting.value(QString("model-%1").arg(m.id()) + "/minP", m.m_minP).toDouble();
}
void MySettings::setModelTopP(const ModelInfo &m, double p, bool force) void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
{ {
if (modelTopP(m) == p && !force) if (modelTopP(m) == p && !force)
@ -216,6 +224,21 @@ void MySettings::setModelTopP(const ModelInfo &m, double p, bool force)
emit topPChanged(m); emit topPChanged(m);
} }
void MySettings::setModelMinP(const ModelInfo &m, double p, bool force)
{
if (modelMinP(m) == p && !force)
return;
QSettings setting;
if (m.m_minP == p && !m.isClone)
setting.remove(QString("model-%1").arg(m.id()) + "/minP");
else
setting.setValue(QString("model-%1").arg(m.id()) + "/minP", p);
setting.sync();
if (!force)
emit minPChanged(m);
}
int MySettings::modelTopK(const ModelInfo &m) const int MySettings::modelTopK(const ModelInfo &m) const
{ {
QSettings setting; QSettings setting;

View File

@ -47,6 +47,8 @@ public:
Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false); Q_INVOKABLE void setModelTemperature(const ModelInfo &m, double t, bool force = false);
double modelTopP(const ModelInfo &m) const; double modelTopP(const ModelInfo &m) const;
Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false); Q_INVOKABLE void setModelTopP(const ModelInfo &m, double p, bool force = false);
double modelMinP(const ModelInfo &m) const;
Q_INVOKABLE void setModelMinP(const ModelInfo &m, double p, bool force = false);
int modelTopK(const ModelInfo &m) const; int modelTopK(const ModelInfo &m) const;
Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false); Q_INVOKABLE void setModelTopK(const ModelInfo &m, int k, bool force = false);
int modelMaxLength(const ModelInfo &m) const; int modelMaxLength(const ModelInfo &m) const;
@ -119,6 +121,7 @@ Q_SIGNALS:
void filenameChanged(const ModelInfo &model); void filenameChanged(const ModelInfo &model);
void temperatureChanged(const ModelInfo &model); void temperatureChanged(const ModelInfo &model);
void topPChanged(const ModelInfo &model); void topPChanged(const ModelInfo &model);
void minPChanged(const ModelInfo &model);
void topKChanged(const ModelInfo &model); void topKChanged(const ModelInfo &model);
void maxLengthChanged(const ModelInfo &model); void maxLengthChanged(const ModelInfo &model);
void promptBatchSizeChanged(const ModelInfo &model); void promptBatchSizeChanged(const ModelInfo &model);

View File

@ -452,6 +452,50 @@ MySettingsTab {
Accessible.name: topPLabel.text Accessible.name: topPLabel.text
Accessible.description: ToolTip.text Accessible.description: ToolTip.text
} }
MySettingsLabel {
id: minPLabel
text: qsTr("Min P")
Layout.row: 3
Layout.column: 0
}
MyTextField {
id: minPField
text: root.currentModelInfo.minP
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Sets the minimum relative probability for a token to be considered.")
ToolTip.visible: hovered
Layout.row: 3
Layout.column: 1
validator: DoubleValidator {
locale: "C"
}
Connections {
target: MySettings
function onMinPChanged() {
minPField.text = root.currentModelInfo.minP;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
minPField.text = root.currentModelInfo.minP;
}
}
onEditingFinished: {
var val = parseFloat(text)
if (!isNaN(val)) {
MySettings.setModelMinP(root.currentModelInfo, val)
focus = false
} else {
text = root.currentModelInfo.minP
}
}
Accessible.role: Accessible.EditableText
Accessible.name: minPLabel.text
Accessible.description: ToolTip.text
}
MySettingsLabel { MySettingsLabel {
id: topKLabel id: topKLabel
visible: !root.currentModelInfo.isOnline visible: !root.currentModelInfo.isOnline
@ -592,8 +636,8 @@ MySettingsTab {
id: repeatPenaltyLabel id: repeatPenaltyLabel
visible: !root.currentModelInfo.isOnline visible: !root.currentModelInfo.isOnline
text: qsTr("Repeat Penalty") text: qsTr("Repeat Penalty")
Layout.row: 3 Layout.row: 4
Layout.column: 0 Layout.column: 2
} }
MyTextField { MyTextField {
id: repeatPenaltyField id: repeatPenaltyField
@ -603,8 +647,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output") ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 3 Layout.row: 4
Layout.column: 1 Layout.column: 3
validator: DoubleValidator { validator: DoubleValidator {
locale: "C" locale: "C"
} }

View File

@ -205,6 +205,10 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
if (body.contains("top_p")) if (body.contains("top_p"))
top_p = body["top_p"].toDouble(); top_p = body["top_p"].toDouble();
float min_p = 0.f;
if (body.contains("min_p"))
min_p = body["min_p"].toDouble();
int n = 1; int n = 1;
if (body.contains("n")) if (body.contains("n"))
n = body["n"].toInt(); n = body["n"].toInt();
@ -312,6 +316,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
max_tokens /*n_predict*/, max_tokens /*n_predict*/,
top_k, top_k,
top_p, top_p,
min_p,
temperature, temperature,
n_batch, n_batch,
repeat_penalty, repeat_penalty,