mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Update to gpt4all version 1.0.1. Implement the Streaming version of the completions endpoint. Implemented an openai python client test for the new streaming functionality. (#1129)
Co-authored-by: Brandon <bbeiler@ridgelineintl.com>
This commit is contained in:
parent
affd0af51f
commit
fb576fbd7e
@ -1,6 +1,9 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Response, Security, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Iterable, AsyncIterable
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
from api_v1.settings import settings
|
||||
@ -10,6 +13,7 @@ import time
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -28,10 +32,13 @@ class CompletionChoice(BaseModel):
|
||||
logprobs: float
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = 'text_completion'
|
||||
@ -41,46 +48,81 @@ class CompletionResponse(BaseModel):
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = 'text_completion'
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
|
||||
|
||||
router = APIRouter(prefix="/completions", tags=["Completion Endpoints"])
|
||||
|
||||
|
||||
def stream_completion(output: Iterable, base_response: CompletionStreamResponse):
|
||||
"""
|
||||
Streams a GPT4All output to the client.
|
||||
|
||||
Args:
|
||||
output: The output of GPT4All.generate(), which is an iterable of tokens.
|
||||
base_response: The base response object, which is cloned and modified for each token.
|
||||
|
||||
Returns:
|
||||
A Generator of CompletionStreamResponse objects, which are serialized to JSON Event Stream format.
|
||||
"""
|
||||
for token in output:
|
||||
chunk = base_response.copy()
|
||||
chunk.choices = [dict(CompletionChoice(
|
||||
text=token,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason=''
|
||||
))]
|
||||
yield f"data: {json.dumps(dict(chunk))}\n\n"
|
||||
|
||||
|
||||
@router.post("/", response_model=CompletionResponse)
|
||||
async def completions(request: CompletionRequest):
|
||||
'''
|
||||
Completes a GPT4All model response.
|
||||
'''
|
||||
|
||||
# global model
|
||||
if request.stream:
|
||||
raise NotImplementedError("Streaming is not yet implements")
|
||||
|
||||
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)
|
||||
|
||||
output = model.generate(prompt=request.prompt,
|
||||
n_predict = request.max_tokens,
|
||||
top_k = 20,
|
||||
top_p = request.top_p,
|
||||
temp=request.temperature,
|
||||
n_batch = 1024,
|
||||
repeat_penalty = 1.2,
|
||||
repeat_last_n = 10,
|
||||
context_erase = 0)
|
||||
|
||||
|
||||
return CompletionResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[dict(CompletionChoice(
|
||||
text=output,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason='stop'
|
||||
))],
|
||||
usage={
|
||||
'prompt_tokens': 0, #TODO how to compute this?
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0
|
||||
}
|
||||
)
|
||||
|
||||
n_predict=request.max_tokens,
|
||||
streaming=request.stream,
|
||||
top_k=20,
|
||||
top_p=request.top_p,
|
||||
temp=request.temperature,
|
||||
n_batch=1024,
|
||||
repeat_penalty=1.2,
|
||||
repeat_last_n=10)
|
||||
|
||||
# If streaming, we need to return a StreamingResponse
|
||||
if request.stream:
|
||||
base_chunk = CompletionStreamResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[]
|
||||
)
|
||||
return StreamingResponse((response for response in stream_completion(output, base_chunk)),
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
return CompletionResponse(
|
||||
id=str(uuid4()),
|
||||
created=time.time(),
|
||||
model=request.model,
|
||||
choices=[dict(CompletionChoice(
|
||||
text=output,
|
||||
index=0,
|
||||
logprobs=-1,
|
||||
finish_reason='stop'
|
||||
))],
|
||||
usage={
|
||||
'prompt_tokens': 0, #TODO how to compute this?
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0
|
||||
}
|
||||
)
|
||||
|
@ -23,6 +23,25 @@ def test_completion():
|
||||
assert len(response['choices'][0]['text']) > len(prompt)
|
||||
print(response)
|
||||
|
||||
|
||||
def test_streaming_completion():
|
||||
model = "gpt4all-j-v1.3-groovy"
|
||||
prompt = "Who is Michael Jordan?"
|
||||
tokens = []
|
||||
for resp in openai.Completion.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=50,
|
||||
temperature=0.28,
|
||||
top_p=0.95,
|
||||
n=1,
|
||||
echo=True,
|
||||
stream=True):
|
||||
tokens.append(resp.choices[0].text)
|
||||
|
||||
assert (len(tokens) > 0)
|
||||
assert (len("".join(tokens)) > len(prompt))
|
||||
|
||||
# def test_chat_completions():
|
||||
# model = "gpt4all-j-v1.3-groovy"
|
||||
# prompt = "Who is Michael Jordan?"
|
||||
@ -30,6 +49,3 @@ def test_completion():
|
||||
# model=model,
|
||||
# messages=[]
|
||||
# )
|
||||
|
||||
|
||||
|
||||
|
@ -5,6 +5,6 @@ requests>=2.24.0
|
||||
ujson>=2.0.2
|
||||
fastapi>=0.95.0
|
||||
Jinja2>=3.0
|
||||
gpt4all==0.2.3
|
||||
gpt4all==1.0.1
|
||||
pytest
|
||||
openai
|
Loading…
Reference in New Issue
Block a user