diff --git a/gpt4all-bindings/golang/README.md b/gpt4all-bindings/golang/README.md index 634c5234..38a41867 100644 --- a/gpt4all-bindings/golang/README.md +++ b/gpt4all-bindings/golang/README.md @@ -24,7 +24,7 @@ func main() { return true }) - _, err = model.Predict("Here are 4 steps to create a website:", gpt4all.SetTemperature(0.1)) + _, err = model.Predict("Here are 4 steps to create a website:", "", "", gpt4all.SetTemperature(0.1)) if err != nil { panic(err) } diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index de730262..e3f47b56 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -35,8 +35,9 @@ void* load_model(const char *fname, int n_threads) { std::string res = ""; void * mm; -void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, - float top_p, float min_p, float temp, int n_batch,float ctx_erase) +void model_prompt(const char *prompt, const char *prompt_template, int special, const char *fake_reply, + void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, + int top_k, float top_p, float min_p, float temp, int n_batch,float ctx_erase) { llmodel_model* model = (llmodel_model*) m; @@ -88,11 +89,11 @@ void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, prompt_context->temp = temp; prompt_context->n_batch = n_batch; - llmodel_prompt(model, prompt, + llmodel_prompt(model, prompt, prompt_template, lambda_prompt, lambda_response, lambda_recalculate, - prompt_context ); + prompt_context, special, fake_reply); strcpy(result, res.c_str()); diff --git a/gpt4all-bindings/golang/binding.h b/gpt4all-bindings/golang/binding.h index 3a4d3656..990f10e8 100644 --- a/gpt4all-bindings/golang/binding.h +++ b/gpt4all-bindings/golang/binding.h @@ -6,8 +6,9 @@ extern "C" { void* load_model(const char *fname, int n_threads); -void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, - float top_p, float min_p, float temp, int n_batch,float ctx_erase); +void model_prompt(const char *prompt, const char *prompt_template, int special, const char *fake_reply, + void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, + int top_k, float top_p, float min_p, float temp, int n_batch,float ctx_erase); void free_model(void *state_ptr); diff --git a/gpt4all-bindings/golang/example/main.go b/gpt4all-bindings/golang/example/main.go index 2e692927..7351e855 100644 --- a/gpt4all-bindings/golang/example/main.go +++ b/gpt4all-bindings/golang/example/main.go @@ -47,7 +47,7 @@ func main() { for { text := readMultiLineInput(reader) - _, err := l.Predict(text, gpt4all.SetTokens(tokens), gpt4all.SetTopK(90), gpt4all.SetTopP(0.86)) + _, err := l.Predict(text, "", "", gpt4all.SetTokens(tokens), gpt4all.SetTopK(90), gpt4all.SetTopP(0.86)) if err != nil { panic(err) } diff --git a/gpt4all-bindings/golang/gpt4all.go b/gpt4all-bindings/golang/gpt4all.go index f97eebf6..57604cf4 100644 --- a/gpt4all-bindings/golang/gpt4all.go +++ b/gpt4all-bindings/golang/gpt4all.go @@ -6,7 +6,7 @@ package gpt4all // #cgo darwin CXXFLAGS: -std=c++17 // #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl // void* load_model(const char *fname, int n_threads); -// void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, +// void model_prompt( const char *prompt, const char *prompt_template, int special, const char *fake_reply, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, // float top_p, float min_p, float temp, int n_batch,float ctx_erase); // void free_model(void *state_ptr); // extern unsigned char getTokenCallback(void *, char *); @@ -47,7 +47,7 @@ func New(model string, opts ...ModelOption) (*Model, error) { return gpt, nil } -func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { +func (l *Model) Predict(text, template, fakeReplyText string, opts ...PredictOption) (string, error) { po := NewPredictOptions(opts...) @@ -55,10 +55,14 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { if po.Tokens == 0 { po.Tokens = 99999999 } + templateInput := C.CString(template) + fakeReplyInput := C.CString(fakeReplyText) out := make([]byte, po.Tokens) - C.model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), - C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.MinP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) + C.model_prompt(input, templateInput, C.int(po.Special), fakeReplyInput, l.state, (*C.char)(unsafe.Pointer(&out[0])), + C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), C.int(po.Tokens), + C.int(po.TopK), C.float(po.TopP), C.float(po.MinP), C.float(po.Temperature), C.int(po.Batch), + C.float(po.ContextErase)) res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) res = strings.TrimPrefix(res, " ") diff --git a/gpt4all-bindings/golang/options.go b/gpt4all-bindings/golang/options.go index e2650ca0..56b0efc8 100644 --- a/gpt4all-bindings/golang/options.go +++ b/gpt4all-bindings/golang/options.go @@ -1,8 +1,8 @@ package gpt4all type PredictOptions struct { - ContextSize, RepeatLastN, Tokens, TopK, Batch int - TopP, MinP, Temperature, ContextErase, RepeatPenalty float64 + ContextSize, RepeatLastN, Tokens, TopK, Batch, Special int + TopP, MinP, Temperature, ContextErase, RepeatPenalty float64 } type PredictOption func(p *PredictOptions) @@ -11,9 +11,10 @@ var DefaultOptions PredictOptions = PredictOptions{ Tokens: 200, TopK: 10, TopP: 0.90, - MinP: 0.0, + MinP: 0.0, Temperature: 0.96, Batch: 1, + Special: 0, ContextErase: 0.55, ContextSize: 1024, RepeatLastN: 10, @@ -93,6 +94,17 @@ func SetBatch(size int) PredictOption { } } +// SetSpecial is true if special tokens in the prompt should be processed, false otherwise. +func SetSpecial(special bool) PredictOption { + return func(p *PredictOptions) { + if special { + p.Special = 1 + } else { + p.Special = 0 + } + } +} + // Create a new PredictOptions object with the given options. func NewPredictOptions(opts ...PredictOption) PredictOptions { p := DefaultOptions