[extensions/openai] Support undocumented base64 'encoding_format' param for compatibility with official OpenAI client (#1876)

This commit is contained in:
Jeffrey Lin 2023-05-08 18:31:34 -07:00 committed by GitHub
parent d78b04f0b4
commit 791a38bad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,6 @@
import base64
import json import json
import numpy as np
import os import os
import time import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
@ -45,6 +47,20 @@ def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))
def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32")
# Get raw bytes
bytes_array = float_array.tobytes()
# Encode bytes into base64
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path.startswith('/v1/models'): if self.path.startswith('/v1/models'):
@ -435,7 +451,13 @@ class Handler(BaseHTTPRequestHandler):
embeddings = embedding_model.encode(input).tolist() embeddings = embedding_model.encode(input).tolist()
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)] def enc_emb(emb):
# If base64 is specified, encode. Otherwise, do nothing.
if body.get("encoding_format", "") == "base64":
return float_list_to_base64(emb)
else:
return emb
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
response = json.dumps({ response = json.dumps({
"object": "list", "object": "list",