Handle edge cases when generating embeddings (#1215)

* Handle edge cases when generating embeddings
* Improve Python handling & add llmodel_c.h note
- In the Python bindings fail fast with a ValueError when text is empty
- Advice other bindings authors to do likewise in llmodel_c.h
This commit is contained in:
cosmic-snow 2023-07-17 22:21:03 +02:00 committed by GitHub
parent 1e74171a7b
commit 2d02c65177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 1 deletions

View File

@ -168,10 +168,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
{
if (model == nullptr || text == nullptr || !strlen(text)) {
*embedding_size = 0;
return nullptr;
}
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
if(embedding == nullptr) {
if (embedding == nullptr) {
*embedding_size = 0;
return nullptr;
}

View File

@ -173,6 +173,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
/**
* Generate an embedding using the model.
* NOTE: If given NULL pointers for the model or text, or an empty text, a NULL pointer will be
* returned. Bindings should signal an error when NULL is the return value.
* @param model A pointer to the llmodel_model instance.
* @param text A string representing the text to generate an embedding for.
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length

View File

@ -251,6 +251,8 @@ class LLModel:
self,
text: str
) -> list[float]:
if not text:
raise ValueError("Text must not be None or empty")
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8'))
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))

View File

@ -3,6 +3,7 @@ from io import StringIO
from gpt4all import GPT4All, Embed4All
import time
import pytest
def test_inference():
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
@ -107,3 +108,9 @@ def test_embedding():
#for i, value in enumerate(output):
#print(f'Value at index {i}: {value}')
assert len(output) == 384
def test_empty_embedding():
text = ''
embedder = Embed4All()
with pytest.raises(ValueError):
output = embedder.embed(text)