mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Handle edge cases when generating embeddings (#1215)
* Handle edge cases when generating embeddings * Improve Python handling & add llmodel_c.h note - In the Python bindings fail fast with a ValueError when text is empty - Advice other bindings authors to do likewise in llmodel_c.h
This commit is contained in:
parent
1e74171a7b
commit
2d02c65177
@ -168,10 +168,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
|
||||
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
|
||||
{
|
||||
if (model == nullptr || text == nullptr || !strlen(text)) {
|
||||
*embedding_size = 0;
|
||||
return nullptr;
|
||||
}
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
|
||||
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
|
||||
if(embedding == nullptr) {
|
||||
if (embedding == nullptr) {
|
||||
*embedding_size = 0;
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -173,6 +173,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
|
||||
/**
|
||||
* Generate an embedding using the model.
|
||||
* NOTE: If given NULL pointers for the model or text, or an empty text, a NULL pointer will be
|
||||
* returned. Bindings should signal an error when NULL is the return value.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param text A string representing the text to generate an embedding for.
|
||||
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length
|
||||
|
@ -251,6 +251,8 @@ class LLModel:
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
if not text:
|
||||
raise ValueError("Text must not be None or empty")
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
|
@ -3,6 +3,7 @@ from io import StringIO
|
||||
|
||||
from gpt4all import GPT4All, Embed4All
|
||||
import time
|
||||
import pytest
|
||||
|
||||
def test_inference():
|
||||
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
||||
@ -107,3 +108,9 @@ def test_embedding():
|
||||
#for i, value in enumerate(output):
|
||||
#print(f'Value at index {i}: {value}')
|
||||
assert len(output) == 384
|
||||
|
||||
def test_empty_embedding():
|
||||
text = ''
|
||||
embedder = Embed4All()
|
||||
with pytest.raises(ValueError):
|
||||
output = embedder.embed(text)
|
||||
|
Loading…
Reference in New Issue
Block a user