mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Golang bindings initial working version(#534)
* WIP * Fix includes * Try to fix linking issues * Refinements * allow to load MPT and llama models too * cleanup, add example, add README
This commit is contained in:
parent
2433902460
commit
3f63cc6b47
172
gpt4all-bindings/golang/Makefile
Normal file
172
gpt4all-bindings/golang/Makefile
Normal file
@ -0,0 +1,172 @@
|
||||
INCLUDE_PATH := $(abspath ./)
|
||||
LIBRARY_PATH := $(abspath ./)
|
||||
CMAKEFLAGS=
|
||||
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
ifndef UNAME_P
|
||||
UNAME_P := $(shell uname -p)
|
||||
endif
|
||||
|
||||
ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
ifneq ($(UNAME_P),arm)
|
||||
SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
|
||||
ifeq ($(SYSCTL_M),1)
|
||||
# UNAME_P := arm
|
||||
# UNAME_M := arm64
|
||||
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
#
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
# keep standard at C11 and C++11
|
||||
CFLAGS = -I. -I../../gpt4all-backend/llama.cpp -I../../gpt4all-backend -I -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I../../gpt4all-backend/llama.cpp -I../../gpt4all-backend -O3 -DNDEBUG -std=c++17 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# warnings
|
||||
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
|
||||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),FreeBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),NetBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),OpenBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),Haiku)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
||||
# Use all CPU extensions that are available:
|
||||
CFLAGS += -march=native -mtune=native
|
||||
CXXFLAGS += -march=native -mtune=native
|
||||
endif
|
||||
ifneq ($(filter ppc64%,$(UNAME_M)),)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
CFLAGS += -mcpu=power9
|
||||
CXXFLAGS += -mcpu=power9
|
||||
endif
|
||||
# Require c++23's std::byteswap for big-endian support.
|
||||
ifeq ($(UNAME_M),ppc64)
|
||||
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
|
||||
endif
|
||||
endif
|
||||
ifndef LLAMA_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework.
|
||||
# `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -DGGML_USE_ACCELERATE
|
||||
LDFLAGS += -framework Accelerate
|
||||
endif
|
||||
endif
|
||||
ifdef LLAMA_OPENBLAS
|
||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
||||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
ifdef LLAMA_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
CFLAGS += -mcpu=native
|
||||
CXXFLAGS += -mcpu=native
|
||||
endif
|
||||
ifneq ($(filter armv6%,$(UNAME_M)),)
|
||||
# Raspberry Pi 1, 2, 3
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
ifneq ($(filter armv7%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
||||
endif
|
||||
ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
# Raspberry Pi 4
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
||||
$(info I go-gpt4all build info: )
|
||||
$(info I UNAME_S: $(UNAME_S))
|
||||
$(info I UNAME_P: $(UNAME_P))
|
||||
$(info I UNAME_M: $(UNAME_M))
|
||||
$(info I CFLAGS: $(CFLAGS))
|
||||
$(info I CXXFLAGS: $(CXXFLAGS))
|
||||
$(info I LDFLAGS: $(LDFLAGS))
|
||||
$(info I CMAKEFLAGS: $(CMAKEFLAGS))
|
||||
$(info I CC: $(CCV))
|
||||
$(info I CXX: $(CXXV))
|
||||
$(info )
|
||||
|
||||
llama.o:
|
||||
mkdir buildllama
|
||||
cd buildllama && cmake ../../../gpt4all-backend/llama.cpp $(CMAKEFLAGS) && make VERBOSE=1 llama.o && cp -rf CMakeFiles/llama.dir/llama.cpp.o ../llama.o
|
||||
|
||||
llmodel.o:
|
||||
mkdir buildllm
|
||||
cd buildllm && cmake ../../../gpt4all-backend/ $(CMAKEFLAGS) && make VERBOSE=1 llmodel ggml common
|
||||
cd buildllm && cp -rf CMakeFiles/llmodel.dir/llmodel_c.cpp.o ../llmodel.o
|
||||
cd buildllm && cp -rfv CMakeFiles/llmodel.dir/llama.cpp/examples/common.cpp.o ../common.o
|
||||
cd buildllm && cp -rf CMakeFiles/llmodel.dir/gptj.cpp.o ../gptj.o
|
||||
cd buildllm && cp -rf CMakeFiles/llmodel.dir/llamamodel.cpp.o ../llamamodel.o
|
||||
cd buildllm && cp -rf CMakeFiles/llmodel.dir/utils.cpp.o ../utils.o
|
||||
cd buildllm && cp -rf llama.cpp/CMakeFiles/ggml.dir/ggml.c.o ../ggml.o
|
||||
|
||||
clean:
|
||||
rm -f *.o
|
||||
rm -f *.a
|
||||
rm -rf buildllm
|
||||
rm -rf buildllama
|
||||
rm -rf example/main
|
||||
|
||||
binding.o:
|
||||
$(CXX) $(CXXFLAGS) binding.cpp -o binding.o -c $(LDFLAGS)
|
||||
|
||||
libgpt4all.a: binding.o llmodel.o llama.o
|
||||
ar src libgpt4all.a ggml.o common.o llama.o llamamodel.o utils.o llmodel.o gptj.o binding.o
|
||||
|
||||
test: libgpt4all.a
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./...
|
||||
|
||||
example/main: libgpt4all.a
|
||||
C_INCLUDE_PATH=$(INCLUDE_PATH) LIBRARY_PATH=$(INCLUDE_PATH) go build -o example/main ./example/
|
58
gpt4all-bindings/golang/README.md
Normal file
58
gpt4all-bindings/golang/README.md
Normal file
@ -0,0 +1,58 @@
|
||||
# GPT4All Golang bindings
|
||||
|
||||
The golang bindings has been tested on:
|
||||
- MacOS
|
||||
- Linux
|
||||
|
||||
### Usage
|
||||
|
||||
```
|
||||
import (
|
||||
"github.com/nomic/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load the model
|
||||
model, err := gpt4all.New("model.bin", gpt4all.SetModelType(gpt4all.GPTJType))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer model.Free()
|
||||
|
||||
model.SetTokenCallback(func(s string) bool {
|
||||
fmt.Print(s)
|
||||
return true
|
||||
})
|
||||
|
||||
_, err = model.Predict("Here are 4 steps to create a website:", gpt4all.SetTemperature(0.1))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
In order to use the bindings you will need to build `libgpt4all.a`:
|
||||
|
||||
```
|
||||
git clone https://github.com/nomic-ai/gpt4all
|
||||
cd gpt4all/gpt4all-bindings/golang
|
||||
make libgpt4all.a
|
||||
```
|
||||
|
||||
To use the bindings in your own software:
|
||||
|
||||
- Import `github.com/nomic/gpt4all/gpt4all-bindings/golang`;
|
||||
- Compile `libgpt4all.a` (you can use `make libgpt4all.a` in the bindings/go directory);
|
||||
- Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH` to point to the `binding.h` file directory and `libgpt4all.a` file directory respectively.
|
||||
|
||||
## Testing
|
||||
|
||||
To run tests, run `make test`:
|
||||
|
||||
```
|
||||
git clone https://github.com/nomic-ai/gpt4all
|
||||
cd gpt4all/gpt4all-bindings/golang
|
||||
make test
|
||||
```
|
127
gpt4all-bindings/golang/binding.cpp
Normal file
127
gpt4all-bindings/golang/binding.cpp
Normal file
@ -0,0 +1,127 @@
|
||||
#include "../../gpt4all-backend/llmodel_c.h"
|
||||
#include "../../gpt4all-backend/llmodel.h"
|
||||
#include "../../gpt4all-backend/llama.cpp/llama.h"
|
||||
#include "../../gpt4all-backend/llmodel_c.cpp"
|
||||
#include "../../gpt4all-backend/mpt.h"
|
||||
#include "../../gpt4all-backend/mpt.cpp"
|
||||
|
||||
#include "../../gpt4all-backend/llamamodel.h"
|
||||
#include "../../gpt4all-backend/gptj.h"
|
||||
#include "binding.h"
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
|
||||
void* load_mpt_model(const char *fname, int n_threads) {
|
||||
// load the model
|
||||
auto gptj = llmodel_mpt_create();
|
||||
|
||||
llmodel_setThreadCount(gptj, n_threads);
|
||||
if (!llmodel_loadModel(gptj, fname)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gptj;
|
||||
}
|
||||
|
||||
void* load_llama_model(const char *fname, int n_threads) {
|
||||
// load the model
|
||||
auto gptj = llmodel_llama_create();
|
||||
|
||||
llmodel_setThreadCount(gptj, n_threads);
|
||||
if (!llmodel_loadModel(gptj, fname)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gptj;
|
||||
}
|
||||
|
||||
void* load_gptj_model(const char *fname, int n_threads) {
|
||||
// load the model
|
||||
auto gptj = llmodel_gptj_create();
|
||||
|
||||
llmodel_setThreadCount(gptj, n_threads);
|
||||
if (!llmodel_loadModel(gptj, fname)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return gptj;
|
||||
}
|
||||
|
||||
std::string res = "";
|
||||
void * mm;
|
||||
|
||||
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||
float top_p, float temp, int n_batch,float ctx_erase)
|
||||
{
|
||||
llmodel_model* model = (llmodel_model*) m;
|
||||
|
||||
// std::string res = "";
|
||||
|
||||
auto lambda_prompt = [](int token_id, const char *promptchars) {
|
||||
return true;
|
||||
};
|
||||
|
||||
mm=model;
|
||||
res="";
|
||||
|
||||
auto lambda_response = [](int token_id, const char *responsechars) {
|
||||
res.append((char*)responsechars);
|
||||
return !!getTokenCallback(mm, (char*)responsechars);
|
||||
};
|
||||
|
||||
auto lambda_recalculate = [](bool is_recalculating) {
|
||||
// You can handle recalculation requests here if needed
|
||||
return is_recalculating;
|
||||
};
|
||||
|
||||
llmodel_prompt_context* prompt_context = new llmodel_prompt_context{
|
||||
.logits = NULL,
|
||||
.logits_size = 0,
|
||||
.tokens = NULL,
|
||||
.tokens_size = 0,
|
||||
.n_past = 0,
|
||||
.n_ctx = 1024,
|
||||
.n_predict = 50,
|
||||
.top_k = 10,
|
||||
.top_p = 0.9,
|
||||
.temp = 1.0,
|
||||
.n_batch = 1,
|
||||
.repeat_penalty = 1.2,
|
||||
.repeat_last_n = 10,
|
||||
.context_erase = 0.5
|
||||
};
|
||||
|
||||
prompt_context->n_predict = tokens;
|
||||
prompt_context->repeat_last_n = repeat_last_n;
|
||||
prompt_context->repeat_penalty = repeat_penalty;
|
||||
prompt_context->n_ctx = n_ctx;
|
||||
prompt_context->top_k = top_k;
|
||||
prompt_context->context_erase = ctx_erase;
|
||||
prompt_context->top_p = top_p;
|
||||
prompt_context->temp = temp;
|
||||
prompt_context->n_batch = n_batch;
|
||||
|
||||
llmodel_prompt(model, prompt,
|
||||
lambda_prompt,
|
||||
lambda_response,
|
||||
lambda_recalculate,
|
||||
prompt_context );
|
||||
|
||||
strcpy(result, res.c_str());
|
||||
|
||||
free(prompt_context);
|
||||
}
|
||||
|
||||
void gptj_free_model(void *state_ptr) {
|
||||
llmodel_model* ctx = (llmodel_model*) state_ptr;
|
||||
llmodel_llama_destroy(ctx);
|
||||
}
|
||||
|
22
gpt4all-bindings/golang/binding.h
Normal file
22
gpt4all-bindings/golang/binding.h
Normal file
@ -0,0 +1,22 @@
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
void* load_mpt_model(const char *fname, int n_threads);
|
||||
|
||||
void* load_llama_model(const char *fname, int n_threads);
|
||||
|
||||
void* load_gptj_model(const char *fname, int n_threads);
|
||||
|
||||
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||
float top_p, float temp, int n_batch,float ctx_erase);
|
||||
|
||||
void gptj_free_model(void *state_ptr);
|
||||
|
||||
extern unsigned char getTokenCallback(void *, char *);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
82
gpt4all-bindings/golang/example/main.go
Normal file
82
gpt4all-bindings/golang/example/main.go
Normal file
@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
)
|
||||
|
||||
var (
|
||||
threads = 4
|
||||
tokens = 128
|
||||
)
|
||||
|
||||
func main() {
|
||||
var model string
|
||||
|
||||
flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||
flags.StringVar(&model, "m", "./models/7B/ggml-model-q4_0.bin", "path to q4_0.bin model file to load")
|
||||
flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation")
|
||||
flags.IntVar(&tokens, "n", 512, "number of tokens to predict")
|
||||
|
||||
err := flags.Parse(os.Args[1:])
|
||||
if err != nil {
|
||||
fmt.Printf("Parsing program arguments failed: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads))
|
||||
if err != nil {
|
||||
fmt.Println("Loading the model failed:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Model loaded successfully.\n")
|
||||
|
||||
l.SetTokenCallback(func(token string) bool {
|
||||
fmt.Print(token)
|
||||
return true
|
||||
})
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
text := readMultiLineInput(reader)
|
||||
|
||||
_, err := l.Predict(text, gpt4all.SetTokens(tokens), gpt4all.SetTopK(90), gpt4all.SetTopP(0.86))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
// readMultiLineInput reads input until an empty line is entered.
|
||||
func readMultiLineInput(reader *bufio.Reader) string {
|
||||
var lines []string
|
||||
fmt.Print(">>> ")
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
os.Exit(0)
|
||||
}
|
||||
fmt.Printf("Reading the prompt failed: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(strings.TrimSpace(line)) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
text := strings.Join(lines, "")
|
||||
return text
|
||||
}
|
20
gpt4all-bindings/golang/go.mod
Normal file
20
gpt4all-bindings/golang/go.mod
Normal file
@ -0,0 +1,20 @@
|
||||
module github.com/nomic-ai/gpt4all/gpt4all-bindings/golang
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/onsi/ginkgo/v2 v2.9.4
|
||||
github.com/onsi/gomega v1.27.6
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
golang.org/x/net v0.9.0 // indirect
|
||||
golang.org/x/sys v0.7.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/tools v0.8.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
40
gpt4all-bindings/golang/go.sum
Normal file
40
gpt4all-bindings/golang/go.sum
Normal file
@ -0,0 +1,40 @@
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE=
|
||||
github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y=
|
||||
golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4=
|
||||
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
113
gpt4all-bindings/golang/gpt4all.go
Normal file
113
gpt4all-bindings/golang/gpt4all.go
Normal file
@ -0,0 +1,113 @@
|
||||
package gpt4all
|
||||
|
||||
// #cgo CFLAGS: -I../../gpt4all-backend/ -I../../gpt4all-backend/llama.cpp -I./
|
||||
// #cgo CXXFLAGS: -std=c++17 -I../../gpt4all-backend/ -I../../gpt4all-backend/llama.cpp -I./
|
||||
// #cgo darwin LDFLAGS: -framework Accelerate
|
||||
// #cgo darwin CXXFLAGS: -std=c++17
|
||||
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++
|
||||
// void* load_mpt_model(const char *fname, int n_threads);
|
||||
// void* load_llama_model(const char *fname, int n_threads);
|
||||
// void* load_gptj_model(const char *fname, int n_threads);
|
||||
// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
|
||||
// float top_p, float temp, int n_batch,float ctx_erase);
|
||||
// void gptj_free_model(void *state_ptr);
|
||||
// extern unsigned char getTokenCallback(void *, char *);
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// The following code is https://github.com/go-skynet/go-llama.cpp with small adaptations
|
||||
type Model struct {
|
||||
state unsafe.Pointer
|
||||
}
|
||||
|
||||
func New(model string, opts ...ModelOption) (*Model, error) {
|
||||
ops := NewModelOptions(opts...)
|
||||
var state unsafe.Pointer
|
||||
|
||||
switch ops.ModelType {
|
||||
case LLaMAType:
|
||||
state = C.load_llama_model(C.CString(model), C.int(ops.Threads))
|
||||
case GPTJType:
|
||||
state = C.load_gptj_model(C.CString(model), C.int(ops.Threads))
|
||||
case MPTType:
|
||||
state = C.load_mpt_model(C.CString(model), C.int(ops.Threads))
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
return nil, fmt.Errorf("failed loading model")
|
||||
}
|
||||
|
||||
gpt := &Model{state: state}
|
||||
// set a finalizer to remove any callbacks when the struct is reclaimed by the garbage collector.
|
||||
runtime.SetFinalizer(gpt, func(g *Model) {
|
||||
setTokenCallback(g.state, nil)
|
||||
})
|
||||
|
||||
return gpt, nil
|
||||
}
|
||||
|
||||
func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
|
||||
|
||||
po := NewPredictOptions(opts...)
|
||||
|
||||
input := C.CString(text)
|
||||
if po.Tokens == 0 {
|
||||
po.Tokens = 99999999
|
||||
}
|
||||
out := make([]byte, po.Tokens)
|
||||
|
||||
C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
|
||||
C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
|
||||
|
||||
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
|
||||
res = strings.TrimPrefix(res, " ")
|
||||
res = strings.TrimPrefix(res, text)
|
||||
res = strings.TrimPrefix(res, "\n")
|
||||
res = strings.TrimSuffix(res, "<|endoftext|>")
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (l *Model) Free() {
|
||||
C.gptj_free_model(l.state)
|
||||
}
|
||||
|
||||
func (l *Model) SetTokenCallback(callback func(token string) bool) {
|
||||
setTokenCallback(l.state, callback)
|
||||
}
|
||||
|
||||
var (
|
||||
m sync.Mutex
|
||||
callbacks = map[uintptr]func(string) bool{}
|
||||
)
|
||||
|
||||
//export getTokenCallback
|
||||
func getTokenCallback(statePtr unsafe.Pointer, token *C.char) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if callback, ok := callbacks[uintptr(statePtr)]; ok {
|
||||
return callback(C.GoString(token))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// setCallback can be used to register a token callback for LLama. Pass in a nil callback to
|
||||
// remove the callback.
|
||||
func setTokenCallback(statePtr unsafe.Pointer, callback func(string) bool) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if callback == nil {
|
||||
delete(callbacks, uintptr(statePtr))
|
||||
} else {
|
||||
callbacks[uintptr(statePtr)] = callback
|
||||
}
|
||||
}
|
13
gpt4all-bindings/golang/gpt4all_suite_test.go
Normal file
13
gpt4all-bindings/golang/gpt4all_suite_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package gpt4all_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestGPT(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "go-gpt4all-j test suite")
|
||||
}
|
27
gpt4all-bindings/golang/gpt4all_test.go
Normal file
27
gpt4all-bindings/golang/gpt4all_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package gpt4all_test
|
||||
|
||||
import (
|
||||
. "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LLama binding", func() {
|
||||
Context("Declaration", func() {
|
||||
It("fails with no model", func() {
|
||||
model, err := New("not-existing")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(model).To(BeNil())
|
||||
})
|
||||
It("fails with no model", func() {
|
||||
model, err := New("not-existing", SetModelType(MPTType))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(model).To(BeNil())
|
||||
})
|
||||
It("fails with no model", func() {
|
||||
model, err := New("not-existing", SetModelType(LLaMAType))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(model).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
127
gpt4all-bindings/golang/options.go
Normal file
127
gpt4all-bindings/golang/options.go
Normal file
@ -0,0 +1,127 @@
|
||||
package gpt4all
|
||||
|
||||
type PredictOptions struct {
|
||||
ContextSize, RepeatLastN, Tokens, TopK, Batch int
|
||||
TopP, Temperature, ContextErase, RepeatPenalty float64
|
||||
}
|
||||
|
||||
type PredictOption func(p *PredictOptions)
|
||||
|
||||
var DefaultOptions PredictOptions = PredictOptions{
|
||||
Tokens: 200,
|
||||
TopK: 10,
|
||||
TopP: 0.90,
|
||||
Temperature: 0.96,
|
||||
Batch: 1,
|
||||
ContextErase: 0.55,
|
||||
ContextSize: 1024,
|
||||
RepeatLastN: 10,
|
||||
RepeatPenalty: 1.2,
|
||||
}
|
||||
|
||||
var DefaultModelOptions ModelOptions = ModelOptions{
|
||||
Threads: 4,
|
||||
ModelType: GPTJType,
|
||||
}
|
||||
|
||||
type ModelOptions struct {
|
||||
Threads int
|
||||
ModelType ModelType
|
||||
}
|
||||
type ModelOption func(p *ModelOptions)
|
||||
|
||||
type ModelType int
|
||||
|
||||
const (
|
||||
LLaMAType ModelType = 0
|
||||
GPTJType ModelType = iota
|
||||
MPTType ModelType = iota
|
||||
)
|
||||
|
||||
// SetTokens sets the number of tokens to generate.
|
||||
func SetTokens(tokens int) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.Tokens = tokens
|
||||
}
|
||||
}
|
||||
|
||||
// SetTopK sets the value for top-K sampling.
|
||||
func SetTopK(topk int) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.TopK = topk
|
||||
}
|
||||
}
|
||||
|
||||
// SetTopP sets the value for nucleus sampling.
|
||||
func SetTopP(topp float64) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.TopP = topp
|
||||
}
|
||||
}
|
||||
|
||||
// SetRepeatPenalty sets the repeat penalty.
|
||||
func SetRepeatPenalty(ce float64) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.RepeatPenalty = ce
|
||||
}
|
||||
}
|
||||
|
||||
// SetRepeatLastN sets the RepeatLastN.
|
||||
func SetRepeatLastN(ce int) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.RepeatLastN = ce
|
||||
}
|
||||
}
|
||||
|
||||
// SetContextErase sets the context erase %.
|
||||
func SetContextErase(ce float64) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.ContextErase = ce
|
||||
}
|
||||
}
|
||||
|
||||
// SetTemperature sets the temperature value for text generation.
|
||||
func SetTemperature(temp float64) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.Temperature = temp
|
||||
}
|
||||
}
|
||||
|
||||
// SetBatch sets the batch size.
|
||||
func SetBatch(size int) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.Batch = size
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new PredictOptions object with the given options.
|
||||
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
||||
p := DefaultOptions
|
||||
for _, opt := range opts {
|
||||
opt(&p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// SetThreads sets the number of threads to use for text generation.
|
||||
func SetThreads(c int) ModelOption {
|
||||
return func(p *ModelOptions) {
|
||||
p.Threads = c
|
||||
}
|
||||
}
|
||||
|
||||
// SetModelType sets the model type.
|
||||
func SetModelType(c ModelType) ModelOption {
|
||||
return func(p *ModelOptions) {
|
||||
p.ModelType = c
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new PredictOptions object with the given options.
|
||||
func NewModelOptions(opts ...ModelOption) ModelOptions {
|
||||
p := DefaultModelOptions
|
||||
for _, opt := range opts {
|
||||
opt(&p)
|
||||
}
|
||||
return p
|
||||
}
|
Loading…
Reference in New Issue
Block a user