mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
transfer python bindings code
This commit is contained in:
parent
f8fdcccc5d
commit
8c84c24ee9
164
gpt4all-bindings/python/.gitignore
vendored
Normal file
164
gpt4all-bindings/python/.gitignore
vendored
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Cython
|
||||||
|
/*.c
|
||||||
|
*DO_NOT_MODIFY/
|
19
gpt4all-bindings/python/LICENSE.txt
Normal file
19
gpt4all-bindings/python/LICENSE.txt
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
Copyright (c) 2023 Nomic, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
1
gpt4all-bindings/python/MANIFEST.in
Normal file
1
gpt4all-bindings/python/MANIFEST.in
Normal file
@ -0,0 +1 @@
|
|||||||
|
recursive-include gpt4all/llmodel_DO_NOT_MODIFY *
|
41
gpt4all-bindings/python/README.md
Normal file
41
gpt4all-bindings/python/README.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Python GPT4All
|
||||||
|
|
||||||
|
This package contains a set of Python bindings that runs the `llmodel` C-API.
|
||||||
|
|
||||||
|
|
||||||
|
# Local Installation Instructions
|
||||||
|
|
||||||
|
TODO: Right now instructions in main README still depend on Qt6 setup. To setup Python bindings, we just need `llmodel` to be built which is much simpler. However, in the future, the below installation instructions should be sequentially organized such that we expect the main README's instructions were followed first.
|
||||||
|
|
||||||
|
1. Setup `llmodel`
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone --recurse-submodules https://github.com/nomic-ai/gpt4all-chat
|
||||||
|
cd gpt4all-chat/llmodel/
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
cmake --build . --parallel
|
||||||
|
```
|
||||||
|
Confirm that `libllmodel.dylib` exists in `gpt4all-chat/llmodel/build`.
|
||||||
|
|
||||||
|
2. Setup Python package
|
||||||
|
|
||||||
|
```
|
||||||
|
cd ../../bindings/python
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
pip3 install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it out! In a Python script or console:
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
gptj = GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
||||||
|
messages = [{"role": "user", "content": "Name 3 colors"}]
|
||||||
|
gptj.chat_completion(messages)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
BIN
gpt4all-bindings/python/docs/assets/favicon.ico
Normal file
BIN
gpt4all-bindings/python/docs/assets/favicon.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 15 KiB |
BIN
gpt4all-bindings/python/docs/assets/nomic.png
Normal file
BIN
gpt4all-bindings/python/docs/assets/nomic.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 25 KiB |
5
gpt4all-bindings/python/docs/css/custom.css
Normal file
5
gpt4all-bindings/python/docs/css/custom.css
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
/* Remove the `In` and `Out` block in rendered Jupyter notebooks */
|
||||||
|
.md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt,
|
||||||
|
.md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt {
|
||||||
|
display: none !important;
|
||||||
|
}
|
6
gpt4all-bindings/python/docs/gpt4all_api.md
Normal file
6
gpt4all-bindings/python/docs/gpt4all_api.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# GPT4All API
|
||||||
|
The `GPT4All` provides a universal API to call all GPT4All models and
|
||||||
|
introduces additional helpful functionality such as downloading models.
|
||||||
|
|
||||||
|
::: gpt4all.gpt4all.GPT4All
|
||||||
|
|
22
gpt4all-bindings/python/docs/index.md
Normal file
22
gpt4all-bindings/python/docs/index.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# GPT4All
|
||||||
|
|
||||||
|
In this package, we introduce Python bindings built around GPT4All's C/C++ ecosystem.
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install gpt4all
|
||||||
|
```
|
||||||
|
|
||||||
|
In Python, run the following commands to retrieve a GPT4All model and generate a response
|
||||||
|
to a prompt.
|
||||||
|
|
||||||
|
**Download Note*:*
|
||||||
|
By default, models are stored in `~/.cache/gpt4all/` (you can change this with `model_path`). If the file already exists, model download will be skipped.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import gpt4all
|
||||||
|
gptj = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
||||||
|
messages = [{"role": "user", "content": "Name 3 colors"}]
|
||||||
|
gptj.chat_completion(messages)
|
||||||
|
```
|
2
gpt4all-bindings/python/gpt4all/__init__.py
Normal file
2
gpt4all-bindings/python/gpt4all/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .pyllmodel import LLModel # noqa
|
||||||
|
from .gpt4all import GPT4All # noqa
|
280
gpt4all-bindings/python/gpt4all/gpt4all.py
Normal file
280
gpt4all-bindings/python/gpt4all/gpt4all.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
"""
|
||||||
|
Python only API for running all GPT4All models.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from . import pyllmodel
|
||||||
|
|
||||||
|
# TODO: move to config
|
||||||
|
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||||
|
|
||||||
|
class GPT4All():
|
||||||
|
"""Python API for retrieving and interacting with GPT4All models
|
||||||
|
|
||||||
|
Attribuies:
|
||||||
|
model: Pointer to underlying C model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True):
|
||||||
|
"""
|
||||||
|
Constructor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
||||||
|
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||||
|
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||||
|
model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model
|
||||||
|
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
||||||
|
Default is None.
|
||||||
|
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||||
|
"""
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
# Model type provided for when model is custom
|
||||||
|
if model_type:
|
||||||
|
self.model = GPT4All.get_model_from_type(model_type)
|
||||||
|
# Else get model from gpt4all model filenames
|
||||||
|
else:
|
||||||
|
self.model = GPT4All.get_model_from_name(model_name)
|
||||||
|
|
||||||
|
# Retrieve model and download if allowed
|
||||||
|
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||||
|
self.model.load_model(model_dest)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_models():
|
||||||
|
"""
|
||||||
|
Fetch model list from https://gpt4all.io/models/models.json
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model list in JSON format.
|
||||||
|
"""
|
||||||
|
response = requests.get("https://gpt4all.io/models/models.json")
|
||||||
|
model_json = json.loads(response.content)
|
||||||
|
return model_json
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def retrieve_model(model_name: str, model_path: str = None, allow_download = True):
|
||||||
|
"""
|
||||||
|
Find model file, and if it doesn't exist, download the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of model.
|
||||||
|
model_path: Path to find model. Default is None in which case path is set to
|
||||||
|
~/.cache/gpt4all/.
|
||||||
|
allow_download: Allow API to download model from gpt4all.io. Default is True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model file destination.
|
||||||
|
"""
|
||||||
|
model_path = model_path.replace("\\", "\\\\")
|
||||||
|
model_filename = model_name
|
||||||
|
if ".bin" not in model_filename:
|
||||||
|
model_filename += ".bin"
|
||||||
|
|
||||||
|
# Validate download directory
|
||||||
|
if model_path == None:
|
||||||
|
model_path = DEFAULT_MODEL_DIRECTORY
|
||||||
|
if not os.path.exists(DEFAULT_MODEL_DIRECTORY):
|
||||||
|
try:
|
||||||
|
os.makedirs(DEFAULT_MODEL_DIRECTORY)
|
||||||
|
except:
|
||||||
|
raise ValueError("Failed to create model download directory at ~/.cache/gpt4all/. \
|
||||||
|
Please specify download_dir.")
|
||||||
|
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||||
|
if os.path.exists(model_dest):
|
||||||
|
print("Found model file.")
|
||||||
|
return model_dest
|
||||||
|
|
||||||
|
# If model file does not exist, download
|
||||||
|
elif allow_download:
|
||||||
|
# Make sure valid model filename before attempting download
|
||||||
|
model_match = False
|
||||||
|
for item in GPT4All.list_models():
|
||||||
|
if model_filename == item["filename"]:
|
||||||
|
model_match = True
|
||||||
|
break
|
||||||
|
if not model_match:
|
||||||
|
raise ValueError(f"Model filename not in model list: {model_filename}")
|
||||||
|
return GPT4All.download_model(model_filename, model_path)
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to retrieve model")
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid model directory")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_model(model_filename, model_path):
|
||||||
|
def get_download_url(model_filename):
|
||||||
|
return f"https://gpt4all.io/models/{model_filename}"
|
||||||
|
|
||||||
|
# Download model
|
||||||
|
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||||
|
download_url = get_download_url(model_filename)
|
||||||
|
|
||||||
|
response = requests.get(download_url, stream=True)
|
||||||
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||||
|
block_size = 1048576 # 1 MB
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
|
with open(download_path, "wb") as file:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
file.write(data)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
# Validate download was successful
|
||||||
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"An error occurred during download. Downloaded file may not work."
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Model downloaded at: " + download_path)
|
||||||
|
return download_path
|
||||||
|
|
||||||
|
def generate(self, prompt: str, **generate_kwargs):
|
||||||
|
"""
|
||||||
|
Surfaced method of running generate without accessing model object.
|
||||||
|
"""
|
||||||
|
return self.model.generate(prompt, **generate_kwargs)
|
||||||
|
|
||||||
|
def chat_completion(self,
|
||||||
|
messages: List[Dict],
|
||||||
|
default_prompt_header: bool = True,
|
||||||
|
default_prompt_footer: bool = True,
|
||||||
|
verbose: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Format list of message dictionaries into a prompt and call model
|
||||||
|
generate on prompt. Returns a response dictionary with metadata and
|
||||||
|
generated content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Each dictionary should have a "role" key
|
||||||
|
with value of "system", "assistant", or "user" and a "content" key with a
|
||||||
|
string value. Messages are organized such that "system" messages are at top of prompt,
|
||||||
|
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
||||||
|
"Reponse: {content}".
|
||||||
|
default_prompt_header: If True (default), add default prompt header after any user specified system messages and
|
||||||
|
before user/assistant messages.
|
||||||
|
default_prompt_footer: If True (default), add default footer at end of prompt.
|
||||||
|
verbose: If True (default), print full prompt and generated response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dictionary with:
|
||||||
|
"model": name of model.
|
||||||
|
"usage": a dictionary with number of full prompt tokens, number of
|
||||||
|
generated tokens in response, and total tokens.
|
||||||
|
"choices": List of message dictionary where "content" is generated response and "role" is set
|
||||||
|
as "assistant". Right now, only one choice is returned by model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
full_prompt = self._build_prompt(messages,
|
||||||
|
default_prompt_header=default_prompt_header,
|
||||||
|
default_prompt_footer=default_prompt_footer)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(full_prompt)
|
||||||
|
|
||||||
|
response = self.model.generate(full_prompt)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
response_dict = {
|
||||||
|
"model": self.model.model_name,
|
||||||
|
"usage": {"prompt_tokens": len(full_prompt),
|
||||||
|
"completion_tokens": len(response),
|
||||||
|
"total_tokens" : len(full_prompt) + len(response)},
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_prompt(messages: List[Dict],
|
||||||
|
default_prompt_header=True,
|
||||||
|
default_prompt_footer=False) -> str:
|
||||||
|
full_prompt = ""
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_message = message["content"] + "\n"
|
||||||
|
full_prompt += system_message
|
||||||
|
|
||||||
|
if default_prompt_header:
|
||||||
|
full_prompt += """### Instruction:
|
||||||
|
The prompt below is a question to answer, a task to complete, or a conversation
|
||||||
|
to respond to; decide which and write an appropriate response.
|
||||||
|
\n### Prompt: """
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
user_message = "\n" + message["content"]
|
||||||
|
full_prompt += user_message
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
assistant_message = "\n### Response: " + message["content"]
|
||||||
|
full_prompt += assistant_message
|
||||||
|
|
||||||
|
if default_prompt_footer:
|
||||||
|
full_prompt += "\n### Response:"
|
||||||
|
|
||||||
|
return full_prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
|
||||||
|
# This needs to be updated for each new model
|
||||||
|
# TODO: Might be worth converting model_type to enum
|
||||||
|
|
||||||
|
if model_type == "gptj":
|
||||||
|
return pyllmodel.GPTJModel()
|
||||||
|
elif model_type == "llama":
|
||||||
|
return pyllmodel.LlamaModel()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
|
||||||
|
# This needs to be updated for each new model
|
||||||
|
|
||||||
|
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
|
||||||
|
if ".bin" not in model_name:
|
||||||
|
model_name += ".bin"
|
||||||
|
|
||||||
|
GPTJ_MODELS = [
|
||||||
|
"ggml-gpt4all-j-v1.3-groovy.bin",
|
||||||
|
"ggml-gpt4all-j-v1.2-jazzy.bin",
|
||||||
|
"ggml-gpt4all-j-v1.1-breezy.bin",
|
||||||
|
"ggml-gpt4all-j.bin"
|
||||||
|
]
|
||||||
|
|
||||||
|
LLAMA_MODELS = [
|
||||||
|
"ggml-gpt4all-l13b-snoozy.bin",
|
||||||
|
"ggml-vicuna-7b-1.1-q4_2.bin",
|
||||||
|
"ggml-vicuna-13b-1.1-q4_2.bin",
|
||||||
|
"ggml-wizardLM-7B.q4_2.bin",
|
||||||
|
"ggml-stable-vicuna-13B.q4_2.bin"
|
||||||
|
]
|
||||||
|
|
||||||
|
if model_name in GPTJ_MODELS:
|
||||||
|
return pyllmodel.GPTJModel()
|
||||||
|
elif model_name in LLAMA_MODELS:
|
||||||
|
return pyllmodel.LlamaModel()
|
||||||
|
else:
|
||||||
|
err_msg = f"""No corresponding model for provided filename {model_name}.
|
||||||
|
If this is a custom model, make sure to specify a valid model_type.
|
||||||
|
"""
|
||||||
|
raise ValueError(err_msg)
|
241
gpt4all-bindings/python/gpt4all/pyllmodel.py
Normal file
241
gpt4all-bindings/python/gpt4all/pyllmodel.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
from io import StringIO
|
||||||
|
import pkg_resources
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# TODO: provide a config file to make this more robust
|
||||||
|
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build")
|
||||||
|
|
||||||
|
def load_llmodel_library():
|
||||||
|
system = platform.system()
|
||||||
|
|
||||||
|
def get_c_shared_lib_extension():
|
||||||
|
if system == "Darwin":
|
||||||
|
return "dylib"
|
||||||
|
elif system == "Linux":
|
||||||
|
return "so"
|
||||||
|
elif system == "Windows":
|
||||||
|
return "dll"
|
||||||
|
else:
|
||||||
|
raise Exception("Operating System not supported")
|
||||||
|
|
||||||
|
c_lib_ext = get_c_shared_lib_extension()
|
||||||
|
|
||||||
|
llmodel_file = "libllmodel" + '.' + c_lib_ext
|
||||||
|
llama_file = "libllama" + '.' + c_lib_ext
|
||||||
|
llama_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llama_file)))
|
||||||
|
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file)))
|
||||||
|
|
||||||
|
# For windows
|
||||||
|
llama_dir = llama_dir.replace("\\", "\\\\")
|
||||||
|
print(llama_dir)
|
||||||
|
llmodel_dir = llmodel_dir.replace("\\", "\\\\")
|
||||||
|
print(llmodel_dir)
|
||||||
|
|
||||||
|
llama_lib = ctypes.CDLL(llama_dir, mode=ctypes.RTLD_GLOBAL)
|
||||||
|
llmodel_lib = ctypes.CDLL(llmodel_dir)
|
||||||
|
|
||||||
|
return llmodel_lib, llama_lib
|
||||||
|
|
||||||
|
|
||||||
|
llmodel, llama = load_llmodel_library()
|
||||||
|
|
||||||
|
# Define C function signatures using ctypes
|
||||||
|
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
||||||
|
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
||||||
|
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
||||||
|
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
||||||
|
|
||||||
|
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||||
|
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||||
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||||
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||||
|
|
||||||
|
class LLModelPromptContext(ctypes.Structure):
|
||||||
|
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("logits_size", ctypes.c_size_t),
|
||||||
|
("tokens", ctypes.POINTER(ctypes.c_int32)),
|
||||||
|
("tokens_size", ctypes.c_size_t),
|
||||||
|
("n_past", ctypes.c_int32),
|
||||||
|
("n_ctx", ctypes.c_int32),
|
||||||
|
("n_predict", ctypes.c_int32),
|
||||||
|
("top_k", ctypes.c_int32),
|
||||||
|
("top_p", ctypes.c_float),
|
||||||
|
("temp", ctypes.c_float),
|
||||||
|
("n_batch", ctypes.c_int32),
|
||||||
|
("repeat_penalty", ctypes.c_float),
|
||||||
|
("repeat_last_n", ctypes.c_int32),
|
||||||
|
("context_erase", ctypes.c_float)]
|
||||||
|
|
||||||
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||||
|
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||||
|
|
||||||
|
llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ResponseCallback,
|
||||||
|
ResponseCallback,
|
||||||
|
RecalculateCallback,
|
||||||
|
ctypes.POINTER(LLModelPromptContext)]
|
||||||
|
|
||||||
|
|
||||||
|
class LLModel:
|
||||||
|
"""
|
||||||
|
Base class and universal wrapper for GPT4All language models
|
||||||
|
built around llmodel C-API.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
model: llmodel_model
|
||||||
|
Ctype pointer to underlying model
|
||||||
|
model_type : str
|
||||||
|
Model architecture identifier
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type: str = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.model_name = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_model(self, model_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Load model from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path : str
|
||||||
|
Model filepath
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
True if model loaded successfully, False otherwise
|
||||||
|
"""
|
||||||
|
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8'))
|
||||||
|
filename = os.path.basename(model_path)
|
||||||
|
self.model_name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
if llmodel.llmodel_isModelLoaded(self.model):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
prompt: str,
|
||||||
|
logits_size: int = 0,
|
||||||
|
tokens_size: int = 0,
|
||||||
|
n_past: int = 0,
|
||||||
|
n_ctx: int = 1024,
|
||||||
|
n_predict: int = 128,
|
||||||
|
top_k: int = 40,
|
||||||
|
top_p: float = .9,
|
||||||
|
temp: float = .1,
|
||||||
|
n_batch: int = 8,
|
||||||
|
repeat_penalty: float = 1.2,
|
||||||
|
repeat_last_n: int = 10,
|
||||||
|
context_erase: float = .5) -> str:
|
||||||
|
"""
|
||||||
|
Generate response from model from a prompt.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prompt: str
|
||||||
|
Question, task, or conversation for model to respond to
|
||||||
|
add_default_header: bool, optional
|
||||||
|
Whether to add a prompt header (default is True)
|
||||||
|
add_default_footer: bool, optional
|
||||||
|
Whether to add a prompt footer (default is True)
|
||||||
|
verbose: bool, optional
|
||||||
|
Whether to print prompt and response
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Model response str
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = prompt.encode('utf-8')
|
||||||
|
prompt = ctypes.c_char_p(prompt)
|
||||||
|
|
||||||
|
# Change stdout to StringIO so we can collect response
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
collect_response = StringIO()
|
||||||
|
sys.stdout = collect_response
|
||||||
|
|
||||||
|
context = LLModelPromptContext(
|
||||||
|
logits_size=logits_size,
|
||||||
|
tokens_size=tokens_size,
|
||||||
|
n_past=n_past,
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
n_predict=n_predict,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temp=temp,
|
||||||
|
n_batch=n_batch,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
repeat_last_n=repeat_last_n,
|
||||||
|
context_erase=context_erase
|
||||||
|
)
|
||||||
|
|
||||||
|
llmodel.llmodel_prompt(self.model,
|
||||||
|
prompt,
|
||||||
|
ResponseCallback(self._prompt_callback),
|
||||||
|
ResponseCallback(self._response_callback),
|
||||||
|
RecalculateCallback(self._recalculate_callback),
|
||||||
|
context)
|
||||||
|
|
||||||
|
response = collect_response.getvalue()
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
# Remove the unnecessary new lines from response
|
||||||
|
response = re.sub(r"\n(?!\n)", "", response).strip()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Empty prompt callback
|
||||||
|
@staticmethod
|
||||||
|
def _prompt_callback(token_id, response):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Empty response callback method that just prints response to be collected
|
||||||
|
@staticmethod
|
||||||
|
def _response_callback(token_id, response):
|
||||||
|
print(response.decode('utf-8'))
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Empty recalculate callback
|
||||||
|
@staticmethod
|
||||||
|
def _recalculate_callback(is_recalculating):
|
||||||
|
return is_recalculating
|
||||||
|
|
||||||
|
|
||||||
|
class GPTJModel(LLModel):
|
||||||
|
|
||||||
|
model_type = "gptj"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = llmodel.llmodel_gptj_create()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.model is not None:
|
||||||
|
llmodel.llmodel_gptj_destroy(self.model)
|
||||||
|
super().__del__()
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(LLModel):
|
||||||
|
|
||||||
|
model_type = "llama"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = llmodel.llmodel_llama_create()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.model is not None:
|
||||||
|
llmodel.llmodel_llama_destroy(self.model)
|
||||||
|
super().__del__()
|
16
gpt4all-bindings/python/makefile
Normal file
16
gpt4all-bindings/python/makefile
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
SHELL:=/bin/bash -o pipefail
|
||||||
|
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
PYTHON:=python3
|
||||||
|
|
||||||
|
venv:
|
||||||
|
if [ ! -d $(ROOT_DIR)/env ]; then $(PYTHON) -m venv $(ROOT_DIR)/env; fi
|
||||||
|
|
||||||
|
documentation:
|
||||||
|
rm -rf ./site && mkdocs build
|
||||||
|
|
||||||
|
wheel:
|
||||||
|
rm -rf dist/ build/ gpt4all/llmodel_DO_NOT_MODIFY; python setup.py bdist_wheel;
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf {.pytest_cache,env,gpt4all.egg-info}
|
||||||
|
find . | grep -E "(__pycache__|\.pyc|\.pyo$\)" | xargs rm -rf
|
76
gpt4all-bindings/python/mkdocs.yml
Normal file
76
gpt4all-bindings/python/mkdocs.yml
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
site_name: GPT4All Python Documentation
|
||||||
|
repo_url: https://github.com/nomic-ai/gpt4all
|
||||||
|
repo_name: nomic-ai/gpt4all
|
||||||
|
site_url: https://docs.nomic.ai # TODO: change
|
||||||
|
edit_uri: edit/main/docs/
|
||||||
|
site_description: Python bindings for GPT4All
|
||||||
|
copyright: Copyright © 2023 Nomic, Inc
|
||||||
|
use_directory_urls: false
|
||||||
|
|
||||||
|
nav:
|
||||||
|
- 'index.md'
|
||||||
|
- 'API Reference':
|
||||||
|
- 'gpt4all_api.md'
|
||||||
|
|
||||||
|
theme:
|
||||||
|
name: material
|
||||||
|
palette:
|
||||||
|
primary: white
|
||||||
|
logo: assets/nomic.png
|
||||||
|
favicon: assets/favicon.ico
|
||||||
|
features:
|
||||||
|
- navigation.instant
|
||||||
|
- navigation.tracking
|
||||||
|
- navigation.sections
|
||||||
|
# - navigation.tabs
|
||||||
|
# - navigation.tabs.sticky
|
||||||
|
|
||||||
|
markdown_extensions:
|
||||||
|
- pymdownx.highlight:
|
||||||
|
anchor_linenums: true
|
||||||
|
- pymdownx.inlinehilite
|
||||||
|
- pymdownx.snippets
|
||||||
|
- pymdownx.details
|
||||||
|
- pymdownx.superfences
|
||||||
|
- pymdownx.tabbed:
|
||||||
|
alternate_style: true
|
||||||
|
- pymdownx.emoji:
|
||||||
|
emoji_index: !!python/name:materialx.emoji.twemoji
|
||||||
|
emoji_generator: !!python/name:materialx.emoji.to_svg
|
||||||
|
options:
|
||||||
|
custom_icons:
|
||||||
|
- docs/overrides/.icons
|
||||||
|
- tables
|
||||||
|
- admonition
|
||||||
|
- codehilite:
|
||||||
|
css_class: highlight
|
||||||
|
|
||||||
|
extra_css:
|
||||||
|
- css/custom.css
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- mkdocstrings:
|
||||||
|
handlers:
|
||||||
|
python:
|
||||||
|
options:
|
||||||
|
show_root_heading: True
|
||||||
|
heading_level: 4
|
||||||
|
show_root_full_path: false
|
||||||
|
docstring_section_style: list
|
||||||
|
#- material/social:
|
||||||
|
# cards_font: Roboto
|
||||||
|
|
||||||
|
#- mkdocs-jupyter:
|
||||||
|
# ignore_h1_titles: True
|
||||||
|
# show_input: True
|
||||||
|
|
||||||
|
extra:
|
||||||
|
generator: false
|
||||||
|
analytics:
|
||||||
|
provider: google
|
||||||
|
property: G-NPXC8BYHJV
|
||||||
|
#social:
|
||||||
|
# - icon: fontawesome/brands/twitter
|
||||||
|
# link: https://twitter.com/nomic_ai
|
||||||
|
# - icon: material/fruit-pineapple
|
||||||
|
# link: https://www.youtube.com/watch?v=628eVJgHD6I
|
89
gpt4all-bindings/python/setup.py
Normal file
89
gpt4all-bindings/python/setup.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
package_name = "gpt4all"
|
||||||
|
|
||||||
|
# Define the location of your prebuilt C library files
|
||||||
|
SRC_CLIB_DIRECtORY = os.path.join("..", "..", "llmodel")
|
||||||
|
SRC_CLIB_BUILD_DIRECTORY = os.path.join("..", "..", "llmodel", "build")
|
||||||
|
|
||||||
|
LIB_NAME = "llmodel"
|
||||||
|
|
||||||
|
DEST_CLIB_DIRECTORY = os.path.join(package_name, f"{LIB_NAME}_DO_NOT_MODIFY")
|
||||||
|
DEST_CLIB_BUILD_DIRECTORY = os.path.join(DEST_CLIB_DIRECTORY, "build")
|
||||||
|
|
||||||
|
system = platform.system()
|
||||||
|
|
||||||
|
def get_c_shared_lib_extension():
|
||||||
|
|
||||||
|
if system == "Darwin":
|
||||||
|
return "dylib"
|
||||||
|
elif system == "Linux":
|
||||||
|
return "so"
|
||||||
|
elif system == "Windows":
|
||||||
|
return "dll"
|
||||||
|
else:
|
||||||
|
raise Exception("Operating System not supported")
|
||||||
|
|
||||||
|
lib_ext = get_c_shared_lib_extension()
|
||||||
|
|
||||||
|
def copy_prebuilt_C_lib(src_dir, dest_dir, dest_build_dir):
|
||||||
|
files_copied = 0
|
||||||
|
|
||||||
|
if not os.path.exists(dest_dir):
|
||||||
|
os.mkdir(dest_dir)
|
||||||
|
os.mkdir(dest_build_dir)
|
||||||
|
|
||||||
|
for dirpath, _, filenames in os.walk(src_dir):
|
||||||
|
for item in filenames:
|
||||||
|
# copy over header files to dest dir
|
||||||
|
s = os.path.join(dirpath, item)
|
||||||
|
if item.endswith(".h"):
|
||||||
|
d = os.path.join(dest_dir, item)
|
||||||
|
shutil.copy2(s, d)
|
||||||
|
files_copied += 1
|
||||||
|
if item.endswith(lib_ext):
|
||||||
|
s = os.path.join(dirpath, item)
|
||||||
|
d = os.path.join(dest_build_dir, item)
|
||||||
|
shutil.copy2(s, d)
|
||||||
|
files_copied += 1
|
||||||
|
|
||||||
|
return files_copied
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: You must provide correct path to the prebuilt llmodel C library.
|
||||||
|
# Specifically, the llmodel.h and C shared library are needed.
|
||||||
|
copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
|
||||||
|
DEST_CLIB_DIRECTORY,
|
||||||
|
DEST_CLIB_BUILD_DIRECTORY)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version="0.1.9",
|
||||||
|
description="Python bindings for GPT4All",
|
||||||
|
author="Richard Guo",
|
||||||
|
author_email="richard@nomic.ai",
|
||||||
|
url="https://pypi.org/project/gpt4all/",
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
],
|
||||||
|
python_requires='>=3.8',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=['requests', 'tqdm'],
|
||||||
|
extras_require={
|
||||||
|
'dev': [
|
||||||
|
'pytest',
|
||||||
|
'twine',
|
||||||
|
'mkdocs-material',
|
||||||
|
'mkautodoc',
|
||||||
|
'mkdocstrings[python]',
|
||||||
|
'mkdocs-jupyter'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},
|
||||||
|
include_package_data=True
|
||||||
|
)
|
62
gpt4all-bindings/python/tests/test_gpt4all.py
Normal file
62
gpt4all-bindings/python/tests/test_gpt4all.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from gpt4all.gpt4all import GPT4All
|
||||||
|
|
||||||
|
def test_invalid_model_type():
|
||||||
|
model_type = "bad_type"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4All.get_model_from_type(model_type)
|
||||||
|
|
||||||
|
def test_valid_model_type():
|
||||||
|
model_type = "gptj"
|
||||||
|
assert GPT4All.get_model_from_type(model_type).model_type == model_type
|
||||||
|
|
||||||
|
def test_invalid_model_name():
|
||||||
|
model_name = "bad_filename.bin"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GPT4All.get_model_from_name(model_name)
|
||||||
|
|
||||||
|
def test_valid_model_name():
|
||||||
|
model_name = "ggml-gpt4all-l13b-snoozy"
|
||||||
|
model_type = "llama"
|
||||||
|
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||||
|
model_name += ".bin"
|
||||||
|
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||||
|
|
||||||
|
def test_build_prompt():
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello there."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hi, how can I help you?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Reverse a list in Python."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = """You are a helpful assistant.\
|
||||||
|
\n### Instruction:
|
||||||
|
The prompt below is a question to answer, a task to complete, or a conversation
|
||||||
|
to respond to; decide which and write an appropriate response.\
|
||||||
|
### Prompt:\
|
||||||
|
Hello there.\
|
||||||
|
Response: Hi, how can I help you?\
|
||||||
|
Reverse a list in Python.\
|
||||||
|
### Response:"""
|
||||||
|
|
||||||
|
print(expected_prompt)
|
||||||
|
|
||||||
|
full_prompt = GPT4All._build_prompt(messages, default_prompt_footer=True, default_prompt_header=True)
|
||||||
|
|
||||||
|
print("\n\n\n")
|
||||||
|
print(full_prompt)
|
||||||
|
assert len(full_prompt) == len(expected_prompt)
|
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
44
gpt4all-bindings/python/tests/test_pyllmodel.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from io import StringIO
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from gpt4all import pyllmodel
|
||||||
|
|
||||||
|
# TODO: Integration test for loadmodel and prompt.
|
||||||
|
# # Right now, too slow b/c it requries file download.
|
||||||
|
|
||||||
|
def test_create_gptj():
|
||||||
|
gptj = pyllmodel.GPTJModel()
|
||||||
|
assert gptj.model_type == "gptj"
|
||||||
|
|
||||||
|
def test_create_llama():
|
||||||
|
llama = pyllmodel.LlamaModel()
|
||||||
|
assert llama.model_type == "llama"
|
||||||
|
|
||||||
|
def prompt_unloaded_gptj():
|
||||||
|
gptj = pyllmodel.GPTJModel()
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
collect_response = StringIO()
|
||||||
|
sys.stdout = collect_response
|
||||||
|
|
||||||
|
gptj.prompt("hello there")
|
||||||
|
|
||||||
|
response = collect_response.getvalue()
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
response = response.strip()
|
||||||
|
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
|
||||||
|
|
||||||
|
def prompt_unloaded_llama():
|
||||||
|
llama = pyllmodel.LlamaModel()
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
collect_response = StringIO()
|
||||||
|
sys.stdout = collect_response
|
||||||
|
|
||||||
|
llama.prompt("hello there")
|
||||||
|
|
||||||
|
response = collect_response.getvalue()
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
|
||||||
|
response = response.strip()
|
||||||
|
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user