mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Make OpenAI API the default API (#4430)
This commit is contained in:
parent
84d957ba62
commit
ec17a5d2b7
@ -22,7 +22,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character)
|
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character)
|
||||||
* Very efficient text streaming
|
* Very efficient text streaming
|
||||||
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
|
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
|
||||||
* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples))
|
* OpenAI-compatible API server
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
@ -412,8 +412,8 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--api` | Enable the API extension. |
|
| `--api` | Enable the API extension. |
|
||||||
| `--public-api` | Create a public URL for the API using Cloudfare. |
|
| `--public-api` | Create a public URL for the API using Cloudfare. |
|
||||||
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
|
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
|
||||||
| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. |
|
| `--api-port API_PORT` | The listening port for the API. |
|
||||||
| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. |
|
| `--api-key API_KEY` | API authentication key. |
|
||||||
|
|
||||||
#### Multimodal
|
#### Multimodal
|
||||||
|
|
||||||
|
@ -1,114 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import html
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
import websockets
|
|
||||||
except ImportError:
|
|
||||||
print("Websockets package not found. Make sure it's installed.")
|
|
||||||
|
|
||||||
# For local streaming, the websockets are hosted without ssl - ws://
|
|
||||||
HOST = 'localhost:5005'
|
|
||||||
URI = f'ws://{HOST}/api/v1/chat-stream'
|
|
||||||
|
|
||||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
|
||||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
|
||||||
|
|
||||||
|
|
||||||
async def run(user_input, history):
|
|
||||||
# Note: the selected defaults change from time to time.
|
|
||||||
request = {
|
|
||||||
'user_input': user_input,
|
|
||||||
'max_new_tokens': 250,
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
|
||||||
'history': history,
|
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
|
||||||
'character': 'Example',
|
|
||||||
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
|
||||||
'your_name': 'You',
|
|
||||||
# 'name1': 'name of user', # Optional
|
|
||||||
# 'name2': 'name of character', # Optional
|
|
||||||
# 'context': 'character context', # Optional
|
|
||||||
# 'greeting': 'greeting', # Optional
|
|
||||||
# 'name1_instruct': 'You', # Optional
|
|
||||||
# 'name2_instruct': 'Assistant', # Optional
|
|
||||||
# 'context_instruct': 'context_instruct', # Optional
|
|
||||||
# 'turn_template': 'turn_template', # Optional
|
|
||||||
'regenerate': False,
|
|
||||||
'_continue': False,
|
|
||||||
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
|
||||||
'preset': 'None',
|
|
||||||
'do_sample': True,
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.1,
|
|
||||||
'typical_p': 1,
|
|
||||||
'epsilon_cutoff': 0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0, # In units of 1e-4
|
|
||||||
'tfs': 1,
|
|
||||||
'top_a': 0,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'presence_penalty': 0,
|
|
||||||
'frequency_penalty': 0,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'top_k': 40,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'grammar_string': '',
|
|
||||||
'guidance_scale': 1,
|
|
||||||
'negative_prompt': '',
|
|
||||||
|
|
||||||
'seed': -1,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'custom_token_bans': '',
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'stopping_strings': []
|
|
||||||
}
|
|
||||||
|
|
||||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
incoming_data = await websocket.recv()
|
|
||||||
incoming_data = json.loads(incoming_data)
|
|
||||||
|
|
||||||
match incoming_data['event']:
|
|
||||||
case 'text_stream':
|
|
||||||
yield incoming_data['history']
|
|
||||||
case 'stream_end':
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def print_response_stream(user_input, history):
|
|
||||||
cur_len = 0
|
|
||||||
async for new_history in run(user_input, history):
|
|
||||||
cur_message = new_history['visible'][-1][1][cur_len:]
|
|
||||||
cur_len += len(cur_message)
|
|
||||||
print(html.unescape(cur_message), end='')
|
|
||||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
|
||||||
|
|
||||||
# Basic example
|
|
||||||
history = {'internal': [], 'visible': []}
|
|
||||||
|
|
||||||
# "Continue" example. Make sure to set '_continue' to True above
|
|
||||||
# arr = [user_input, 'Surely, here is']
|
|
||||||
# history = {'internal': [arr], 'visible': [arr]}
|
|
||||||
|
|
||||||
asyncio.run(print_response_stream(user_input, history))
|
|
@ -1,94 +0,0 @@
|
|||||||
import html
|
|
||||||
import json
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# For local streaming, the websockets are hosted without ssl - http://
|
|
||||||
HOST = 'localhost:5000'
|
|
||||||
URI = f'http://{HOST}/api/v1/chat'
|
|
||||||
|
|
||||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
|
||||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat'
|
|
||||||
|
|
||||||
|
|
||||||
def run(user_input, history):
|
|
||||||
request = {
|
|
||||||
'user_input': user_input,
|
|
||||||
'max_new_tokens': 250,
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
|
||||||
'history': history,
|
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
|
||||||
'character': 'Example',
|
|
||||||
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
|
||||||
'your_name': 'You',
|
|
||||||
# 'name1': 'name of user', # Optional
|
|
||||||
# 'name2': 'name of character', # Optional
|
|
||||||
# 'context': 'character context', # Optional
|
|
||||||
# 'greeting': 'greeting', # Optional
|
|
||||||
# 'name1_instruct': 'You', # Optional
|
|
||||||
# 'name2_instruct': 'Assistant', # Optional
|
|
||||||
# 'context_instruct': 'context_instruct', # Optional
|
|
||||||
# 'turn_template': 'turn_template', # Optional
|
|
||||||
'regenerate': False,
|
|
||||||
'_continue': False,
|
|
||||||
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
|
||||||
'preset': 'None',
|
|
||||||
'do_sample': True,
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.1,
|
|
||||||
'typical_p': 1,
|
|
||||||
'epsilon_cutoff': 0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0, # In units of 1e-4
|
|
||||||
'tfs': 1,
|
|
||||||
'top_a': 0,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'presence_penalty': 0,
|
|
||||||
'frequency_penalty': 0,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'top_k': 40,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'grammar_string': '',
|
|
||||||
'guidance_scale': 1,
|
|
||||||
'negative_prompt': '',
|
|
||||||
|
|
||||||
'seed': -1,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'custom_token_bans': '',
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'stopping_strings': []
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(URI, json=request)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()['results'][0]['history']
|
|
||||||
print(json.dumps(result, indent=4))
|
|
||||||
print()
|
|
||||||
print(html.unescape(result['visible'][-1][1]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
|
||||||
|
|
||||||
# Basic example
|
|
||||||
history = {'internal': [], 'visible': []}
|
|
||||||
|
|
||||||
# "Continue" example. Make sure to set '_continue' to True above
|
|
||||||
# arr = [user_input, 'Surely, here is']
|
|
||||||
# history = {'internal': [arr], 'visible': [arr]}
|
|
||||||
|
|
||||||
run(user_input, history)
|
|
@ -1,88 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
import websockets
|
|
||||||
except ImportError:
|
|
||||||
print("Websockets package not found. Make sure it's installed.")
|
|
||||||
|
|
||||||
# For local streaming, the websockets are hosted without ssl - ws://
|
|
||||||
HOST = 'localhost:5005'
|
|
||||||
URI = f'ws://{HOST}/api/v1/stream'
|
|
||||||
|
|
||||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
|
||||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
|
||||||
|
|
||||||
|
|
||||||
async def run(context):
|
|
||||||
# Note: the selected defaults change from time to time.
|
|
||||||
request = {
|
|
||||||
'prompt': context,
|
|
||||||
'max_new_tokens': 250,
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
|
||||||
'preset': 'None',
|
|
||||||
'do_sample': True,
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.1,
|
|
||||||
'typical_p': 1,
|
|
||||||
'epsilon_cutoff': 0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0, # In units of 1e-4
|
|
||||||
'tfs': 1,
|
|
||||||
'top_a': 0,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'presence_penalty': 0,
|
|
||||||
'frequency_penalty': 0,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'top_k': 40,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'grammar_string': '',
|
|
||||||
'guidance_scale': 1,
|
|
||||||
'negative_prompt': '',
|
|
||||||
|
|
||||||
'seed': -1,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'custom_token_bans': '',
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'stopping_strings': []
|
|
||||||
}
|
|
||||||
|
|
||||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
|
||||||
await websocket.send(json.dumps(request))
|
|
||||||
|
|
||||||
yield context # Remove this if you just want to see the reply
|
|
||||||
|
|
||||||
while True:
|
|
||||||
incoming_data = await websocket.recv()
|
|
||||||
incoming_data = json.loads(incoming_data)
|
|
||||||
|
|
||||||
match incoming_data['event']:
|
|
||||||
case 'text_stream':
|
|
||||||
yield incoming_data['text']
|
|
||||||
case 'stream_end':
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def print_response_stream(prompt):
|
|
||||||
async for response in run(prompt):
|
|
||||||
print(response, end='')
|
|
||||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
|
||||||
asyncio.run(print_response_stream(prompt))
|
|
@ -1,65 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
# For local streaming, the websockets are hosted without ssl - http://
|
|
||||||
HOST = 'localhost:5000'
|
|
||||||
URI = f'http://{HOST}/api/v1/generate'
|
|
||||||
|
|
||||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
|
||||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
|
||||||
|
|
||||||
|
|
||||||
def run(prompt):
|
|
||||||
request = {
|
|
||||||
'prompt': prompt,
|
|
||||||
'max_new_tokens': 250,
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
|
||||||
'preset': 'None',
|
|
||||||
'do_sample': True,
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.1,
|
|
||||||
'typical_p': 1,
|
|
||||||
'epsilon_cutoff': 0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0, # In units of 1e-4
|
|
||||||
'tfs': 1,
|
|
||||||
'top_a': 0,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'presence_penalty': 0,
|
|
||||||
'frequency_penalty': 0,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'top_k': 40,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'grammar_string': '',
|
|
||||||
'guidance_scale': 1,
|
|
||||||
'negative_prompt': '',
|
|
||||||
|
|
||||||
'seed': -1,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'custom_token_bans': '',
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'stopping_strings': []
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(URI, json=request)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()['results'][0]['text']
|
|
||||||
print(prompt + result)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
|
||||||
run(prompt)
|
|
@ -1,124 +1,64 @@
|
|||||||
# An OpenedAI API (openai like)
|
## OpenAI compatible API
|
||||||
|
|
||||||
This extension creates an API that works kind of like openai (ie. api.openai.com).
|
This project includes an API compatible with multiple OpenAI endpoints, including Chat and Completions.
|
||||||
|
|
||||||
## Setup & installation
|
If you did not use the one-click installers, you may need to install the requirements first:
|
||||||
|
|
||||||
Install the requirements:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
pip3 install -r requirements.txt
|
pip install -r extensions/openai/requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this.
|
### Starting the API
|
||||||
|
|
||||||
Make sure you enable it in server launch parameters, it should include:
|
Add `--extensions openai` to your command-line flags.
|
||||||
|
|
||||||
|
* To create a public Cloudflare URL, add the `--public-api` flag.
|
||||||
|
* To listen on your local network, add the `--listen` flag.
|
||||||
|
* To change the port, which is 5000 by default, use `--port 1234` (change 1234 to your desired port number).
|
||||||
|
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
|
||||||
|
|
||||||
|
#### Environment variables
|
||||||
|
|
||||||
|
The following environment variables can be used (they take precendence over everything else):
|
||||||
|
|
||||||
|
| Variable Name | Description | Example Value |
|
||||||
|
|------------------------|------------------------------------|----------------------------|
|
||||||
|
| `OPENEDAI_PORT` | Port number | 5000 |
|
||||||
|
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
||||||
|
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||||
|
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||||
|
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||||
|
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
||||||
|
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||||
|
|
||||||
|
#### Persistent settings with `settings.yaml`
|
||||||
|
|
||||||
|
You can also set default values by adding these lines to your `settings.yaml` file:
|
||||||
|
|
||||||
```
|
```
|
||||||
--extensions openai
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint.
|
|
||||||
|
|
||||||
To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
```
|
|
||||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick start
|
|
||||||
|
|
||||||
1. Install the requirements.txt (pip)
|
|
||||||
2. Enable the `openeai` module (--extensions openai), restart the server.
|
|
||||||
3. Configure the openai client
|
|
||||||
|
|
||||||
Most openai application can be configured to connect the API if you set the following environment variables:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# Sample .env file:
|
|
||||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
|
||||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
|
||||||
|
|
||||||
|
|
||||||
### Settings
|
|
||||||
|
|
||||||
To adjust your default settings, you can add the following to your `settings.yaml` file.
|
|
||||||
|
|
||||||
```
|
|
||||||
openai-port: 5002
|
|
||||||
openai-embedding_device: cuda
|
openai-embedding_device: cuda
|
||||||
|
openai-embedding_model: all-mpnet-base-v2
|
||||||
openai-sd_webui_url: http://127.0.0.1:7861
|
openai-sd_webui_url: http://127.0.0.1:7861
|
||||||
openai-debug: 1
|
openai-debug: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
If you've configured the environment variables, please note that settings from `settings.yaml` won't take effect. For instance, if you set `openai-port: 5002` in `settings.yaml` but `OPENEDAI_PORT=5001` in the environment variables, the extension will use `5001` as the port number.
|
### Examples
|
||||||
|
|
||||||
When using `cache_embedding_model.py` to preload the embedding model during Docker image building, consider the following:
|
|
||||||
|
|
||||||
- If you wish to use the default settings, leave the environment variables unset.
|
|
||||||
- If you intend to change the default embedding model, ensure that you configure the environment variable `OPENEDAI_EMBEDDING_MODEL` to the desired model. Avoid setting `openai-embedding_model` in `settings.yaml` because those settings only take effect after the server starts.
|
|
||||||
|
|
||||||
### Models
|
|
||||||
|
|
||||||
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
|
|
||||||
|
|
||||||
For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start.
|
|
||||||
|
|
||||||
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
|
|
||||||
|
|
||||||
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
|
|
||||||
|
|
||||||
For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date.
|
|
||||||
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
|
|
||||||
|
|
||||||
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
|
|
||||||
|
|
||||||
```
|
|
||||||
.*wizard.*vicuna:
|
|
||||||
mode: 'instruct'
|
|
||||||
instruction_template: 'Vicuna-v1.1'
|
|
||||||
```
|
|
||||||
|
|
||||||
This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this:
|
|
||||||
|
|
||||||
```
|
|
||||||
user: "USER:"
|
|
||||||
bot: "ASSISTANT:"
|
|
||||||
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
|
|
||||||
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
|
|
||||||
```
|
|
||||||
|
|
||||||
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
|
|
||||||
|
|
||||||
If you see this in your logs, it probably means that the correct format could not be loaded:
|
|
||||||
|
|
||||||
```
|
|
||||||
Warning: Loaded default instruction-following template for model.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Embeddings (alpha)
|
|
||||||
|
|
||||||
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
|
||||||
|
|
||||||
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
|
||||||
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
|
||||||
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
|
|
||||||
| text-davinci-002 | 768 | 2046 | - | - | - |
|
|
||||||
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
|
||||||
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
|
||||||
|
|
||||||
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
|
||||||
|
|
||||||
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
|
||||||
|
|
||||||
### Client Application Setup
|
### Client Application Setup
|
||||||
|
|
||||||
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
|
|
||||||
|
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OPENAI_API_HOST=http://127.0.0.1:5000
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||||
|
OPENAI_API_BASE=http://127.0.0.1:500/v1
|
||||||
|
```
|
||||||
|
|
||||||
With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
|
With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
|
||||||
|
|
||||||
@ -128,7 +68,7 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
|||||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
If needed, replace 127.0.0.1 with the IP/port of your server.
|
||||||
|
|
||||||
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
|
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
|
||||||
|
|
||||||
@ -157,8 +97,22 @@ const api = new ChatGPTAPI({
|
|||||||
apiBaseUrl: process.env.OPENAI_API_BASE
|
apiBaseUrl: process.env.OPENAI_API_BASE
|
||||||
});
|
});
|
||||||
```
|
```
|
||||||
|
### Embeddings (alpha)
|
||||||
|
|
||||||
## API Documentation & Examples
|
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||||
|
|
||||||
|
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
||||||
|
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
||||||
|
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
|
||||||
|
| text-davinci-002 | 768 | 2046 | - | - | - |
|
||||||
|
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
||||||
|
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
||||||
|
|
||||||
|
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
||||||
|
|
||||||
|
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
||||||
|
|
||||||
|
### API Documentation & Examples
|
||||||
|
|
||||||
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
|
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
|
||||||
|
|
||||||
@ -185,7 +139,7 @@ text = response['choices'][0]['message']['content']
|
|||||||
print(text)
|
print(text)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Compatibility & not so compatibility
|
### Compatibility & not so compatibility
|
||||||
|
|
||||||
| API endpoint | tested with | notes |
|
| API endpoint | tested with | notes |
|
||||||
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
|
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
|
||||||
@ -195,7 +149,7 @@ print(text)
|
|||||||
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
|
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
|
||||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
||||||
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
| /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
|
||||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
||||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||||
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
|
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||||
@ -209,28 +163,8 @@ print(text)
|
|||||||
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
||||||
| /v1/search | openai.search, engines.search | not yet supported |
|
| /v1/search | openai.search, engines.search | not yet supported |
|
||||||
|
|
||||||
Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
|
|
||||||
|
|
||||||
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly.
|
#### Applications
|
||||||
|
|
||||||
Some hacky mappings:
|
|
||||||
|
|
||||||
| OpenAI | text-generation-webui | note |
|
|
||||||
| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| model | - | Ignored, the model is not changed |
|
|
||||||
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
|
||||||
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
|
||||||
| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
|
|
||||||
| n | 1 | variations are not supported yet. |
|
|
||||||
| 1 | num_beams | hardcoded to 1 |
|
|
||||||
| 1.0 | typical_p | hardcoded to 1.0 |
|
|
||||||
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
|
||||||
| messages.name | - | not supported yet |
|
|
||||||
| suffix | - | not supported yet |
|
|
||||||
| user | - | not supported yet |
|
|
||||||
| functions/function_call | - | function calls are not supported yet |
|
|
||||||
|
|
||||||
### Applications
|
|
||||||
|
|
||||||
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
|
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
|
||||||
|
|
||||||
@ -249,15 +183,3 @@ Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment v
|
|||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||||
|
|
||||||
## Future plans
|
|
||||||
|
|
||||||
- better error handling
|
|
||||||
- model changing, esp. something for swapping loras or embedding models
|
|
||||||
- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
|
||||||
|
|
||||||
## Bugs? Feedback? Comments? Pull requests?
|
|
||||||
|
|
||||||
To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable.
|
|
||||||
|
|
||||||
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
|
@ -3,9 +3,11 @@ import time
|
|||||||
import extensions.api.blocking_api as blocking_api
|
import extensions.api.blocking_api as blocking_api
|
||||||
import extensions.api.streaming_api as streaming_api
|
import extensions.api.streaming_api as streaming_api
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
logger.warning("The current API is deprecated and will be replaced with the OpenAI compatible API on November xxth. To test the new API, use \"--extensions openai\" instead of \"--api\".")
|
||||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
||||||
if shared.args.public_api:
|
if shared.args.public_api:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
|
import copy
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import yaml
|
|
||||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
|
||||||
from extensions.openai.errors import InvalidRequestError
|
from extensions.openai.errors import InvalidRequestError
|
||||||
from extensions.openai.utils import debug_msg, end_line
|
from extensions.openai.utils import debug_msg
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.chat import (
|
||||||
|
generate_chat_prompt,
|
||||||
|
generate_chat_reply,
|
||||||
|
load_character_memoized
|
||||||
|
)
|
||||||
|
from modules.presets import load_preset_memoized
|
||||||
from modules.text_generation import decode, encode, generate_reply
|
from modules.text_generation import decode, encode, generate_reply
|
||||||
from transformers import LogitsProcessor, LogitsProcessorList
|
from transformers import LogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
|
|
||||||
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
|
||||||
class LogitsBiasProcessor(LogitsProcessor):
|
class LogitsBiasProcessor(LogitsProcessor):
|
||||||
def __init__(self, logit_bias={}):
|
def __init__(self, logit_bias={}):
|
||||||
self.logit_bias = logit_bias
|
self.logit_bias = logit_bias
|
||||||
@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
|
|||||||
logits[0, self.keys] += self.values
|
logits[0, self.keys] += self.values
|
||||||
debug_msg(" --> ", logits[0, self.keys])
|
debug_msg(" --> ", logits[0, self.keys])
|
||||||
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor):
|
|||||||
top_probs = [float(x) for x in top_values[0]]
|
top_probs = [float(x) for x in top_values[0]]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
debug_msg(repr(self))
|
debug_msg(repr(self))
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
|||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
def marshal_common_params(body):
|
def process_parameters(body, is_legacy=False):
|
||||||
# Request Parameters
|
generate_params = body
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
||||||
|
if generate_params['truncation_length'] == 0:
|
||||||
|
if shared.args.loader and shared.args.loader.lower().startswith('exllama'):
|
||||||
|
generate_params['truncation_length'] = shared.args.max_seq_len
|
||||||
|
elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||||
|
generate_params['truncation_length'] = shared.args.n_ctx
|
||||||
|
else:
|
||||||
|
generate_params['truncation_length'] = shared.settings['truncation_length']
|
||||||
|
|
||||||
req_params = get_default_req_params()
|
if body['preset'] is not None:
|
||||||
|
preset = load_preset_memoized(body['preset'])
|
||||||
# Common request parameters
|
generate_params.update(preset)
|
||||||
req_params['truncation_length'] = shared.settings['truncation_length']
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
|
||||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
|
||||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
|
||||||
|
|
||||||
# OpenAI API Parameters
|
|
||||||
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
|
||||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
|
||||||
|
|
||||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
|
|
||||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
|
|
||||||
n = default(body, 'n', 1)
|
|
||||||
if n != 1:
|
|
||||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
|
||||||
|
|
||||||
|
generate_params['custom_stopping_strings'] = []
|
||||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||||
if isinstance(body['stop'], str):
|
if isinstance(body['stop'], str):
|
||||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
generate_params['custom_stopping_strings'] = [body['stop']]
|
||||||
elif isinstance(body['stop'], list):
|
elif isinstance(body['stop'], list):
|
||||||
req_params['stopping_strings'] = body['stop']
|
generate_params['custom_stopping_strings'] = body['stop']
|
||||||
|
|
||||||
# presence_penalty - ignored
|
|
||||||
# frequency_penalty - ignored
|
|
||||||
|
|
||||||
# pass through unofficial params
|
|
||||||
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
|
|
||||||
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
|
|
||||||
|
|
||||||
# user - ignored
|
|
||||||
|
|
||||||
logits_processor = []
|
logits_processor = []
|
||||||
logit_bias = body.get('logit_bias', None)
|
logit_bias = body.get('logit_bias', None)
|
||||||
@ -110,12 +102,13 @@ def marshal_common_params(body):
|
|||||||
# XXX convert tokens from tiktoken based on requested model
|
# XXX convert tokens from tiktoken based on requested model
|
||||||
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||||
try:
|
try:
|
||||||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
encoder = tiktoken.encoding_for_model(generate_params['model'])
|
||||||
new_logit_bias = {}
|
new_logit_bias = {}
|
||||||
for logit, bias in logit_bias.items():
|
for logit, bias in logit_bias.items():
|
||||||
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
||||||
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_logit_bias[str(int(x))] = bias
|
new_logit_bias[str(int(x))] = bias
|
||||||
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
||||||
logit_bias = new_logit_bias
|
logit_bias = new_logit_bias
|
||||||
@ -126,238 +119,129 @@ def marshal_common_params(body):
|
|||||||
|
|
||||||
logprobs = None # coming to chat eventually
|
logprobs = None # coming to chat eventually
|
||||||
if 'logprobs' in body:
|
if 'logprobs' in body:
|
||||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||||
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||||
logits_processor.extend([req_params['logprob_proc']])
|
logits_processor.extend([generate_params['logprob_proc']])
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if logits_processor: # requires logits_processor support
|
if logits_processor: # requires logits_processor support
|
||||||
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||||
|
|
||||||
return req_params
|
return generate_params
|
||||||
|
|
||||||
|
|
||||||
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
def convert_history(history):
|
||||||
# functions
|
'''
|
||||||
if body.get('functions', []): # chat only
|
Chat histories in this program are in the format [message, reply].
|
||||||
|
This function converts OpenAI histories to that format.
|
||||||
|
'''
|
||||||
|
chat_dialogue = []
|
||||||
|
current_message = ""
|
||||||
|
current_reply = ""
|
||||||
|
user_input = ""
|
||||||
|
|
||||||
|
for entry in history:
|
||||||
|
content = entry["content"]
|
||||||
|
role = entry["role"]
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
user_input = content
|
||||||
|
if current_message:
|
||||||
|
chat_dialogue.append([current_message, ''])
|
||||||
|
current_message = ""
|
||||||
|
current_message = content
|
||||||
|
elif role == "assistant":
|
||||||
|
current_reply = content
|
||||||
|
if current_message:
|
||||||
|
chat_dialogue.append([current_message, current_reply])
|
||||||
|
current_message = ""
|
||||||
|
current_reply = ""
|
||||||
|
else:
|
||||||
|
chat_dialogue.append(['', current_reply])
|
||||||
|
|
||||||
|
# if current_message:
|
||||||
|
# chat_dialogue.append([current_message, ''])
|
||||||
|
|
||||||
|
return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
|
||||||
|
if body.get('functions', []):
|
||||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
|
||||||
|
if body.get('function_call', ''):
|
||||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||||
|
|
||||||
if 'messages' not in body:
|
if 'messages' not in body:
|
||||||
raise InvalidRequestError(message="messages is required", param='messages')
|
raise InvalidRequestError(message="messages is required", param='messages')
|
||||||
|
|
||||||
messages = body['messages']
|
messages = body['messages']
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': 'User: {message}\n',
|
|
||||||
'assistant': 'Assistant: {message}\n',
|
|
||||||
'system': '{message}',
|
|
||||||
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
|
|
||||||
'prompt': 'Assistant:',
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'stopping_strings' not in req_params:
|
|
||||||
req_params['stopping_strings'] = []
|
|
||||||
|
|
||||||
# Instruct models can be much better
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
system_message_template = "{message}"
|
|
||||||
system_message_default = instruct.get('context', '') # can be missing
|
|
||||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
|
||||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
|
|
||||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
|
|
||||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': user_message_template,
|
|
||||||
'assistant': bot_message_template,
|
|
||||||
'system': system_message_template,
|
|
||||||
'context': system_message_default,
|
|
||||||
'prompt': bot_prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
req_params['stopping_strings'].extend(['\n###'])
|
|
||||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
|
||||||
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
|
||||||
|
|
||||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
|
||||||
|
|
||||||
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
else:
|
|
||||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
system_msgs = []
|
|
||||||
chat_msgs = []
|
|
||||||
|
|
||||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
|
||||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
|
||||||
context_msg = end_line(context_msg)
|
|
||||||
|
|
||||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
|
||||||
if 'prompt' in body:
|
|
||||||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
|
||||||
|
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if 'role' not in m:
|
if 'role' not in m:
|
||||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||||
|
elif m['role'] == 'function':
|
||||||
|
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||||
if 'content' not in m:
|
if 'content' not in m:
|
||||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||||
|
|
||||||
role = m['role']
|
|
||||||
content = m['content']
|
|
||||||
# name = m.get('name', None)
|
|
||||||
# function_call = m.get('function_call', None) # user name or function name with output in content
|
|
||||||
msg = role_formats[role].format(message=content)
|
|
||||||
if role == 'system':
|
|
||||||
system_msgs.extend([msg])
|
|
||||||
elif role == 'function':
|
|
||||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
|
||||||
else:
|
|
||||||
chat_msgs.extend([msg])
|
|
||||||
|
|
||||||
system_msg = '\n'.join(system_msgs)
|
|
||||||
system_msg = end_line(system_msg)
|
|
||||||
|
|
||||||
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
|
|
||||||
if token_count >= req_params['truncation_length']:
|
|
||||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
|
||||||
raise InvalidRequestError(message=err_msg, param='messages')
|
|
||||||
|
|
||||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
|
||||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
|
||||||
print(f"Warning: ${err_msg}")
|
|
||||||
# raise InvalidRequestError(message=err_msg, params='max_tokens')
|
|
||||||
|
|
||||||
return prompt, token_count
|
|
||||||
|
|
||||||
|
|
||||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
|
||||||
# Chat Completions
|
# Chat Completions
|
||||||
object_type = 'chat.completions'
|
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
# common params
|
# generation parameters
|
||||||
req_params = marshal_common_params(body)
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
req_params['stream'] = False
|
continue_ = body['continue_']
|
||||||
requested_model = req_params.pop('requested_model')
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
|
||||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
|
||||||
|
|
||||||
# chat default max_tokens is 'inf', but also flexible
|
# Instruction template
|
||||||
max_tokens = 0
|
instruction_template = body['instruction_template'] or shared.settings['instruction_template']
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||||
if max_tokens_str in body:
|
name1_instruct = body['name1_instruct'] or name1_instruct
|
||||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
name2_instruct = body['name2_instruct'] or name2_instruct
|
||||||
req_params['max_new_tokens'] = max_tokens
|
context_instruct = body['context_instruct'] or context_instruct
|
||||||
else:
|
turn_template = body['turn_template'] or turn_template
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
|
||||||
|
|
||||||
# format the prompt from messages
|
# Chat character
|
||||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
character = body['character'] or shared.settings['character']
|
||||||
|
name1 = body['name1'] or shared.settings['name1']
|
||||||
|
name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False)
|
||||||
|
name2 = body['name2'] or name2
|
||||||
|
context = body['context'] or context
|
||||||
|
greeting = body['greeting'] or greeting
|
||||||
|
|
||||||
# set real max, avoid deeper errors
|
# History
|
||||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
user_input, history = convert_history(messages)
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
|
||||||
|
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
generate_params.update({
|
||||||
|
'mode': body['mode'],
|
||||||
|
'name1': name1,
|
||||||
|
'name2': name2,
|
||||||
|
'context': context,
|
||||||
|
'greeting': greeting,
|
||||||
|
'name1_instruct': name1_instruct,
|
||||||
|
'name2_instruct': name2_instruct,
|
||||||
|
'context_instruct': context_instruct,
|
||||||
|
'turn_template': turn_template,
|
||||||
|
'chat-instruct_command': body['chat_instruct_command'],
|
||||||
|
'history': history,
|
||||||
|
'stream': stream
|
||||||
|
})
|
||||||
|
|
||||||
# generate reply #######################################
|
max_tokens = generate_params['max_new_tokens']
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
if max_tokens in [None, 0]:
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generate_params['max_new_tokens'] = 200
|
||||||
|
generate_params['auto_max_new_tokens'] = True
|
||||||
|
|
||||||
answer = ''
|
requested_model = generate_params.pop('model')
|
||||||
for a in generator:
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
answer = a
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
stop_reason = "stop"
|
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
|
||||||
stop_reason = "length"
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name, # TODO: add Lora info?
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": stop_reason,
|
|
||||||
"message": {"role": "assistant", "content": answer}
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if logprob_proc: # not official for chat yet
|
|
||||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
|
||||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
|
||||||
# else:
|
|
||||||
# resp[resp_list][0]["logprobs"] = None
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
# generator
|
|
||||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|
||||||
|
|
||||||
# Chat Completions
|
|
||||||
stream_object_type = 'chat.completions.chunk'
|
|
||||||
created_time = int(time.time())
|
|
||||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
|
||||||
|
|
||||||
# common params
|
|
||||||
req_params = marshal_common_params(body)
|
|
||||||
req_params['stream'] = True
|
|
||||||
requested_model = req_params.pop('requested_model')
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
|
||||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
|
||||||
|
|
||||||
# chat default max_tokens is 'inf', but also flexible
|
|
||||||
max_tokens = 0
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
|
||||||
if max_tokens_str in body:
|
|
||||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
else:
|
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
|
||||||
|
|
||||||
# format the prompt from messages
|
|
||||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
|
||||||
|
|
||||||
# set real max, avoid deeper errors
|
|
||||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
|
||||||
|
|
||||||
def chat_streaming_chunk(content):
|
def chat_streaming_chunk(content):
|
||||||
# begin streaming
|
# begin streaming
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": stream_object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name,
|
"model": shared.model_name,
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
@ -376,262 +260,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||||||
# chunk[resp_list][0]["logprobs"] = None
|
# chunk[resp_list][0]["logprobs"] = None
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
yield chat_streaming_chunk('')
|
if stream:
|
||||||
|
yield chat_streaming_chunk('')
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
prompt = generate_chat_prompt(user_input, generate_params)
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
generator = generate_chat_reply(
|
||||||
|
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
completion_token_count = 0
|
completion_token_count = 0
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a['internal'][-1][1]
|
||||||
|
if stream:
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
len_seen = len(seen_content)
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
new_content = answer[len_seen:]
|
continue
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
seen_content = answer
|
||||||
continue
|
|
||||||
|
|
||||||
seen_content = answer
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
chunk = chat_streaming_chunk(new_content)
|
||||||
if len_seen == 0 and new_content[0] == ' ':
|
|
||||||
new_content = new_content[1:]
|
|
||||||
|
|
||||||
chunk = chat_streaming_chunk(new_content)
|
yield chunk
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
# to get the correct token_count, strip leading space if present
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
chunk = chat_streaming_chunk('')
|
if stream:
|
||||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
chunk = chat_streaming_chunk('')
|
||||||
chunk['usage'] = {
|
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||||
"prompt_tokens": token_count,
|
chunk['usage'] = {
|
||||||
"completion_tokens": completion_token_count,
|
"prompt_tokens": token_count,
|
||||||
"total_tokens": token_count + completion_token_count
|
"completion_tokens": completion_token_count,
|
||||||
}
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
else:
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"message": {"role": "assistant", "content": answer}
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
def completions(body: dict, is_legacy: bool = False):
|
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
# Legacy
|
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
||||||
# Text Completions
|
|
||||||
object_type = 'text_completion'
|
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
|
||||||
prompt_str = 'context' if is_legacy else 'prompt'
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
|
||||||
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
if prompt_str not in body:
|
if prompt_str not in body:
|
||||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
prompt_arg = body[prompt_str]
|
|
||||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
|
||||||
prompt_arg = [prompt_arg]
|
|
||||||
|
|
||||||
# common params
|
# common params
|
||||||
req_params = marshal_common_params(body)
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
req_params['stream'] = False
|
max_tokens = generate_params['max_new_tokens']
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
generate_params['stream'] = stream
|
||||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
requested_model = generate_params.pop('model')
|
||||||
req_params['max_new_tokens'] = max_tokens
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
requested_model = req_params.pop('requested_model')
|
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
generate_params['echo'] = body.get('echo', generate_params['echo'])
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
|
||||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
|
||||||
|
|
||||||
resp_list_data = []
|
if not stream:
|
||||||
total_completion_token_count = 0
|
prompt_arg = body[prompt_str]
|
||||||
total_prompt_token_count = 0
|
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||||
|
prompt_arg = [prompt_arg]
|
||||||
|
|
||||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
resp_list_data = []
|
||||||
if isinstance(prompt[0], int):
|
total_completion_token_count = 0
|
||||||
# token lists
|
total_prompt_token_count = 0
|
||||||
if requested_model == shared.model_name:
|
|
||||||
prompt = decode(prompt)[0]
|
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||||
else:
|
if isinstance(prompt[0], int):
|
||||||
|
# token lists
|
||||||
|
if requested_model == shared.model_name:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encoder.decode(prompt)
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
total_prompt_token_count += token_count
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||||
|
answer = ''
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
total_completion_token_count += completion_token_count
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
respi = {
|
||||||
|
"index": idx,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"text": answer,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp_list_data.extend([respi])
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: resp_list_data,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": total_prompt_token_count,
|
||||||
|
"completion_tokens": total_completion_token_count,
|
||||||
|
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
try:
|
try:
|
||||||
encoder = tiktoken.encoding_for_model(requested_model)
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
prompt = encoder.decode(prompt)
|
prompt = encoder.decode(prompt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
prompt = decode(prompt)[0]
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
total_prompt_token_count += token_count
|
|
||||||
|
|
||||||
if token_count + max_tokens > req_params['truncation_length']:
|
def text_streaming_chunk(content):
|
||||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
# begin streaming
|
||||||
# print(f"Warning: ${err_msg}")
|
chunk = {
|
||||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"text": content,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield text_streaming_chunk('')
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# to get the correct count, we strip the leading space if present
|
||||||
if answer and answer[0] == ' ':
|
if answer and answer[0] == ' ':
|
||||||
answer = answer[1:]
|
answer = answer[1:]
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
total_completion_token_count += completion_token_count
|
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
respi = {
|
chunk = text_streaming_chunk('')
|
||||||
"index": idx,
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
"finish_reason": stop_reason,
|
chunk["usage"] = {
|
||||||
"text": answer,
|
"prompt_tokens": token_count,
|
||||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
}
|
}
|
||||||
|
|
||||||
resp_list_data.extend([respi])
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name, # TODO: add Lora info?
|
|
||||||
resp_list: resp_list_data,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": total_prompt_token_count,
|
|
||||||
"completion_tokens": total_completion_token_count,
|
|
||||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
# generator
|
|
||||||
def stream_completions(body: dict, is_legacy: bool = False):
|
|
||||||
# Legacy
|
|
||||||
# Text Completions
|
|
||||||
# object_type = 'text_completion'
|
|
||||||
stream_object_type = 'text_completion.chunk'
|
|
||||||
created_time = int(time.time())
|
|
||||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
|
||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
|
||||||
prompt_str = 'context' if is_legacy else 'prompt'
|
|
||||||
if prompt_str not in body:
|
|
||||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
|
||||||
|
|
||||||
prompt = body[prompt_str]
|
|
||||||
req_params = marshal_common_params(body)
|
|
||||||
requested_model = req_params.pop('requested_model')
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
if prompt and isinstance(prompt[0], int):
|
|
||||||
try:
|
|
||||||
encoder = tiktoken.encoding_for_model(requested_model)
|
|
||||||
prompt = encoder.decode(prompt)
|
|
||||||
except KeyError:
|
|
||||||
prompt = decode(prompt)[0]
|
|
||||||
else:
|
|
||||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
|
||||||
|
|
||||||
# common params
|
|
||||||
req_params['stream'] = True
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
|
||||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
|
||||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
|
|
||||||
if token_count + max_tokens > req_params['truncation_length']:
|
|
||||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
|
||||||
# print(f"Warning: ${err_msg}")
|
|
||||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
|
||||||
|
|
||||||
def text_streaming_chunk(content):
|
|
||||||
# begin streaming
|
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name,
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": None,
|
|
||||||
"text": content,
|
|
||||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
yield text_streaming_chunk('')
|
|
||||||
|
|
||||||
# generate reply #######################################
|
|
||||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
|
||||||
seen_content = ''
|
|
||||||
completion_token_count = 0
|
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
len_seen = len(seen_content)
|
|
||||||
new_content = answer[len_seen:]
|
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen_content = answer
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if len_seen == 0 and new_content[0] == ' ':
|
|
||||||
new_content = new_content[1:]
|
|
||||||
|
|
||||||
chunk = text_streaming_chunk(new_content)
|
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# to get the correct count, we strip the leading space if present
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||||
stop_reason = "stop"
|
generator = chat_completions_common(body, is_legacy, stream=False)
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
return deque(generator, maxlen=1).pop()
|
||||||
stop_reason = "length"
|
|
||||||
|
|
||||||
chunk = text_streaming_chunk('')
|
|
||||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
|
||||||
chunk["usage"] = {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
|
|
||||||
yield chunk
|
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||||
|
for resp in chat_completions_common(body, is_legacy, stream=True):
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
def completions(body: dict, is_legacy: bool = False) -> dict:
|
||||||
|
generator = completions_common(body, is_legacy, stream=False)
|
||||||
|
return deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_completions(body: dict, is_legacy: bool = False):
|
||||||
|
for resp in completions_common(body, is_legacy, stream=True):
|
||||||
|
yield resp
|
||||||
|
@ -1,78 +0,0 @@
|
|||||||
import copy
|
|
||||||
|
|
||||||
# Slightly different defaults for OpenAI's API
|
|
||||||
# Data type is important, Ex. use 0.0 for a float 0
|
|
||||||
default_req_params = {
|
|
||||||
'max_new_tokens': 16, # 'Inf' for chat
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
|
||||||
'temperature': 1.0,
|
|
||||||
'temperature_last': False,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'min_p': 0,
|
|
||||||
'top_k': 1, # choose 20 for chat in absence of another default
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'presence_penalty': 0,
|
|
||||||
'frequency_penalty': 0,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'encoder_repetition_penalty': 1.0,
|
|
||||||
'suffix': None,
|
|
||||||
'stream': False,
|
|
||||||
'echo': False,
|
|
||||||
'seed': -1,
|
|
||||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
|
||||||
'truncation_length': 2048, # first use shared.settings value
|
|
||||||
'add_bos_token': True,
|
|
||||||
'do_sample': True,
|
|
||||||
'typical_p': 1.0,
|
|
||||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'tfs': 1.0,
|
|
||||||
'top_a': 0.0,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0.0,
|
|
||||||
'length_penalty': 1.0,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5.0,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'grammar_string': '',
|
|
||||||
'guidance_scale': 1,
|
|
||||||
'negative_prompt': '',
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'custom_token_bans': '',
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'custom_stopping_strings': '',
|
|
||||||
# 'logits_processor' - conditionally passed
|
|
||||||
# 'stopping_strings' - temporarily used
|
|
||||||
# 'logprobs' - temporarily used
|
|
||||||
# 'requested_model' - temporarily used
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_req_params():
|
|
||||||
return copy.deepcopy(default_req_params)
|
|
||||||
|
|
||||||
|
|
||||||
def default(dic, key, default):
|
|
||||||
'''
|
|
||||||
little helper to get defaults if arg is present but None and should be the same type as default.
|
|
||||||
'''
|
|
||||||
val = dic.get(key, default)
|
|
||||||
if not isinstance(val, type(default)):
|
|
||||||
# maybe it's just something like 1 instead of 1.0
|
|
||||||
try:
|
|
||||||
v = type(default)(val)
|
|
||||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
|
||||||
return v
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
val = default
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def clamp(value, minvalue, maxvalue):
|
|
||||||
return max(minvalue, min(value, maxvalue))
|
|
@ -1,101 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from extensions.openai.defaults import get_default_req_params
|
|
||||||
from extensions.openai.errors import InvalidRequestError
|
|
||||||
from extensions.openai.utils import debug_msg
|
|
||||||
from modules import shared
|
|
||||||
from modules.text_generation import encode, generate_reply
|
|
||||||
|
|
||||||
|
|
||||||
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
|
||||||
|
|
||||||
created_time = int(time.time() * 1000)
|
|
||||||
|
|
||||||
# Request parameters
|
|
||||||
req_params = get_default_req_params()
|
|
||||||
stopping_strings = []
|
|
||||||
|
|
||||||
# Alpaca is verbose so a good default prompt
|
|
||||||
default_template = (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
instruction_template = default_template
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
template = template\
|
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
||||||
|
|
||||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
||||||
if instruct['user']:
|
|
||||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
instruction_template = default_template
|
|
||||||
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
else:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
|
||||||
|
|
||||||
truncation_length = shared.settings['truncation_length']
|
|
||||||
|
|
||||||
token_count = len(encode(edit_task)[0])
|
|
||||||
max_tokens = truncation_length - token_count
|
|
||||||
|
|
||||||
if max_tokens < 1:
|
|
||||||
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
|
||||||
raise InvalidRequestError(err_msg, param='input')
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
req_params['truncation_length'] = truncation_length
|
|
||||||
req_params['temperature'] = temperature
|
|
||||||
req_params['top_p'] = top_p
|
|
||||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
|
||||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
|
||||||
|
|
||||||
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
|
||||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"object": "edit",
|
|
||||||
"created": created_time,
|
|
||||||
"choices": [{
|
|
||||||
"text": answer,
|
|
||||||
"index": 0,
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
embeddings_params_initialized = False
|
embeddings_params_initialized = False
|
||||||
# using 'lazy loading' to avoid circular import
|
|
||||||
# so this function will be executed only once
|
|
||||||
def initialize_embedding_params():
|
def initialize_embedding_params():
|
||||||
|
'''
|
||||||
|
using 'lazy loading' to avoid circular import
|
||||||
|
so this function will be executed only once
|
||||||
|
'''
|
||||||
global embeddings_params_initialized
|
global embeddings_params_initialized
|
||||||
if not embeddings_params_initialized:
|
if not embeddings_params_initialized:
|
||||||
global st_model, embeddings_model, embeddings_device
|
global st_model, embeddings_model, embeddings_device
|
||||||
@ -26,7 +30,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
|
|||||||
initialize_embedding_params()
|
initialize_embedding_params()
|
||||||
global embeddings_device, embeddings_model
|
global embeddings_device, embeddings_model
|
||||||
try:
|
try:
|
||||||
print(f"\Try embedding model: {model} on {embeddings_device}")
|
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||||
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
|
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
|
||||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||||
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
|
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
|
||||||
@ -54,7 +58,7 @@ def get_embeddings(input: list) -> np.ndarray:
|
|||||||
model = get_embeddings_model()
|
model = get_embeddings_model()
|
||||||
debug_msg(f"embedding model : {model}")
|
debug_msg(f"embedding model : {model}")
|
||||||
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
|
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
|
||||||
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
|||||||
'data': []
|
'data': []
|
||||||
}
|
}
|
||||||
from extensions.openai.script import params
|
from extensions.openai.script import params
|
||||||
|
|
||||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||||
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
SpeechRecognition==3.10.0
|
SpeechRecognition==3.10.0
|
||||||
flask_cloudflared==0.0.12
|
flask_cloudflared==0.0.14
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
@ -1,351 +1,255 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import ssl
|
|
||||||
import traceback
|
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import extensions.openai.completions as OAIcompletions
|
import extensions.openai.completions as OAIcompletions
|
||||||
import extensions.openai.edits as OAIedits
|
|
||||||
import extensions.openai.embeddings as OAIembeddings
|
import extensions.openai.embeddings as OAIembeddings
|
||||||
import extensions.openai.images as OAIimages
|
import extensions.openai.images as OAIimages
|
||||||
import extensions.openai.models as OAImodels
|
import extensions.openai.models as OAImodels
|
||||||
import extensions.openai.moderations as OAImoderations
|
import extensions.openai.moderations as OAImoderations
|
||||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
|
||||||
from extensions.openai.errors import (
|
|
||||||
InvalidRequestError,
|
|
||||||
OpenAIError,
|
|
||||||
ServiceUnavailableError
|
|
||||||
)
|
|
||||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
|
||||||
from extensions.openai.utils import debug_msg
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
import cgi
|
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
|
import uvicorn
|
||||||
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
|
from extensions.openai.utils import _start_cloudflared
|
||||||
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
|
from .typing import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
to_dict
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
# default params
|
|
||||||
'port': 5001,
|
|
||||||
'embedding_device': 'cpu',
|
'embedding_device': 'cpu',
|
||||||
'embedding_model': 'all-mpnet-base-v2',
|
'embedding_model': 'all-mpnet-base-v2',
|
||||||
|
|
||||||
# optional params
|
|
||||||
'sd_webui_url': '',
|
'sd_webui_url': '',
|
||||||
'debug': 0
|
'debug': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
|
||||||
def send_access_control_headers(self):
|
|
||||||
self.send_header("Access-Control-Allow-Origin", "*")
|
|
||||||
self.send_header("Access-Control-Allow-Credentials", "true")
|
|
||||||
self.send_header(
|
|
||||||
"Access-Control-Allow-Methods",
|
|
||||||
"GET,HEAD,OPTIONS,POST,PUT"
|
|
||||||
)
|
|
||||||
self.send_header(
|
|
||||||
"Access-Control-Allow-Headers",
|
|
||||||
"Origin, Accept, X-Requested-With, Content-Type, "
|
|
||||||
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
|
||||||
"Authorization"
|
|
||||||
)
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||||
self.send_response(200)
|
expected_api_key = shared.args.api_key
|
||||||
self.send_access_control_headers()
|
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||||
self.send_header('Content-Type', 'application/json')
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
self.end_headers()
|
|
||||||
self.wfile.write("OK".encode('utf-8'))
|
|
||||||
|
|
||||||
def start_sse(self):
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
|
||||||
self.send_header('Cache-Control', 'no-cache')
|
|
||||||
# self.send_header('Connection', 'keep-alive')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
def send_sse(self, chunk: dict):
|
app = FastAPI(dependencies=[Depends(verify_api_key)])
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
|
||||||
debug_msg(response[:-4])
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
def end_sse(self):
|
# Configure CORS settings to allow all origins, methods, and headers
|
||||||
response = 'data: [DONE]\r\n\r\n'
|
app.add_middleware(
|
||||||
debug_msg(response[:-4])
|
CORSMiddleware,
|
||||||
self.wfile.write(response.encode('utf-8'))
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
|
||||||
|
allow_headers=[
|
||||||
|
"Origin",
|
||||||
|
"Accept",
|
||||||
|
"X-Requested-With",
|
||||||
|
"Content-Type",
|
||||||
|
"Access-Control-Request-Method",
|
||||||
|
"Access-Control-Request-Headers",
|
||||||
|
"Authorization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
|
||||||
self.send_response(code)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
|
|
||||||
response = json.dumps(ret)
|
@app.options("/")
|
||||||
r_utf8 = response.encode('utf-8')
|
async def options_route():
|
||||||
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
self.send_header('Content-Length', str(len(r_utf8)))
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
self.wfile.write(r_utf8)
|
@app.post('/v1/completions', response_model=CompletionResponse)
|
||||||
if not no_debug:
|
@app.post('/v1/generate', response_model=CompletionResponse)
|
||||||
debug_msg(r_utf8)
|
async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||||
|
path = request.url.path
|
||||||
|
is_legacy = "/generate" in path
|
||||||
|
|
||||||
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
if request_data.stream:
|
||||||
|
async def generator():
|
||||||
|
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||||
|
for resp in response:
|
||||||
|
yield {"data": json.dumps(resp)}
|
||||||
|
|
||||||
error_resp = {
|
return EventSourceResponse(generator()) # SSE streaming
|
||||||
'error': {
|
|
||||||
'message': message,
|
|
||||||
'code': code,
|
|
||||||
'type': error_type,
|
|
||||||
'param': param,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if internal_message:
|
|
||||||
print(error_type, message)
|
|
||||||
print(internal_message)
|
|
||||||
# error_resp['internal_message'] = internal_message
|
|
||||||
|
|
||||||
self.return_json(error_resp, code)
|
else:
|
||||||
|
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
def openai_error_handler(func):
|
|
||||||
def wrapper(self):
|
|
||||||
try:
|
|
||||||
func(self)
|
|
||||||
except InvalidRequestError as e:
|
|
||||||
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
|
|
||||||
except OpenAIError as e:
|
|
||||||
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
|
|
||||||
except Exception as e:
|
|
||||||
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
|
||||||
|
|
||||||
return wrapper
|
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
|
||||||
|
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
||||||
|
path = request.url.path
|
||||||
|
is_legacy = "/generate" in path
|
||||||
|
|
||||||
@openai_error_handler
|
if request_data.stream:
|
||||||
def do_GET(self):
|
async def generator():
|
||||||
debug_msg(self.requestline)
|
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||||
debug_msg(self.headers)
|
for resp in response:
|
||||||
|
yield {"data": json.dumps(resp)}
|
||||||
|
|
||||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
return EventSourceResponse(generator()) # SSE streaming
|
||||||
is_legacy = 'engines' in self.path
|
|
||||||
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
|
||||||
if is_legacy and not is_list:
|
|
||||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
|
||||||
resp = OAImodels.load_model(model_name)
|
|
||||||
elif is_list:
|
|
||||||
resp = OAImodels.list_models(is_legacy)
|
|
||||||
else:
|
|
||||||
model_name = self.path[len('/v1/models/'):]
|
|
||||||
resp = OAImodels.model_info(model_name)
|
|
||||||
|
|
||||||
self.return_json(resp)
|
else:
|
||||||
|
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
|
||||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
|
||||||
self.return_json({"total_usage": 0}, no_debug=True)
|
|
||||||
|
|
||||||
else:
|
@app.get("/v1/models")
|
||||||
self.send_error(404)
|
@app.get("/v1/engines")
|
||||||
|
async def handle_models(request: Request):
|
||||||
|
path = request.url.path
|
||||||
|
is_legacy = 'engines' in path
|
||||||
|
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||||
|
|
||||||
@openai_error_handler
|
if is_legacy and not is_list:
|
||||||
def do_POST(self):
|
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||||
|
resp = OAImodels.load_model(model_name)
|
||||||
|
elif is_list:
|
||||||
|
resp = OAImodels.list_models(is_legacy)
|
||||||
|
else:
|
||||||
|
model_name = path[len('/v1/models/'):]
|
||||||
|
resp = OAImodels.model_info(model_name)
|
||||||
|
|
||||||
if '/v1/audio/transcriptions' in self.path:
|
return JSONResponse(content=resp)
|
||||||
r = sr.Recognizer()
|
|
||||||
|
|
||||||
# Parse the form data
|
|
||||||
form = cgi.FieldStorage(
|
|
||||||
fp=self.rfile,
|
|
||||||
headers=self.headers,
|
|
||||||
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_file = form['file'].file
|
|
||||||
audio_data = AudioSegment.from_file(audio_file)
|
|
||||||
|
|
||||||
# Convert AudioSegment to raw data
|
|
||||||
raw_data = audio_data.raw_data
|
|
||||||
|
|
||||||
# Create AudioData object
|
|
||||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
|
||||||
whipser_language = form.getvalue('language', None)
|
|
||||||
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
|
||||||
|
|
||||||
transcription = {"text": ""}
|
@app.get('/v1/billing/usage')
|
||||||
|
def handle_billing_usage():
|
||||||
try:
|
'''
|
||||||
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
except sr.UnknownValueError:
|
'''
|
||||||
print("Whisper could not understand audio")
|
return JSONResponse(content={"total_usage": 0})
|
||||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
|
||||||
except sr.RequestError as e:
|
|
||||||
print("Could not request results from Whisper", e)
|
|
||||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
|
||||||
|
|
||||||
self.return_json(transcription, no_debug=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
debug_msg(self.requestline)
|
|
||||||
debug_msg(self.headers)
|
|
||||||
|
|
||||||
content_length = self.headers.get('Content-Length')
|
|
||||||
transfer_encoding = self.headers.get('Transfer-Encoding')
|
|
||||||
|
|
||||||
if content_length:
|
@app.post('/v1/audio/transcriptions')
|
||||||
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
|
async def handle_audio_transcription(request: Request):
|
||||||
elif transfer_encoding == 'chunked':
|
r = sr.Recognizer()
|
||||||
chunks = []
|
|
||||||
while True:
|
|
||||||
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
|
|
||||||
if chunk_size == 0:
|
|
||||||
break # End of chunks
|
|
||||||
chunks.append(self.rfile.read(chunk_size))
|
|
||||||
self.rfile.readline() # Consume the trailing newline after each chunk
|
|
||||||
body = json.loads(b''.join(chunks).decode('utf-8'))
|
|
||||||
else:
|
|
||||||
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
|
|
||||||
self.end_headers()
|
|
||||||
return
|
|
||||||
|
|
||||||
debug_msg(body)
|
form = await request.form()
|
||||||
|
audio_file = await form["file"].read()
|
||||||
|
audio_data = AudioSegment.from_file(audio_file)
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
# Convert AudioSegment to raw data
|
||||||
|
raw_data = audio_data.raw_data
|
||||||
|
|
||||||
if not shared.model:
|
# Create AudioData object
|
||||||
raise ServiceUnavailableError("No model loaded.")
|
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||||
|
whipser_language = form.getvalue('language', None)
|
||||||
|
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
transcription = {"text": ""}
|
||||||
is_streaming = body.get('stream', False)
|
|
||||||
|
|
||||||
if is_streaming:
|
try:
|
||||||
self.start_sse()
|
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
||||||
|
except sr.UnknownValueError:
|
||||||
|
print("Whisper could not understand audio")
|
||||||
|
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||||
|
except sr.RequestError as e:
|
||||||
|
print("Could not request results from Whisper", e)
|
||||||
|
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||||
|
|
||||||
response = []
|
return JSONResponse(content=transcription)
|
||||||
if 'chat' in self.path:
|
|
||||||
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
|
||||||
else:
|
|
||||||
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
self.send_sse(resp)
|
|
||||||
|
|
||||||
self.end_sse()
|
@app.post('/v1/images/generations')
|
||||||
|
async def handle_image_generation(request: Request):
|
||||||
|
|
||||||
else:
|
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||||
response = ''
|
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||||
if 'chat' in self.path:
|
|
||||||
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
|
||||||
else:
|
|
||||||
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
|
||||||
|
|
||||||
self.return_json(response)
|
body = await request.json()
|
||||||
|
prompt = body['prompt']
|
||||||
|
size = body.get('size', '1024x1024')
|
||||||
|
response_format = body.get('response_format', 'url') # or b64_json
|
||||||
|
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||||
|
|
||||||
elif '/edits' in self.path:
|
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||||
# deprecated
|
return JSONResponse(response)
|
||||||
|
|
||||||
if not shared.model:
|
|
||||||
raise ServiceUnavailableError("No model loaded.")
|
|
||||||
|
|
||||||
req_params = get_default_req_params()
|
@app.post("/v1/embeddings")
|
||||||
|
async def handle_embeddings(request: Request):
|
||||||
|
body = await request.json()
|
||||||
|
encoding_format = body.get("encoding_format", "")
|
||||||
|
|
||||||
instruction = body['instruction']
|
input = body.get('input', body.get('text', ''))
|
||||||
input = body.get('input', '')
|
if not input:
|
||||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||||
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
|
||||||
|
|
||||||
response = OAIedits.edits(instruction, input, temperature, top_p)
|
if type(input) is str:
|
||||||
|
input = [input]
|
||||||
|
|
||||||
self.return_json(response)
|
response = OAIembeddings.embeddings(input, encoding_format)
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
elif '/images/generations' in self.path:
|
|
||||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
|
||||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
|
||||||
|
|
||||||
prompt = body['prompt']
|
@app.post("/v1/moderations")
|
||||||
size = default(body, 'size', '1024x1024')
|
async def handle_moderations(request: Request):
|
||||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
body = await request.json()
|
||||||
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
input = body["input"]
|
||||||
|
if not input:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||||
|
|
||||||
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
response = OAImoderations.moderations(input)
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
|
||||||
|
|
||||||
elif '/embeddings' in self.path:
|
@app.post("/api/v1/token-count")
|
||||||
encoding_format = body.get('encoding_format', '')
|
async def handle_token_count(request: Request):
|
||||||
|
body = await request.json()
|
||||||
|
response = token_count(body['prompt'])
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
input = body.get('input', body.get('text', ''))
|
|
||||||
if not input:
|
|
||||||
raise InvalidRequestError("Missing required argument input", params='input')
|
|
||||||
|
|
||||||
if type(input) is str:
|
@app.post("/api/v1/token/encode")
|
||||||
input = [input]
|
async def handle_token_encode(request: Request):
|
||||||
|
body = await request.json()
|
||||||
|
encoding_format = body.get("encoding_format", "")
|
||||||
|
response = token_encode(body["input"], encoding_format)
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
response = OAIembeddings.embeddings(input, encoding_format)
|
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
@app.post("/api/v1/token/decode")
|
||||||
|
async def handle_token_decode(request: Request):
|
||||||
elif '/moderations' in self.path:
|
body = await request.json()
|
||||||
input = body['input']
|
encoding_format = body.get("encoding_format", "")
|
||||||
if not input:
|
response = token_decode(body["input"], encoding_format)
|
||||||
raise InvalidRequestError("Missing required argument input", params='input')
|
return JSONResponse(response, no_debug=True)
|
||||||
|
|
||||||
response = OAImoderations.moderations(input)
|
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
|
||||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
|
||||||
response = token_count(body['prompt'])
|
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token/encode':
|
|
||||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
|
||||||
encoding_format = body.get('encoding_format', '')
|
|
||||||
|
|
||||||
response = token_encode(body['input'], encoding_format)
|
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token/decode':
|
|
||||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
|
||||||
encoding_format = body.get('encoding_format', '')
|
|
||||||
|
|
||||||
response = token_decode(body['input'], encoding_format)
|
|
||||||
|
|
||||||
self.return_json(response, no_debug=True)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.send_error(404)
|
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
|
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
|
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||||
server = ThreadingHTTPServer(server_addr, Handler)
|
|
||||||
|
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||||
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||||
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
|
||||||
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
|
if shared.args.public_api:
|
||||||
if ssl_verify:
|
def on_start(public_url: str):
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
|
||||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
|
||||||
server.socket = context.wrap_socket(server.socket, server_side=True)
|
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
||||||
|
|
||||||
if shared.args.share:
|
|
||||||
try:
|
|
||||||
from flask_cloudflared import _run_cloudflared
|
|
||||||
public_url = _run_cloudflared(port, port + 1)
|
|
||||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
|
||||||
except ImportError:
|
|
||||||
print('You should install flask_cloudflared manually')
|
|
||||||
else:
|
else:
|
||||||
if ssl_verify:
|
if ssl_keyfile and ssl_certfile:
|
||||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
|
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
|
||||||
else:
|
else:
|
||||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
|
||||||
|
|
||||||
server.serve_forever()
|
if shared.args.api_key:
|
||||||
|
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||||
|
|
||||||
|
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
125
extensions/openai/typing.py
Normal file
125
extensions/openai/typing.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationOptions(BaseModel):
|
||||||
|
preset: str | None = None
|
||||||
|
temperature: float = 1
|
||||||
|
temperature_last: bool = False
|
||||||
|
top_p: float = 1
|
||||||
|
min_p: float = 0
|
||||||
|
top_k: int = 0
|
||||||
|
repetition_penalty: float = 1
|
||||||
|
presence_penalty: float = 0
|
||||||
|
frequency_penalty: float = 0
|
||||||
|
repetition_penalty_range: int = 0
|
||||||
|
typical_p: float = 1
|
||||||
|
tfs: float = 1
|
||||||
|
top_a: float = 0
|
||||||
|
epsilon_cutoff: float = 0
|
||||||
|
eta_cutoff: float = 0
|
||||||
|
guidance_scale: float = 1
|
||||||
|
negative_prompt: str = ''
|
||||||
|
penalty_alpha: float = 0
|
||||||
|
mirostat_mode: int = 0
|
||||||
|
mirostat_tau: float = 5
|
||||||
|
mirostat_eta: float = 0.1
|
||||||
|
do_sample: bool = True
|
||||||
|
seed: int = -1
|
||||||
|
encoder_repetition_penalty: float = 1
|
||||||
|
no_repeat_ngram_size: int = 0
|
||||||
|
min_length: int = 0
|
||||||
|
num_beams: int = 1
|
||||||
|
length_penalty: float = 1
|
||||||
|
early_stopping: bool = False
|
||||||
|
truncation_length: int = 0
|
||||||
|
max_tokens_second: int = 0
|
||||||
|
custom_token_bans: str = ""
|
||||||
|
auto_max_new_tokens: bool = False
|
||||||
|
ban_eos_token: bool = False
|
||||||
|
add_bos_token: bool = True
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
grammar_string: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(GenerationOptions):
|
||||||
|
model: str | None = None
|
||||||
|
prompt: str | List[str]
|
||||||
|
best_of: int | None = 1
|
||||||
|
echo: bool | None = False
|
||||||
|
frequency_penalty: float | None = 0
|
||||||
|
logit_bias: dict | None = None
|
||||||
|
logprobs: int | None = None
|
||||||
|
max_tokens: int | None = 16
|
||||||
|
n: int | None = 1
|
||||||
|
presence_penalty: int | None = 0
|
||||||
|
stop: str | List[str] | None = None
|
||||||
|
stream: bool | None = False
|
||||||
|
suffix: str | None = None
|
||||||
|
temperature: float | None = 1
|
||||||
|
top_p: float | None = 1
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
choices: List[dict]
|
||||||
|
created: int = int(time.time())
|
||||||
|
model: str
|
||||||
|
object: str = "text_completion"
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(GenerationOptions):
|
||||||
|
messages: List[dict]
|
||||||
|
model: str | None = None
|
||||||
|
frequency_penalty: float | None = 0
|
||||||
|
function_call: str | dict | None = None
|
||||||
|
functions: List[dict] | None = None
|
||||||
|
logit_bias: dict | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
n: int | None = 1
|
||||||
|
presence_penalty: int | None = 0
|
||||||
|
stop: str | List[str] | None = None
|
||||||
|
stream: bool | None = False
|
||||||
|
temperature: float | None = 1
|
||||||
|
top_p: float | None = 1
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||||
|
|
||||||
|
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
|
||||||
|
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||||
|
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||||
|
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||||
|
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||||
|
|
||||||
|
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
|
||||||
|
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||||
|
name2: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||||
|
context: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||||
|
greeting: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||||
|
|
||||||
|
chat_instruct_command: str | None = None
|
||||||
|
|
||||||
|
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
choices: List[dict]
|
||||||
|
created: int = int(time.time())
|
||||||
|
model: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
usage: dict
|
||||||
|
|
||||||
|
|
||||||
|
def to_json(obj):
|
||||||
|
return json.dumps(obj.__dict__, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def to_dict(obj):
|
||||||
|
return obj.__dict__
|
@ -1,8 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
# float_array = np.array(float_list, dtype="float32")
|
# float_array = np.array(float_list, dtype="float32")
|
||||||
@ -18,13 +22,33 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
|
|||||||
return ascii_string
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
def end_line(s):
|
|
||||||
if s and s[-1] != '\n':
|
|
||||||
s = s + '\n'
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def debug_msg(*args, **kwargs):
|
def debug_msg(*args, **kwargs):
|
||||||
from extensions.openai.script import params
|
from extensions.openai.script import params
|
||||||
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
|
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||||
|
try:
|
||||||
|
from flask_cloudflared import _run_cloudflared
|
||||||
|
except ImportError:
|
||||||
|
print('You should install flask_cloudflared manually')
|
||||||
|
raise Exception(
|
||||||
|
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
try:
|
||||||
|
if tunnel_id is not None:
|
||||||
|
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
||||||
|
else:
|
||||||
|
public_url = _run_cloudflared(port, port + 1)
|
||||||
|
|
||||||
|
if on_start:
|
||||||
|
on_start(public_url)
|
||||||
|
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
raise Exception('Could not start cloudflared.')
|
||||||
|
@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
# Find the maximum prompt size
|
# Find the maximum prompt size
|
||||||
max_length = get_max_prompt_length(state)
|
max_length = get_max_prompt_length(state)
|
||||||
all_substrings = {
|
all_substrings = {
|
||||||
'chat': get_turn_substrings(state, instruct=False),
|
'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
|
||||||
'instruct': get_turn_substrings(state, instruct=True)
|
'instruct': get_turn_substrings(state, instruct=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||||||
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
|
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||||
|
|
||||||
# Extract the reply
|
# Extract the reply
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
visible_reply = reply
|
||||||
|
if state['mode'] in ['chat', 'chat-instruct']:
|
||||||
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
|
|
||||||
visible_reply = html.escape(visible_reply)
|
visible_reply = html.escape(visible_reply)
|
||||||
|
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
|
@ -71,11 +71,12 @@ def load_model(model_name, loader=None):
|
|||||||
'AutoAWQ': AutoAWQ_loader,
|
'AutoAWQ': AutoAWQ_loader,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
metadata = get_model_metadata(model_name)
|
||||||
if loader is None:
|
if loader is None:
|
||||||
if shared.args.loader is not None:
|
if shared.args.loader is not None:
|
||||||
loader = shared.args.loader
|
loader = shared.args.loader
|
||||||
else:
|
else:
|
||||||
loader = get_model_metadata(model_name)['loader']
|
loader = metadata['loader']
|
||||||
if loader is None:
|
if loader is None:
|
||||||
logger.error('The path to the model does not exist. Exiting.')
|
logger.error('The path to the model does not exist. Exiting.')
|
||||||
return None, None
|
return None, None
|
||||||
@ -95,6 +96,7 @@ def load_model(model_name, loader=None):
|
|||||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
llama_attn_hijack.hijack_llama_attention()
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
|
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||||
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
@ -6,33 +6,32 @@ import yaml
|
|||||||
|
|
||||||
def default_preset():
|
def default_preset():
|
||||||
return {
|
return {
|
||||||
'do_sample': True,
|
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
'temperature_last': False,
|
'temperature_last': False,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'min_p': 0,
|
'min_p': 0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
'typical_p': 1,
|
|
||||||
'epsilon_cutoff': 0,
|
|
||||||
'eta_cutoff': 0,
|
|
||||||
'tfs': 1,
|
|
||||||
'top_a': 0,
|
|
||||||
'repetition_penalty': 1,
|
'repetition_penalty': 1,
|
||||||
'presence_penalty': 0,
|
'presence_penalty': 0,
|
||||||
'frequency_penalty': 0,
|
'frequency_penalty': 0,
|
||||||
'repetition_penalty_range': 0,
|
'repetition_penalty_range': 0,
|
||||||
|
'typical_p': 1,
|
||||||
|
'tfs': 1,
|
||||||
|
'top_a': 0,
|
||||||
|
'epsilon_cutoff': 0,
|
||||||
|
'eta_cutoff': 0,
|
||||||
|
'guidance_scale': 1,
|
||||||
|
'penalty_alpha': 0,
|
||||||
|
'mirostat_mode': 0,
|
||||||
|
'mirostat_tau': 5,
|
||||||
|
'mirostat_eta': 0.1,
|
||||||
|
'do_sample': True,
|
||||||
'encoder_repetition_penalty': 1,
|
'encoder_repetition_penalty': 1,
|
||||||
'no_repeat_ngram_size': 0,
|
'no_repeat_ngram_size': 0,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
'guidance_scale': 1,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5.0,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'num_beams': 1,
|
'num_beams': 1,
|
||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'custom_token_bans': '',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,21 +39,21 @@ settings = {
|
|||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'max_new_tokens_min': 1,
|
'max_new_tokens_min': 1,
|
||||||
'max_new_tokens_max': 4096,
|
'max_new_tokens_max': 4096,
|
||||||
'seed': -1,
|
|
||||||
'negative_prompt': '',
|
'negative_prompt': '',
|
||||||
|
'seed': -1,
|
||||||
'truncation_length': 2048,
|
'truncation_length': 2048,
|
||||||
'truncation_length_min': 0,
|
'truncation_length_min': 0,
|
||||||
'truncation_length_max': 32768,
|
'truncation_length_max': 32768,
|
||||||
'custom_stopping_strings': '',
|
|
||||||
'auto_max_new_tokens': False,
|
|
||||||
'max_tokens_second': 0,
|
'max_tokens_second': 0,
|
||||||
'ban_eos_token': False,
|
'custom_stopping_strings': '',
|
||||||
'custom_token_bans': '',
|
'custom_token_bans': '',
|
||||||
|
'auto_max_new_tokens': False,
|
||||||
|
'ban_eos_token': False,
|
||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'stream': True,
|
'stream': True,
|
||||||
'name1': 'You',
|
|
||||||
'character': 'Assistant',
|
'character': 'Assistant',
|
||||||
|
'name1': 'You',
|
||||||
'instruction_template': 'Alpaca',
|
'instruction_template': 'Alpaca',
|
||||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||||
'autoload_model': False,
|
'autoload_model': False,
|
||||||
@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif
|
|||||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||||
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
||||||
@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED')
|
|||||||
parser.add_argument('--chat', action='store_true', help='DEPRECATED')
|
parser.add_argument('--chat', action='store_true', help='DEPRECATED')
|
||||||
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
|
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
|
||||||
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
|
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
|
||||||
|
parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED')
|
||||||
|
parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_defaults = parser.parse_args([])
|
args_defaults = parser.parse_args([])
|
||||||
@ -233,10 +235,13 @@ def fix_loader_name(name):
|
|||||||
return 'AutoAWQ'
|
return 'AutoAWQ'
|
||||||
|
|
||||||
|
|
||||||
def add_extension(name):
|
def add_extension(name, last=False):
|
||||||
if args.extensions is None:
|
if args.extensions is None:
|
||||||
args.extensions = [name]
|
args.extensions = [name]
|
||||||
elif 'api' not in args.extensions:
|
elif last:
|
||||||
|
args.extensions = [x for x in args.extensions if x != name]
|
||||||
|
args.extensions.append(name)
|
||||||
|
elif name not in args.extensions:
|
||||||
args.extensions.append(name)
|
args.extensions.append(name)
|
||||||
|
|
||||||
|
|
||||||
@ -246,14 +251,15 @@ def is_chat():
|
|||||||
|
|
||||||
args.loader = fix_loader_name(args.loader)
|
args.loader = fix_loader_name(args.loader)
|
||||||
|
|
||||||
# Activate the API extension
|
|
||||||
if args.api or args.public_api:
|
|
||||||
add_extension('api')
|
|
||||||
|
|
||||||
# Activate the multimodal extension
|
# Activate the multimodal extension
|
||||||
if args.multimodal_pipeline is not None:
|
if args.multimodal_pipeline is not None:
|
||||||
add_extension('multimodal')
|
add_extension('multimodal')
|
||||||
|
|
||||||
|
# Activate the API extension
|
||||||
|
if args.api:
|
||||||
|
# add_extension('openai', last=True)
|
||||||
|
add_extension('api', last=True)
|
||||||
|
|
||||||
# Load model-specific settings
|
# Load model-specific settings
|
||||||
with Path(f'{args.model_dir}/config.yaml') as p:
|
with Path(f'{args.model_dir}/config.yaml') as p:
|
||||||
if p.exists():
|
if p.exists():
|
||||||
|
@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
|
|
||||||
# Find the stopping strings
|
# Find the stopping strings
|
||||||
all_stop_strings = []
|
all_stop_strings = []
|
||||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
for st in (stopping_strings, state['custom_stopping_strings']):
|
||||||
|
if type(st) is str:
|
||||||
|
st = ast.literal_eval(f"[{st}]")
|
||||||
|
|
||||||
if type(st) is list and len(st) > 0:
|
if type(st) is list and len(st) > 0:
|
||||||
all_stop_strings += st
|
all_stop_strings += st
|
||||||
|
|
||||||
|
@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
|||||||
if 'instruction_template' in settings:
|
if 'instruction_template' in settings:
|
||||||
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
|
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
|
||||||
|
|
||||||
# Applying the changes to the global shared settings (in-memory)
|
|
||||||
shared.settings.update({k: v for k, v in settings.items() if k in shared.settings})
|
|
||||||
|
|
||||||
yield output
|
yield output
|
||||||
else:
|
else:
|
||||||
yield f"Failed to load `{selected_model}`."
|
yield f"Failed to load `{selected_model}`."
|
||||||
|
Loading…
Reference in New Issue
Block a user