extensions/openai: load extension settings via settings.yaml (#3953)

This commit is contained in:
Chenxiao Wang 2023-09-18 09:39:29 +08:00 committed by GitHub
parent cc8eda298a
commit 347aed4254
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 16 deletions

View File

@ -44,6 +44,18 @@ OPENAI_API_BASE=http://0.0.0.0:5001/v1
If needed, replace 0.0.0.0 with the IP/port of your server. If needed, replace 0.0.0.0 with the IP/port of your server.
## Settings
To adjust your default settings, you can add the following to your `settings.yaml` file.
```
openai-port: 5002
openai-embedding_device: cuda
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Models ### Models
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat. This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.

View File

@ -6,6 +6,7 @@
import os import os
import sentence_transformers import sentence_transformers
from extensions.openai.script import params
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
model = sentence_transformers.SentenceTransformer(st_model) model = sentence_transformers.SentenceTransformer(st_model)

View File

@ -5,15 +5,25 @@ from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64 from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" embeddings_params_initialized = False
embeddings_model = None # using 'lazy loading' to avoid circular import
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone # so this function will be executed only once
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu") def initialize_embedding_params():
if embeddings_device.lower() == 'auto': global embeddings_params_initialized
embeddings_device = None if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
from extensions.openai.script import params
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
def load_embedding_model(model: str) -> SentenceTransformer: def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
embeddings_model = 'loading...' # flag embeddings_model = 'loading...' # flag
@ -29,6 +39,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
def get_embeddings_model() -> SentenceTransformer: def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params()
global embeddings_model, st_model global embeddings_model, st_model
if st_model and not embeddings_model: if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model embeddings_model = load_embedding_model(st_model) # lazy load the model
@ -36,6 +47,7 @@ def get_embeddings_model() -> SentenceTransformer:
def get_embeddings_model_name() -> str: def get_embeddings_model_name() -> str:
initialize_embedding_params()
global st_model global st_model
return st_model return st_model

View File

@ -49,9 +49,9 @@ def generations(prompt: str, size: str, response_format: str, n: int):
'created': int(time.time()), 'created': int(time.time()),
'data': [] 'data': []
} }
from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair. # TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload) response = requests.post(url=sd_url, json=payload)
r = response.json() r = response.json()

View File

@ -25,10 +25,16 @@ import speech_recognition as sr
from pydub import AudioSegment from pydub import AudioSegment
params = { params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, # default params
'port': 5001,
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
# optional params
'sd_webui_url': '',
'debug': 0
} }
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self): def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Origin", "*")
@ -251,7 +257,7 @@ class Handler(BaseHTTPRequestHandler):
self.return_json(response) self.return_json(response)
elif '/images/generations' in self.path: elif '/images/generations' in self.path:
if 'SD_WEBUI_URL' not in os.environ: if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt'] prompt = body['prompt']
@ -313,12 +319,13 @@ class Handler(BaseHTTPRequestHandler):
def run_server(): def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler) server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share: if shared.args.share:
try: try:
from flask_cloudflared import _run_cloudflared from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1) public_url = _run_cloudflared(port, port + 1)
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1') print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
except ImportError: except ImportError:
print('You should install flask_cloudflared manually') print('You should install flask_cloudflared manually')

View File

@ -3,7 +3,6 @@ import os
import numpy as np import numpy as np
def float_list_to_base64(float_array: np.ndarray) -> str: def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects # Convert the list to a float32 array that the OpenAPI client expects
# float_array = np.array(float_list, dtype="float32") # float_array = np.array(float_list, dtype="float32")
@ -26,5 +25,6 @@ def end_line(s):
def debug_msg(*args, **kwargs): def debug_msg(*args, **kwargs):
if 'OPENEDAI_DEBUG' in os.environ: from extensions.openai.script import params
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
print(*args, **kwargs) print(*args, **kwargs)