mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add embeddings endpoint for gpt4all-api (#1314)
* Add embeddings endpoint * Add test for embedding endpoint
This commit is contained in:
parent
108d950874
commit
889c8d1758
65
gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py
Normal file
65
gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import List, Union
|
||||
from fastapi import APIRouter
|
||||
from api_v1.settings import settings
|
||||
from gpt4all import Embed4All
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
model: str = Field(
|
||||
settings.model, description="The model to generate an embedding from."
|
||||
)
|
||||
input: Union[str, List[str], List[int], List[List[int]]] = Field(
|
||||
..., description="Input text to embed, encoded as a string or array of tokens."
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingUsage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
index: int = 0
|
||||
object: str = "embedding"
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
object: str = "list"
|
||||
model: str
|
||||
data: List[Embedding]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
router = APIRouter(prefix="/embeddings", tags=["Embedding Endpoints"])
|
||||
|
||||
embedder = Embed4All()
|
||||
|
||||
|
||||
def get_embedding(data: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""
|
||||
Calculates the embedding for the given input using a specified model.
|
||||
|
||||
Args:
|
||||
data (EmbeddingRequest): An EmbeddingRequest object containing the input data
|
||||
and model name.
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse: An EmbeddingResponse object encapsulating the calculated embedding,
|
||||
usage info, and the model name.
|
||||
"""
|
||||
embedding = embedder.embed(data.input)
|
||||
return EmbeddingResponse(
|
||||
data=[Embedding(embedding=embedding)], usage=EmbeddingUsage(), model=data.model
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=EmbeddingResponse)
|
||||
def embeddings(data: EmbeddingRequest):
|
||||
"""
|
||||
Creates a GPT4All embedding
|
||||
"""
|
||||
return get_embedding(data)
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Use the OpenAI python API to test gpt4all models.
|
||||
"""
|
||||
from typing import List, get_args
|
||||
|
||||
import openai
|
||||
|
||||
openai.api_base = "http://localhost:4891/v1"
|
||||
@ -43,3 +45,15 @@ def test_batched_completion():
|
||||
)
|
||||
assert len(response['choices'][0]['text']) > len(prompt)
|
||||
assert len(response['choices']) == 3
|
||||
|
||||
|
||||
def test_embedding():
|
||||
model = "ggml-all-MiniLM-L6-v2-f16.bin"
|
||||
prompt = "Who is Michael Jordan?"
|
||||
response = openai.Embedding.create(model=model, input=prompt)
|
||||
output = response["data"][0]["embedding"]
|
||||
args = get_args(List[float])
|
||||
|
||||
assert response["model"] == model
|
||||
assert isinstance(output, list)
|
||||
assert all(isinstance(x, args) for x in output)
|
||||
|
Loading…
Reference in New Issue
Block a user