mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Adapt code
This commit is contained in:
parent
fca2578a81
commit
79cef86bec
@ -2,14 +2,11 @@
|
|||||||
#include "../../gpt4all-backend/llmodel.h"
|
#include "../../gpt4all-backend/llmodel.h"
|
||||||
#include "../../gpt4all-backend/llama.cpp/llama.h"
|
#include "../../gpt4all-backend/llama.cpp/llama.h"
|
||||||
#include "../../gpt4all-backend/llmodel_c.cpp"
|
#include "../../gpt4all-backend/llmodel_c.cpp"
|
||||||
#include "../../gpt4all-backend/mpt.h"
|
|
||||||
#include "../../gpt4all-backend/mpt.cpp"
|
|
||||||
|
|
||||||
#include "../../gpt4all-backend/llamamodel.h"
|
|
||||||
#include "../../gpt4all-backend/gptj.h"
|
|
||||||
#include "binding.h"
|
#include "binding.h"
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstddef>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -19,46 +16,24 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
void* load_mpt_model(const char *fname, int n_threads) {
|
void* load_gpt4all_model(const char *fname, int n_threads) {
|
||||||
// load the model
|
// load the model
|
||||||
auto gptj = llmodel_mpt_create();
|
auto gptj4all = llmodel_model_create(fname);
|
||||||
|
if (gptj4all == NULL ){
|
||||||
llmodel_setThreadCount(gptj, n_threads);
|
return nullptr;
|
||||||
if (!llmodel_loadModel(gptj, fname)) {
|
}
|
||||||
|
llmodel_setThreadCount(gptj4all, n_threads);
|
||||||
|
if (!llmodel_loadModel(gptj4all, fname)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return gptj;
|
return gptj4all;
|
||||||
}
|
|
||||||
|
|
||||||
void* load_llama_model(const char *fname, int n_threads) {
|
|
||||||
// load the model
|
|
||||||
auto gptj = llmodel_llama_create();
|
|
||||||
|
|
||||||
llmodel_setThreadCount(gptj, n_threads);
|
|
||||||
if (!llmodel_loadModel(gptj, fname)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return gptj;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* load_gptj_model(const char *fname, int n_threads) {
|
|
||||||
// load the model
|
|
||||||
auto gptj = llmodel_gptj_create();
|
|
||||||
|
|
||||||
llmodel_setThreadCount(gptj, n_threads);
|
|
||||||
if (!llmodel_loadModel(gptj, fname)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return gptj;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string res = "";
|
std::string res = "";
|
||||||
void * mm;
|
void * mm;
|
||||||
|
|
||||||
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||||
float top_p, float temp, int n_batch,float ctx_erase)
|
float top_p, float temp, int n_batch,float ctx_erase)
|
||||||
{
|
{
|
||||||
llmodel_model* model = (llmodel_model*) m;
|
llmodel_model* model = (llmodel_model*) m;
|
||||||
@ -120,8 +95,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la
|
|||||||
free(prompt_context);
|
free(prompt_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gptj_free_model(void *state_ptr) {
|
void gpt4all_free_model(void *state_ptr) {
|
||||||
llmodel_model* ctx = (llmodel_model*) state_ptr;
|
llmodel_model* ctx = (llmodel_model*) state_ptr;
|
||||||
llmodel_llama_destroy(ctx);
|
llmodel_model_destroy(*ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,16 +4,12 @@ extern "C" {
|
|||||||
|
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
void* load_mpt_model(const char *fname, int n_threads);
|
void* load_gpt4all_model(const char *fname, int n_threads);
|
||||||
|
|
||||||
void* load_llama_model(const char *fname, int n_threads);
|
void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||||
|
|
||||||
void* load_gptj_model(const char *fname, int n_threads);
|
|
||||||
|
|
||||||
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
|
||||||
float top_p, float temp, int n_batch,float ctx_erase);
|
float top_p, float temp, int n_batch,float ctx_erase);
|
||||||
|
|
||||||
void gptj_free_model(void *state_ptr);
|
void gpt4all_free_model(void *state_ptr);
|
||||||
|
|
||||||
extern unsigned char getTokenCallback(void *, char *);
|
extern unsigned char getTokenCallback(void *, char *);
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ func main() {
|
|||||||
fmt.Printf("Parsing program arguments failed: %s", err)
|
fmt.Printf("Parsing program arguments failed: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads))
|
l, err := gpt4all.New(model, gpt4all.SetThreads(threads))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Loading the model failed:", err.Error())
|
fmt.Println("Loading the model failed:", err.Error())
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
@ -5,12 +5,10 @@ package gpt4all
|
|||||||
// #cgo darwin LDFLAGS: -framework Accelerate
|
// #cgo darwin LDFLAGS: -framework Accelerate
|
||||||
// #cgo darwin CXXFLAGS: -std=c++17
|
// #cgo darwin CXXFLAGS: -std=c++17
|
||||||
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++
|
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++
|
||||||
// void* load_mpt_model(const char *fname, int n_threads);
|
// void* load_gpt4all_model(const char *fname, int n_threads);
|
||||||
// void* load_llama_model(const char *fname, int n_threads);
|
// void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||||
// void* load_gptj_model(const char *fname, int n_threads);
|
|
||||||
// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
|
||||||
// float top_p, float temp, int n_batch,float ctx_erase);
|
// float top_p, float temp, int n_batch,float ctx_erase);
|
||||||
// void gptj_free_model(void *state_ptr);
|
// void gpt4all_free_model(void *state_ptr);
|
||||||
// extern unsigned char getTokenCallback(void *, char *);
|
// extern unsigned char getTokenCallback(void *, char *);
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
@ -28,16 +26,8 @@ type Model struct {
|
|||||||
|
|
||||||
func New(model string, opts ...ModelOption) (*Model, error) {
|
func New(model string, opts ...ModelOption) (*Model, error) {
|
||||||
ops := NewModelOptions(opts...)
|
ops := NewModelOptions(opts...)
|
||||||
var state unsafe.Pointer
|
|
||||||
|
|
||||||
switch ops.ModelType {
|
state := C.load_gpt4all_model(C.CString(model), C.int(ops.Threads))
|
||||||
case LLaMAType:
|
|
||||||
state = C.load_llama_model(C.CString(model), C.int(ops.Threads))
|
|
||||||
case GPTJType:
|
|
||||||
state = C.load_gptj_model(C.CString(model), C.int(ops.Threads))
|
|
||||||
case MPTType:
|
|
||||||
state = C.load_mpt_model(C.CString(model), C.int(ops.Threads))
|
|
||||||
}
|
|
||||||
|
|
||||||
if state == nil {
|
if state == nil {
|
||||||
return nil, fmt.Errorf("failed loading model")
|
return nil, fmt.Errorf("failed loading model")
|
||||||
@ -62,7 +52,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
|
|||||||
}
|
}
|
||||||
out := make([]byte, po.Tokens)
|
out := make([]byte, po.Tokens)
|
||||||
|
|
||||||
C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
|
C.gpt4all_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
|
||||||
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
|
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
|
||||||
|
|
||||||
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
|
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
|
||||||
@ -75,7 +65,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *Model) Free() {
|
func (l *Model) Free() {
|
||||||
C.gptj_free_model(l.state)
|
C.gpt4all_free_model(l.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Model) SetTokenCallback(callback func(token string) bool) {
|
func (l *Model) SetTokenCallback(callback func(token string) bool) {
|
||||||
|
@ -13,15 +13,5 @@ var _ = Describe("LLama binding", func() {
|
|||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(model).To(BeNil())
|
Expect(model).To(BeNil())
|
||||||
})
|
})
|
||||||
It("fails with no model", func() {
|
|
||||||
model, err := New("not-existing", SetModelType(MPTType))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(model).To(BeNil())
|
|
||||||
})
|
|
||||||
It("fails with no model", func() {
|
|
||||||
model, err := New("not-existing", SetModelType(LLaMAType))
|
|
||||||
Expect(err).To(HaveOccurred())
|
|
||||||
Expect(model).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -20,24 +20,14 @@ var DefaultOptions PredictOptions = PredictOptions{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var DefaultModelOptions ModelOptions = ModelOptions{
|
var DefaultModelOptions ModelOptions = ModelOptions{
|
||||||
Threads: 4,
|
Threads: 4,
|
||||||
ModelType: GPTJType,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelOptions struct {
|
type ModelOptions struct {
|
||||||
Threads int
|
Threads int
|
||||||
ModelType ModelType
|
|
||||||
}
|
}
|
||||||
type ModelOption func(p *ModelOptions)
|
type ModelOption func(p *ModelOptions)
|
||||||
|
|
||||||
type ModelType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
LLaMAType ModelType = 0
|
|
||||||
GPTJType ModelType = iota
|
|
||||||
MPTType ModelType = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetTokens sets the number of tokens to generate.
|
// SetTokens sets the number of tokens to generate.
|
||||||
func SetTokens(tokens int) PredictOption {
|
func SetTokens(tokens int) PredictOption {
|
||||||
return func(p *PredictOptions) {
|
return func(p *PredictOptions) {
|
||||||
@ -110,13 +100,6 @@ func SetThreads(c int) ModelOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetModelType sets the model type.
|
|
||||||
func SetModelType(c ModelType) ModelOption {
|
|
||||||
return func(p *ModelOptions) {
|
|
||||||
p.ModelType = c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new PredictOptions object with the given options.
|
// Create a new PredictOptions object with the given options.
|
||||||
func NewModelOptions(opts ...ModelOption) ModelOptions {
|
func NewModelOptions(opts ...ModelOption) ModelOptions {
|
||||||
p := DefaultModelOptions
|
p := DefaultModelOptions
|
||||||
|
Loading…
Reference in New Issue
Block a user