Lint the openai extension

This commit is contained in:
oobabooga 2023-09-15 20:11:16 -07:00
parent 760510db52
commit 8f97e87cac
12 changed files with 79 additions and 69 deletions

View File

@ -3,6 +3,9 @@
# Dockerfile:
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
# RUN python3 cache_embedded_model.py
import os, sentence_transformers
import os
import sentence_transformers
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
model = sentence_transformers.SentenceTransformer(st_model)

View File

@ -1,18 +1,15 @@
import time
import yaml
import tiktoken
import torch
import torch.nn.functional as F
from math import log, exp
from transformers import LogitsProcessor, LogitsProcessorList
import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line
from modules import shared
from modules.text_generation import encode, decode, generate_reply
from extensions.openai.defaults import get_default_req_params, default, clamp
from extensions.openai.utils import end_line, debug_msg
from extensions.openai.errors import *
from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
@ -36,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
@ -64,6 +62,7 @@ def convert_logprobs_to_tiktoken(model, logprobs):
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
return logprobs
@ -146,7 +145,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if not 'messages' in body:
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
@ -159,7 +158,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
'prompt': 'Assistant:',
}
if not 'stopping_strings' in req_params:
if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
@ -439,7 +438,7 @@ def completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str]
@ -538,10 +537,12 @@ def stream_completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
@ -553,12 +554,10 @@ def stream_completions(body: dict, is_legacy: bool = False):
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])

View File

@ -51,9 +51,11 @@ def get_default_req_params():
return copy.deepcopy(default_req_params)
# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
val = dic.get(key, default)
if type(val) != type(default):
if not isinstance(val, type(default)):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)

View File

@ -1,10 +1,10 @@
import time
import yaml
import os
from modules import shared
from extensions.openai.defaults import get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from extensions.openai.errors import *
from modules import shared
from modules.text_generation import encode, generate_reply
@ -74,7 +74,6 @@ def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
answer = ''
for a in generator:
answer = a

View File

@ -1,8 +1,9 @@
import os
from sentence_transformers import SentenceTransformer
import numpy as np
from extensions.openai.utils import float_list_to_base64, debug_msg
from extensions.openai.errors import *
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None
@ -11,6 +12,7 @@ embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == 'auto':
embeddings_device = None
def load_embedding_model(model: str) -> SentenceTransformer:
global embeddings_device, embeddings_model
try:
@ -41,6 +43,7 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input)

View File

@ -1,7 +1,8 @@
import os
import time
import requests
from extensions.openai.errors import *
from extensions.openai.errors import ServiceUnavailableError
def generations(prompt: str, size: str, response_format: str, n: int):
@ -14,7 +15,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
sd_defaults = {
'sampler_name': 'DPM++ 2M Karras', # vast improvement
'steps': 30,

View File

@ -1,7 +1,8 @@
from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import *
from extensions.openai.errors import OpenAIError
from modules import shared
from modules.models import load_model, unload_model
from modules.models import load_model as _load_model
from modules.models import unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models
@ -38,7 +39,7 @@ def load_model(model_name: str) -> dict:
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.model, shared.tokenizer = _load_model(shared.model_name)
if not shared.model: # load failed.
shared.model_name = "None"

View File

@ -1,8 +1,8 @@
import time
import numpy as np
from numpy.linalg import norm
from extensions.openai.embeddings import get_embeddings
import numpy as np
from extensions.openai.embeddings import get_embeddings
from numpy.linalg import norm
moderations_disabled = False # return 0/false
category_embeddings = None

View File

@ -4,19 +4,21 @@ import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules import shared
from extensions.openai.tokens import token_count, token_encode, token_decode
import extensions.openai.models as OAImodels
import extensions.openai.completions as OAIcompletions
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
import extensions.openai.completions as OAIcompletions
from extensions.openai.errors import *
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import (
InvalidRequestError,
OpenAIError,
ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg
from extensions.openai.defaults import (get_default_req_params, default, clamp)
from modules import shared
params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
@ -209,7 +211,7 @@ class Handler(BaseHTTPRequestHandler):
self.return_json(response)
elif '/images/generations' in self.path:
if not 'SD_WEBUI_URL' in os.environ:
if 'SD_WEBUI_URL' not in os.environ:
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt']

View File

@ -1,6 +1,5 @@
from extensions.openai.utils import float_list_to_base64
from modules.text_generation import encode, decode
import numpy as np
from modules.text_generation import decode, encode
def token_count(prompt):
tokens = encode(prompt)[0]

View File

@ -1,5 +1,6 @@
import os
import base64
import os
import numpy as np