Merge pull request #4488 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-11-06 12:18:55 -03:00 committed by GitHub
commit 1fba6db69f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 862 additions and 1598 deletions

View File

@ -22,7 +22,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character)
* Very efficient text streaming
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples))
* OpenAI-compatible API server
## Documentation
@ -412,8 +412,8 @@ Optionally, you can use the following command-line flags:
| `--api` | Enable the API extension. |
| `--public-api` | Create a public URL for the API using Cloudfare. |
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. |
| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. |
| `--api-port API_PORT` | The listening port for the API. |
| `--api-key API_KEY` | API authentication key. |
#### Multimodal

View File

@ -1,114 +0,0 @@
import asyncio
import html
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/chat-stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(user_input, history):
# Note: the selected defaults change from time to time.
request = {
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False,
'_continue': False,
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['history']
case 'stream_end':
return
async def print_response_stream(user_input, history):
cur_len = 0
async for new_history in run(user_input, history):
cur_message = new_history['visible'][-1][1][cur_len:]
cur_len += len(cur_message)
print(html.unescape(cur_message), end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
asyncio.run(print_response_stream(user_input, history))

View File

@ -1,94 +0,0 @@
import html
import json
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/chat'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat'
def run(user_input, history):
request = {
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False,
'_continue': False,
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['history']
print(json.dumps(result, indent=4))
print()
print(html.unescape(result['visible'][-1][1]))
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
run(user_input, history)

View File

@ -1,176 +0,0 @@
#!/usr/bin/env python3
import requests
HOST = '0.0.0.0:5000'
def generate(prompt, tokens=200):
request = {'prompt': prompt, 'max_new_tokens': tokens}
response = requests.post(f'http://{HOST}/api/v1/generate', json=request)
if response.status_code == 200:
return response.json()['results'][0]['text']
def model_api(request):
response = requests.post(f'http://{HOST}/api/v1/model', json=request)
return response.json()
# print some common settings
def print_basic_model_info(response):
basic_settings = ['truncation_length', 'instruction_template']
print("Model: ", response['result']['model_name'])
print("Lora(s): ", response['result']['lora_names'])
for setting in basic_settings:
print(setting, "=", response['result']['shared.settings'][setting])
# model info
def model_info():
response = model_api({'action': 'info'})
print_basic_model_info(response)
# simple loader
def model_load(model_name):
return model_api({'action': 'load', 'model_name': model_name})
# complex loader
def complex_model_load(model):
def guess_groupsize(model_name):
if '1024g' in model_name:
return 1024
elif '128g' in model_name:
return 128
elif '32g' in model_name:
return 32
else:
return -1
req = {
'action': 'load',
'model_name': model,
'args': {
'loader': 'AutoGPTQ',
'bf16': False,
'load_in_8bit': False,
'groupsize': 0,
'wbits': 0,
# llama.cpp
'threads': 0,
'n_batch': 512,
'no_mmap': False,
'mlock': False,
'cache_capacity': None,
'n_gpu_layers': 0,
'n_ctx': 2048,
# RWKV
'rwkv_strategy': None,
'rwkv_cuda_on': False,
# b&b 4-bit
# 'load_in_4bit': False,
# 'compute_dtype': 'float16',
# 'quant_type': 'nf4',
# 'use_double_quant': False,
# "cpu": false,
# "auto_devices": false,
# "gpu_memory": null,
# "cpu_memory": null,
# "disk": false,
# "disk_cache_dir": "cache",
},
}
model = model.lower()
if '4bit' in model or 'gptq' in model or 'int4' in model:
req['args']['wbits'] = 4
req['args']['groupsize'] = guess_groupsize(model)
elif '3bit' in model:
req['args']['wbits'] = 3
req['args']['groupsize'] = guess_groupsize(model)
else:
req['args']['gptq_for_llama'] = False
if '8bit' in model:
req['args']['load_in_8bit'] = True
elif '-hf' in model or 'fp16' in model:
if '7b' in model:
req['args']['bf16'] = True # for 24GB
elif '13b' in model:
req['args']['load_in_8bit'] = True # for 24GB
elif 'gguf' in model:
# req['args']['threads'] = 16
if '7b' in model:
req['args']['n_gpu_layers'] = 100
elif '13b' in model:
req['args']['n_gpu_layers'] = 100
elif '30b' in model or '33b' in model:
req['args']['n_gpu_layers'] = 59 # 24GB
elif '65b' in model:
req['args']['n_gpu_layers'] = 42 # 24GB
elif 'rwkv' in model:
req['args']['rwkv_cuda_on'] = True
if '14b' in model:
req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB
else:
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
return model_api(req)
if __name__ == '__main__':
for model in model_api({'action': 'list'})['result']:
try:
resp = complex_model_load(model)
if 'error' in resp:
print(f"{model} FAIL Error: {resp['error']['message']}")
continue
else:
print_basic_model_info(resp)
ans = generate("0,1,1,2,3,5,8,13,", tokens=2)
if '21' in ans:
print(f"{model} PASS ({ans})")
else:
print(f"{model} FAIL ({ans})")
except Exception as e:
print(f"{model} FAIL Exception: {repr(e)}")
# 0,1,1,2,3,5,8,13, is the fibonacci sequence, the next number is 21.
# Some results below.
""" $ ./model-api-example.py
Model: 4bit_gpt4-x-alpaca-13b-native-4bit-128g-cuda
Lora(s): []
truncation_length = 2048
instruction_template = Alpaca
4bit_gpt4-x-alpaca-13b-native-4bit-128g-cuda PASS (21)
Model: 4bit_WizardLM-13B-Uncensored-4bit-128g
Lora(s): []
truncation_length = 2048
instruction_template = WizardLM
4bit_WizardLM-13B-Uncensored-4bit-128g PASS (21)
Model: Aeala_VicUnlocked-alpaca-30b-4bit
Lora(s): []
truncation_length = 2048
instruction_template = Alpaca
Aeala_VicUnlocked-alpaca-30b-4bit PASS (21)
Model: alpaca-30b-4bit
Lora(s): []
truncation_length = 2048
instruction_template = Alpaca
alpaca-30b-4bit PASS (21)
"""

View File

@ -1,88 +0,0 @@
import asyncio
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context):
# Note: the selected defaults change from time to time.
request = {
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['text']
case 'stream_end':
return
async def print_response_stream(prompt):
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
asyncio.run(print_response_stream(prompt))

View File

@ -1,65 +0,0 @@
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/generate'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(prompt):
request = {
'prompt': prompt,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['text']
print(prompt + result)
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
run(prompt)

View File

@ -1,124 +1,164 @@
# An OpenedAI API (openai like)
## OpenAI compatible API
This extension creates an API that works kind of like openai (ie. api.openai.com).
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
## Setup & installation
Install the requirements:
If you did not use the one-click installers, you may need to install the requirements first:
```
pip3 install -r requirements.txt
pip install -r extensions/openai/requirements.txt
```
It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this.
### Starting the API
Make sure you enable it in server launch parameters, it should include:
Add `--extensions openai` to your command-line flags.
* To create a public Cloudflare URL, also add the `--public-api` flag.
* To listen on your local network, also add the `--listen` flag.
* To change the port, which is 5000 by default, use `--port 1234` (change 1234 to your desired port number).
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
#### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set the following variables in your `settings.yaml` file:
```
--extensions openai
```
You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint.
To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
For example:
```
SD_WEBUI_URL=http://127.0.0.1:7861
```
## Quick start
1. Install the requirements.txt (pip)
2. Enable the `openeai` module (--extensions openai), restart the server.
3. Configure the openai client
Most openai application can be configured to connect the API if you set the following environment variables:
```shell
# Sample .env file:
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://0.0.0.0:5001/v1
```
If needed, replace 0.0.0.0 with the IP/port of your server.
### Settings
To adjust your default settings, you can add the following to your `settings.yaml` file.
```
openai-port: 5002
openai-embedding_device: cuda
openai-embedding_model: all-mpnet-base-v2
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
If you've configured the environment variables, please note that settings from `settings.yaml` won't take effect. For instance, if you set `openai-port: 5002` in `settings.yaml` but `OPENEDAI_PORT=5001` in the environment variables, the extension will use `5001` as the port number.
### Examples
When using `cache_embedding_model.py` to preload the embedding model during Docker image building, consider the following:
For the documentation with all the parameters, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
- If you wish to use the default settings, leave the environment variables unset.
- If you intend to change the default embedding model, ensure that you configure the environment variable `OPENEDAI_EMBEDDING_MODEL` to the desired model. Avoid setting `openai-embedding_model` in `settings.yaml` because those settings only take effect after the server starts.
The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters).
### Models
#### Completions
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start.
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date.
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
```
.*wizard.*vicuna:
mode: 'instruct'
instruction_template: 'Vicuna-v1.1'
```shell
curl http://127.0.0.1:5000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"prompt": "This is a cake recipe:\n\n1.",
"max_tokens": 200,
"temperature": 1,
"top_p": 0.9,
"seed": 10
}'
```
This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this:
#### Chat completions
```
user: "USER:"
bot: "ASSISTANT:"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `models/config.yaml`.
```shell
curl http://127.0.0.1:5000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "Hello!"
}
],
"mode": "instruct",
"instruction_template": "Alpaca"
}'
```
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
#### Chat completions with characters
If you see this in your logs, it probably means that the correct format could not be loaded:
```
Warning: Loaded default instruction-following template for model.
```shell
curl http://127.0.0.1:5000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "Hello! Who are you?"
}
],
"mode": "chat",
"character": "Example"
}'
```
### Embeddings (alpha)
#### SSE streaming
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
```shell
curl http://127.0.0.1:5000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": "Hello!"
}
],
"mode": "instruct",
"instruction_template": "Alpaca",
"stream": true
}'
```
| model name | dimensions | input max tokens | speed | size | Avg. performance |
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
#### Python chat example
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
```python
import requests
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
url = "http://127.0.0.1:5000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
history = []
while True:
user_message = input("> ")
history.append({"role": "user", "content": user_message})
data = {
"mode": "chat",
"character": "Example",
"messages": history
}
response = requests.post(url, headers=headers, json=data, verify=False)
assistant_message = response.json()['choices'][0]['message']['content']
history.append({"role": "assistant", "content": assistant_message})
print(assistant_message)
```
### Client Application Setup
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
```shell
OPENAI_API_HOST=http://127.0.0.1:5000
```
or
```shell
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://127.0.0.1:500/v1
```
With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
@ -128,7 +168,7 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://0.0.0.0:5001/v1
```
If needed, replace 0.0.0.0 with the IP/port of your server.
If needed, replace 127.0.0.1 with the IP/port of your server.
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
@ -157,8 +197,22 @@ const api = new ChatGPTAPI({
apiBaseUrl: process.env.OPENAI_API_BASE
});
```
### Embeddings (alpha)
## API Documentation & Examples
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
| model name | dimensions | input max tokens | speed | size | Avg. performance |
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
### API Documentation & Examples
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
@ -185,7 +239,7 @@ text = response['choices'][0]['message']['content']
print(text)
```
## Compatibility & not so compatibility
### Compatibility & not so compatibility
| API endpoint | tested with | notes |
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
@ -195,7 +249,7 @@ print(text)
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
| /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
@ -209,28 +263,8 @@ print(text)
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
| /v1/search | openai.search, engines.search | not yet supported |
Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly.
Some hacky mappings:
| OpenAI | text-generation-webui | note |
| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| model | - | Ignored, the model is not changed |
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
| n | 1 | variations are not supported yet. |
| 1 | num_beams | hardcoded to 1 |
| 1.0 | typical_p | hardcoded to 1.0 |
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet |
### Applications
#### Applications
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
@ -249,15 +283,3 @@ Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment v
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
## Future plans
- better error handling
- model changing, esp. something for swapping loras or embedding models
- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
## Bugs? Feedback? Comments? Pull requests?
To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable.
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.

View File

@ -3,9 +3,11 @@ import time
import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api
from modules import shared
from modules.logging_colors import logger
def setup():
logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API")
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
if shared.args.public_api:
time.sleep(5)

View File

@ -1,18 +1,23 @@
import copy
import time
from collections import deque
import tiktoken
import torch
import torch.nn.functional as F
import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line
from extensions.openai.utils import debug_msg
from modules import shared
from modules.chat import (
generate_chat_prompt,
generate_chat_reply,
load_character_memoized
)
from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
logits[0, self.keys] += self.values
debug_msg(" --> ", logits[0, self.keys])
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
return logits
def __repr__(self):
@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor):
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self))
return logits
def __repr__(self):
@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs):
return logprobs
def marshal_common_params(body):
# Request Parameters
# Try to use openai defaults or map them to something with the same intent
def process_parameters(body, is_legacy=False):
generate_params = body
max_tokens_str = 'length' if is_legacy else 'max_tokens'
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
if generate_params['truncation_length'] == 0:
if shared.args.loader and shared.args.loader.lower().startswith('exllama'):
generate_params['truncation_length'] = shared.args.max_seq_len
elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
generate_params['truncation_length'] = shared.args.n_ctx
else:
generate_params['truncation_length'] = shared.settings['truncation_length']
req_params = get_default_req_params()
# Common request parameters
req_params['truncation_length'] = shared.settings['truncation_length']
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params['requested_model'] = body.get('model', shared.model_name)
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
n = default(body, 'n', 1)
if n != 1:
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
if body['preset'] is not None:
preset = load_preset_memoized(body['preset'])
generate_params.update(preset)
generate_params['custom_stopping_strings'] = []
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
generate_params['custom_stopping_strings'] = [body['stop']]
elif isinstance(body['stop'], list):
req_params['stopping_strings'] = body['stop']
# presence_penalty - ignored
# frequency_penalty - ignored
# pass through unofficial params
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
# user - ignored
generate_params['custom_stopping_strings'] = body['stop']
logits_processor = []
logit_bias = body.get('logit_bias', None)
@ -110,12 +102,13 @@ def marshal_common_params(body):
# XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try:
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
encoder = tiktoken.encoding_for_model(generate_params['model'])
new_logit_bias = {}
for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
continue
new_logit_bias[str(int(x))] = bias
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
logit_bias = new_logit_bias
@ -126,238 +119,131 @@ def marshal_common_params(body):
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
req_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([req_params['logprob_proc']])
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
return req_params
return generate_params
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
# functions
if body.get('functions', []): # chat only
def convert_history(history):
'''
Chat histories in this program are in the format [message, reply].
This function converts OpenAI histories to that format.
'''
chat_dialogue = []
current_message = ""
current_reply = ""
user_input = ""
for entry in history:
content = entry["content"]
role = entry["role"]
if role == "user":
user_input = content
if current_message:
chat_dialogue.append([current_message, ''])
current_message = ""
current_message = content
elif role == "assistant":
current_reply = content
if current_message:
chat_dialogue.append([current_message, current_reply])
current_message = ""
current_reply = ""
else:
chat_dialogue.append(['', current_reply])
# if current_message:
# chat_dialogue.append([current_message, ''])
return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
if body.get('function_call', ''):
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
role_formats = {
'user': 'User: {message}\n',
'assistant': 'Assistant: {message}\n',
'system': '{message}',
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
'prompt': 'Assistant:',
}
if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
if shared.settings['instruction_template']:
try:
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct.get('context', '') # can be missing
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if 'Alpaca' in shared.settings['instruction_template']:
req_params['stopping_strings'].extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
chat_msgs = []
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
context_msg = end_line(context_msg)
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if 'prompt' in body:
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
for m in messages:
if 'role' not in m:
raise InvalidRequestError(message="messages: missing role", param='messages')
elif m['role'] == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
if 'content' not in m:
raise InvalidRequestError(message="messages: missing content", param='messages')
role = m['role']
content = m['content']
# name = m.get('name', None)
# function_call = m.get('function_call', None) # user name or function name with output in content
msg = role_formats[role].format(message=content)
if role == 'system':
system_msgs.extend([msg])
elif role == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
else:
chat_msgs.extend([msg])
system_msg = '\n'.join(system_msgs)
system_msg = end_line(system_msg)
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
raise InvalidRequestError(message=err_msg, param='messages')
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}")
# raise InvalidRequestError(message=err_msg, params='max_tokens')
return prompt, token_count
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions
object_type = 'chat.completions'
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
continue_ = body['continue_']
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# Instruction template
instruction_template = body['instruction_template'] or shared.settings['instruction_template']
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
name1_instruct = body['name1_instruct'] or name1_instruct
name2_instruct = body['name2_instruct'] or name2_instruct
context_instruct = body['context_instruct'] or context_instruct
turn_template = body['turn_template'] or turn_template
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# Chat character
character = body['character'] or shared.settings['character']
character = "Assistant" if character == "None" else character
name1 = body['name1'] or shared.settings['name1']
name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False)
name2 = body['name2'] or name2
context = body['context'] or context
greeting = body['greeting'] or greeting
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
# History
user_input, history = convert_history(messages)
stopping_strings = req_params.pop('stopping_strings', [])
generate_params.update({
'mode': body['mode'],
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'turn_template': turn_template,
'chat-instruct_command': body['chat_instruct_command'],
'history': history,
'stream': stream
})
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
max_tokens = generate_params['max_new_tokens']
if max_tokens in [None, 0]:
generate_params['max_new_tokens'] = 200
generate_params['auto_max_new_tokens'] = True
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions
stream_object_type = 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
def chat_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
@ -376,262 +262,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
# chunk[resp_list][0]["logprobs"] = None
return chunk
yield chat_streaming_chunk('')
if stream:
yield chat_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
prompt = generate_chat_prompt(user_input, generate_params)
token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params})
stopping_strings = req_params.pop('stopping_strings', [])
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
answer = a['internal'][-1][1]
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = chat_streaming_chunk(new_content)
chunk = chat_streaming_chunk(new_content)
yield chunk
# to get the correct token_count, strip leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
yield chunk
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length"
chunk = chat_streaming_chunk('')
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
if stream:
chunk = chat_streaming_chunk('')
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
yield chunk
else:
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
yield resp
def completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
object_type = 'text_completion'
def completions_common(body: dict, is_legacy: bool = False, stream=False):
object_type = 'text_completion.chunk' if stream else 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
# common params
req_params = marshal_common_params(body)
req_params['stream'] = False
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens']
generate_params['stream'] = stream
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
generate_params['echo'] = body.get('echo', generate_params['echo'])
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
if not stream:
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
yield resp
else:
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
yield text_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
# strip extra leading space off new generated content
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
yield chunk
# to get the correct count, we strip the leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
chunk = text_streaming_chunk('')
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
return resp
# generator
def stream_completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
# object_type = 'text_completion'
stream_object_type = 'text_completion.chunk'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
yield text_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
yield chunk
# to get the correct count, we strip the leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop()
chunk = text_streaming_chunk('')
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
def stream_chat_completions(body: dict, is_legacy: bool = False):
for resp in chat_completions_common(body, is_legacy, stream=True):
yield resp
def completions(body: dict, is_legacy: bool = False) -> dict:
generator = completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True):
yield resp

View File

@ -1,78 +0,0 @@
import copy
# Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = {
'max_new_tokens': 16, # 'Inf' for chat
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'temperature': 1.0,
'temperature_last': False,
'top_p': 1.0,
'min_p': 0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048, # first use shared.settings value
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0.0, # In units of 1e-4
'eta_cutoff': 0.0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1.0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'custom_stopping_strings': '',
# 'logits_processor' - conditionally passed
# 'stopping_strings' - temporarily used
# 'logprobs' - temporarily used
# 'requested_model' - temporarily used
}
def get_default_req_params():
return copy.deepcopy(default_req_params)
def default(dic, key, default):
'''
little helper to get defaults if arg is present but None and should be the same type as default.
'''
val = dic.get(key, default)
if not isinstance(val, type(default)):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v
except:
pass
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))

View File

@ -1,101 +0,0 @@
import time
import yaml
from extensions.openai.defaults import get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from modules import shared
from modules.text_generation import encode, generate_reply
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
created_time = int(time.time() * 1000)
# Request parameters
req_params = get_default_req_params()
stopping_strings = []
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
instruction_template = default_template
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template']:
if 'Alpaca' in shared.settings['instruction_template']:
stopping_strings.extend(['\n###'])
else:
try:
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
else:
stopping_strings.extend(['\n###'])
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = shared.settings['truncation_length']
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
if max_tokens < 1:
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
raise InvalidRequestError(err_msg, param='input')
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = temperature
req_params['top_p'] = top_p
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
resp = {
"object": "edit",
"created": created_time,
"choices": [{
"text": answer,
"index": 0,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
return resp

View File

@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
embeddings_params_initialized = False
# using 'lazy loading' to avoid circular import
# so this function will be executed only once
def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
@ -26,23 +30,21 @@ def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
embeddings_model = 'loading...' # flag
print(f"Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer(model, device=embeddings_device)
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {emb_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {emb_model.max_seq_length}")
embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}")
except Exception as e:
embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
return emb_model
def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model
load_embedding_model(st_model) # lazy load the model
return embeddings_model
@ -53,7 +55,11 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding
def embeddings(input: list, encoding_format: str) -> dict:

View File

@ -50,6 +50,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
'data': []
}
from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"

View File

@ -1,4 +1,5 @@
SpeechRecognition==3.10.0
flask_cloudflared==0.0.12
flask_cloudflared==0.0.14
sentence-transformers
sse-starlette==1.6.5
tiktoken

View File

@ -1,351 +1,254 @@
import json
import os
import ssl
import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import (
InvalidRequestError,
OpenAIError,
ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg
from modules import shared
import cgi
import speech_recognition as sr
import uvicorn
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from modules import shared
from modules.logging_colors import logger
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
to_dict
)
params = {
# default params
'port': 5001,
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
# optional params
'sd_webui_url': '',
'debug': 0
}
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Credentials", "true")
self.send_header(
"Access-Control-Allow-Methods",
"GET,HEAD,OPTIONS,POST,PUT"
)
self.send_header(
"Access-Control-Allow-Headers",
"Origin, Accept, X-Requested-With, Content-Type, "
"Access-Control-Request-Method, Access-Control-Request-Headers, "
"Authorization"
)
def do_OPTIONS(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write("OK".encode('utf-8'))
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
def start_sse(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
self.end_headers()
def send_sse(self, chunk: dict):
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
debug_msg(response[:-4])
self.wfile.write(response.encode('utf-8'))
app = FastAPI(dependencies=[Depends(verify_api_key)])
def end_sse(self):
response = 'data: [DONE]\r\n\r\n'
debug_msg(response[:-4])
self.wfile.write(response.encode('utf-8'))
# Configure CORS settings to allow all origins, methods, and headers
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
allow_headers=[
"Origin",
"Accept",
"X-Requested-With",
"Content-Type",
"Access-Control-Request-Method",
"Access-Control-Request-Headers",
"Authorization",
],
)
def return_json(self, ret: dict, code: int = 200, no_debug=False):
self.send_response(code)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
response = json.dumps(ret)
r_utf8 = response.encode('utf-8')
@app.options("/")
async def options_route():
return JSONResponse(content="OK")
self.send_header('Content-Length', str(len(r_utf8)))
self.end_headers()
self.wfile.write(r_utf8)
if not no_debug:
debug_msg(r_utf8)
@app.post('/v1/completions', response_model=CompletionResponse)
async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
if request_data.stream:
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
error_resp = {
'error': {
'message': message,
'code': code,
'type': error_type,
'param': param,
}
}
if internal_message:
print(error_type, message)
print(internal_message)
# error_resp['internal_message'] = internal_message
return EventSourceResponse(generator()) # SSE streaming
self.return_json(error_resp, code)
else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
def openai_error_handler(func):
def wrapper(self):
try:
func(self)
except InvalidRequestError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
except OpenAIError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
except Exception as e:
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
return wrapper
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
@openai_error_handler
def do_GET(self):
debug_msg(self.requestline)
debug_msg(self.headers)
if request_data.stream:
async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
is_legacy = 'engines' in self.path
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
if is_legacy and not is_list:
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
return EventSourceResponse(generator()) # SSE streaming
self.return_json(resp)
else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.return_json({"total_usage": 0}, no_debug=True)
else:
self.send_error(404)
@app.get("/v1/models")
@app.get("/v1/engines")
async def handle_models(request: Request):
path = request.url.path
is_legacy = 'engines' in path
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
@openai_error_handler
def do_POST(self):
if is_legacy and not is_list:
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
if '/v1/audio/transcriptions' in self.path:
r = sr.Recognizer()
return JSONResponse(content=resp)
# Parse the form data
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
)
audio_file = form['file'].file
audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""}
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
self.return_json(transcription, no_debug=True)
return
debug_msg(self.requestline)
debug_msg(self.headers)
@app.get('/v1/billing/usage')
def handle_billing_usage():
'''
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
'''
return JSONResponse(content={"total_usage": 0})
content_length = self.headers.get('Content-Length')
transfer_encoding = self.headers.get('Transfer-Encoding')
if content_length:
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
elif transfer_encoding == 'chunked':
chunks = []
while True:
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
if chunk_size == 0:
break # End of chunks
chunks.append(self.rfile.read(chunk_size))
self.rfile.readline() # Consume the trailing newline after each chunk
body = json.loads(b''.join(chunks).decode('utf-8'))
else:
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
self.end_headers()
return
@app.post('/v1/audio/transcriptions')
async def handle_audio_transcription(request: Request):
r = sr.Recognizer()
debug_msg(body)
form = await request.form()
audio_file = await form["file"].read()
audio_data = AudioSegment.from_file(audio_file)
if '/completions' in self.path or '/generate' in self.path:
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
if not shared.model:
raise ServiceUnavailableError("No model loaded.")
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
is_legacy = '/generate' in self.path
is_streaming = body.get('stream', False)
transcription = {"text": ""}
if is_streaming:
self.start_sse()
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
response = []
if 'chat' in self.path:
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
return JSONResponse(content=transcription)
for resp in response:
self.send_sse(resp)
self.end_sse()
@app.post('/v1/images/generations')
async def handle_image_generation(request: Request):
else:
response = ''
if 'chat' in self.path:
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.completions(body, is_legacy=is_legacy)
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
self.return_json(response)
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
elif '/edits' in self.path:
# deprecated
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
return JSONResponse(response)
if not shared.model:
raise ServiceUnavailableError("No model loaded.")
req_params = get_default_req_params()
@app.post("/v1/embeddings")
async def handle_embeddings(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
instruction = body['instruction']
input = body.get('input', '')
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
input = body.get('input', body.get('text', ''))
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAIedits.edits(instruction, input, temperature, top_p)
if type(input) is str:
input = [input]
self.return_json(response)
response = OAIembeddings.embeddings(input, encoding_format)
return JSONResponse(response)
elif '/images/generations' in self.path:
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt']
size = default(body, 'size', '1024x1024')
response_format = default(body, 'response_format', 'url') # or b64_json
n = default(body, 'n', 1) # ignore the batch limits of max 10
@app.post("/v1/moderations")
async def handle_moderations(request: Request):
body = await request.json()
input = body["input"]
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
response = OAImoderations.moderations(input)
return JSONResponse(response)
self.return_json(response, no_debug=True)
elif '/embeddings' in self.path:
encoding_format = body.get('encoding_format', '')
@app.post("/api/v1/token-count")
async def handle_token_count(request: Request):
body = await request.json()
response = token_count(body['prompt'])
return JSONResponse(response)
input = body.get('input', body.get('text', ''))
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
if type(input) is str:
input = [input]
@app.post("/api/v1/token/encode")
async def handle_token_encode(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
response = token_encode(body["input"], encoding_format)
return JSONResponse(response)
response = OAIembeddings.embeddings(input, encoding_format)
self.return_json(response, no_debug=True)
elif '/moderations' in self.path:
input = body['input']
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
response = OAImoderations.moderations(input)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
response = token_count(body['prompt'])
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/encode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_encode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/decode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_decode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
else:
self.send_error(404)
@app.post("/api/v1/token/decode")
async def handle_token_decode(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
response = token_decode(body["input"], encoding_format)
return JSONResponse(response, no_debug=True)
def run_server():
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler)
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(port, port + 1)
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
if shared.args.public_api:
def on_start(public_url: str):
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
else:
if ssl_verify:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
if ssl_keyfile and ssl_certfile:
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
server.serve_forever()
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
if shared.args.api_key:
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
def setup():

125
extensions/openai/typing.py Normal file
View File

@ -0,0 +1,125 @@
import json
import time
from typing import List
from pydantic import BaseModel, Field
class GenerationOptions(BaseModel):
preset: str | None = None
temperature: float = 1
top_p: float = 1
min_p: float = 0
top_k: int = 0
repetition_penalty: float = 1
presence_penalty: float = 0
frequency_penalty: float = 0
repetition_penalty_range: int = 0
typical_p: float = 1
tfs: float = 1
top_a: float = 0
epsilon_cutoff: float = 0
eta_cutoff: float = 0
guidance_scale: float = 1
negative_prompt: str = ''
penalty_alpha: float = 0
mirostat_mode: int = 0
mirostat_tau: float = 5
mirostat_eta: float = 0.1
temperature_last: bool = False
do_sample: bool = True
seed: int = -1
encoder_repetition_penalty: float = 1
no_repeat_ngram_size: int = 0
min_length: int = 0
num_beams: int = 1
length_penalty: float = 1
early_stopping: bool = False
truncation_length: int = 0
max_tokens_second: int = 0
custom_token_bans: str = ""
auto_max_new_tokens: bool = False
ban_eos_token: bool = False
add_bos_token: bool = True
skip_special_tokens: bool = True
grammar_string: str = ""
class CompletionRequest(GenerationOptions):
model: str | None = None
prompt: str | List[str]
best_of: int | None = 1
echo: bool | None = False
frequency_penalty: float | None = 0
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 16
n: int | None = 1
presence_penalty: int | None = 0
stop: str | List[str] | None = None
stream: bool | None = False
suffix: str | None = None
temperature: float | None = 1
top_p: float | None = 1
user: str | None = None
class CompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = int(time.time())
model: str
object: str = "text_completion"
usage: dict
class ChatCompletionRequest(GenerationOptions):
messages: List[dict]
model: str | None = None
frequency_penalty: float | None = 0
function_call: str | dict | None = None
functions: List[dict] | None = None
logit_bias: dict | None = None
max_tokens: int | None = None
n: int | None = 1
presence_penalty: int | None = 0
stop: str | List[str] | None = None
stream: bool | None = False
temperature: float | None = 1
top_p: float | None = 1
user: str | None = None
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
name2: str | None = Field(default=None, description="Overwrites the value set by character.")
context: str | None = Field(default=None, description="Overwrites the value set by character.")
greeting: str | None = Field(default=None, description="Overwrites the value set by character.")
chat_instruct_command: str | None = None
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
class ChatCompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = int(time.time())
model: str
object: str = "chat.completion"
usage: dict
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)
def to_dict(obj):
return obj.__dict__

View File

@ -1,8 +1,12 @@
import base64
import os
import time
import traceback
from typing import Callable, Optional
import numpy as np
def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects
# float_array = np.array(float_list, dtype="float32")
@ -18,13 +22,33 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
return ascii_string
def end_line(s):
if s and s[-1] != '\n':
s = s + '\n'
return s
def debug_msg(*args, **kwargs):
from extensions.openai.script import params
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
print(*args, **kwargs)
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')

View File

@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Find the maximum prompt size
max_length = get_max_prompt_length(state)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False),
'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
'instruct': get_turn_substrings(state, instruct=True)
}
@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
# Extract the reply
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = reply
if state['mode'] in ['chat', 'chat-instruct']:
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = html.escape(visible_reply)
if shared.stop_everything:

View File

@ -71,11 +71,12 @@ def load_model(model_name, loader=None):
'AutoAWQ': AutoAWQ_loader,
}
metadata = get_model_metadata(model_name)
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = get_model_metadata(model_name)['loader']
loader = metadata['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
return None, None
@ -95,6 +96,7 @@ def load_model(model_name, loader=None):
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

View File

@ -6,33 +6,32 @@ import yaml
def default_preset():
return {
'do_sample': True,
'temperature': 1,
'temperature_last': False,
'top_p': 1,
'min_p': 0,
'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'typical_p': 1,
'tfs': 1,
'top_a': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'guidance_scale': 1,
'penalty_alpha': 0,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
'custom_token_bans': '',
}

View File

@ -39,21 +39,21 @@ settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 4096,
'seed': -1,
'negative_prompt': '',
'seed': -1,
'truncation_length': 2048,
'truncation_length_min': 0,
'truncation_length_max': 32768,
'custom_stopping_strings': '',
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'ban_eos_token': False,
'custom_stopping_strings': '',
'custom_token_bans': '',
'auto_max_new_tokens': False,
'ban_eos_token': False,
'add_bos_token': True,
'skip_special_tokens': True,
'stream': True,
'name1': 'You',
'character': 'Assistant',
'name1': 'You',
'instruction_template': 'Alpaca',
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'autoload_model': False,
@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
# Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED')
parser.add_argument('--chat', action='store_true', help='DEPRECATED')
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED')
args = parser.parse_args()
args_defaults = parser.parse_args([])
@ -233,10 +235,13 @@ def fix_loader_name(name):
return 'AutoAWQ'
def add_extension(name):
def add_extension(name, last=False):
if args.extensions is None:
args.extensions = [name]
elif 'api' not in args.extensions:
elif last:
args.extensions = [x for x in args.extensions if x != name]
args.extensions.append(name)
elif name not in args.extensions:
args.extensions.append(name)
@ -246,14 +251,15 @@ def is_chat():
args.loader = fix_loader_name(args.loader)
# Activate the API extension
if args.api or args.public_api:
add_extension('api')
# Activate the multimodal extension
if args.multimodal_pipeline is not None:
add_extension('multimodal')
# Activate the API extension
if args.api:
# add_extension('openai', last=True)
add_extension('api', last=True)
# Load model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists():

View File

@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
# Find the stopping strings
all_stop_strings = []
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
for st in (stopping_strings, state['custom_stopping_strings']):
if type(st) is str:
st = ast.literal_eval(f"[{st}]")
if type(st) is list and len(st) > 0:
all_stop_strings += st

View File

@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False):
if 'instruction_template' in settings:
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
# Applying the changes to the global shared settings (in-memory)
shared.settings.update({k: v for k, v in settings.items() if k in shared.settings})
yield output
else:
yield f"Failed to load `{selected_model}`."