Add embeddings endpoint for gpt4all-api (#1314)

* Add embeddings endpoint

* Add test for embedding endpoint
This commit is contained in:
David Okpare 2023-08-10 15:43:07 +01:00 committed by GitHub
parent 108d950874
commit 889c8d1758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 0 deletions

View File

@ -0,0 +1,65 @@
from typing import List, Union
from fastapi import APIRouter
from api_v1.settings import settings
from gpt4all import Embed4All
from pydantic import BaseModel, Field
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
class EmbeddingRequest(BaseModel):
model: str = Field(
settings.model, description="The model to generate an embedding from."
)
input: Union[str, List[str], List[int], List[List[int]]] = Field(
..., description="Input text to embed, encoded as a string or array of tokens."
)
class EmbeddingUsage(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
class Embedding(BaseModel):
index: int = 0
object: str = "embedding"
embedding: List[float]
class EmbeddingResponse(BaseModel):
object: str = "list"
model: str
data: List[Embedding]
usage: EmbeddingUsage
router = APIRouter(prefix="/embeddings", tags=["Embedding Endpoints"])
embedder = Embed4All()
def get_embedding(data: EmbeddingRequest) -> EmbeddingResponse:
"""
Calculates the embedding for the given input using a specified model.
Args:
data (EmbeddingRequest): An EmbeddingRequest object containing the input data
and model name.
Returns:
EmbeddingResponse: An EmbeddingResponse object encapsulating the calculated embedding,
usage info, and the model name.
"""
embedding = embedder.embed(data.input)
return EmbeddingResponse(
data=[Embedding(embedding=embedding)], usage=EmbeddingUsage(), model=data.model
)
@router.post("/", response_model=EmbeddingResponse)
def embeddings(data: EmbeddingRequest):
"""
Creates a GPT4All embedding
"""
return get_embedding(data)

View File

@ -1,6 +1,8 @@
"""
Use the OpenAI python API to test gpt4all models.
"""
from typing import List, get_args
import openai
openai.api_base = "http://localhost:4891/v1"
@ -43,3 +45,15 @@ def test_batched_completion():
)
assert len(response['choices'][0]['text']) > len(prompt)
assert len(response['choices']) == 3
def test_embedding():
model = "ggml-all-MiniLM-L6-v2-f16.bin"
prompt = "Who is Michael Jordan?"
response = openai.Embedding.create(model=model, input=prompt)
output = response["data"][0]["embedding"]
args = get_args(List[float])
assert response["model"] == model
assert isinstance(output, list)
assert all(isinstance(x, args) for x in output)