2024-03-07 13:03:18 -05:00
|
|
|
import random
|
|
|
|
|
2023-05-21 21:42:34 -04:00
|
|
|
import chromadb
|
2023-05-13 13:14:59 -04:00
|
|
|
import posthog
|
2023-05-21 21:42:34 -04:00
|
|
|
from chromadb.config import Settings
|
2024-03-07 13:03:18 -05:00
|
|
|
from chromadb.utils import embedding_functions
|
2023-05-13 13:14:59 -04:00
|
|
|
|
2024-03-07 13:03:18 -05:00
|
|
|
# Intercept calls to posthog
|
2023-05-13 13:14:59 -04:00
|
|
|
posthog.capture = lambda *args, **kwargs: None
|
|
|
|
|
|
|
|
|
2024-03-07 13:03:18 -05:00
|
|
|
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
2023-05-13 13:14:59 -04:00
|
|
|
|
|
|
|
|
2024-03-07 13:03:18 -05:00
|
|
|
class ChromaCollector():
|
2023-05-13 13:14:59 -04:00
|
|
|
def __init__(self):
|
2024-03-07 13:03:18 -05:00
|
|
|
name = ''.join(random.choice('ab') for _ in range(10))
|
2023-05-13 13:14:59 -04:00
|
|
|
|
2024-03-07 13:03:18 -05:00
|
|
|
self.name = name
|
2023-05-13 13:14:59 -04:00
|
|
|
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
2024-03-07 13:03:18 -05:00
|
|
|
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
2023-05-13 13:14:59 -04:00
|
|
|
self.ids = []
|
|
|
|
|
|
|
|
def add(self, texts: list[str]):
|
2023-05-20 21:27:22 -04:00
|
|
|
if len(texts) == 0:
|
|
|
|
return
|
|
|
|
|
2023-05-13 13:14:59 -04:00
|
|
|
self.ids = [f"id{i}" for i in range(len(texts))]
|
|
|
|
self.collection.add(documents=texts, ids=self.ids)
|
|
|
|
|
2023-05-25 09:22:45 -04:00
|
|
|
def get_documents_ids_distances(self, search_strings: list[str], n_results: int):
|
2023-05-13 13:14:59 -04:00
|
|
|
n_results = min(len(self.ids), n_results)
|
2023-05-20 21:27:22 -04:00
|
|
|
if n_results == 0:
|
2023-05-25 22:25:36 -04:00
|
|
|
return [], [], []
|
2023-05-20 21:27:22 -04:00
|
|
|
|
2023-05-25 09:22:45 -04:00
|
|
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances'])
|
2023-05-14 21:19:29 -04:00
|
|
|
documents = result['documents'][0]
|
|
|
|
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
|
2023-05-25 09:22:45 -04:00
|
|
|
distances = result['distances'][0]
|
|
|
|
return documents, ids, distances
|
2023-05-14 21:19:29 -04:00
|
|
|
|
|
|
|
# Get chunks by similarity
|
|
|
|
def get(self, search_strings: list[str], n_results: int) -> list[str]:
|
2023-05-25 09:22:45 -04:00
|
|
|
documents, _, _ = self.get_documents_ids_distances(search_strings, n_results)
|
2023-05-14 21:19:29 -04:00
|
|
|
return documents
|
2023-05-13 13:14:59 -04:00
|
|
|
|
2023-05-14 21:19:29 -04:00
|
|
|
# Get ids by similarity
|
2023-05-13 13:14:59 -04:00
|
|
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
2023-05-25 09:22:45 -04:00
|
|
|
_, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
2023-05-14 21:19:29 -04:00
|
|
|
return ids
|
|
|
|
|
|
|
|
# Get chunks by similarity and then sort by insertion order
|
|
|
|
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
|
2023-05-25 09:22:45 -04:00
|
|
|
documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
|
2023-05-14 21:19:29 -04:00
|
|
|
return [x for _, x in sorted(zip(ids, documents))]
|
|
|
|
|
2023-05-25 09:22:45 -04:00
|
|
|
# Multiply distance by factor within [0, time_weight] where more recent is lower
|
|
|
|
def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]:
|
|
|
|
if len(self.ids) <= 1:
|
|
|
|
return distances.copy()
|
|
|
|
|
|
|
|
return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)]
|
|
|
|
|
2023-05-14 21:19:29 -04:00
|
|
|
# Get ids by similarity and then sort by insertion order
|
2023-05-25 09:22:45 -04:00
|
|
|
def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]:
|
|
|
|
do_time_weight = time_weight > 0
|
|
|
|
if not (do_time_weight and n_initial is not None):
|
|
|
|
n_initial = n_results
|
|
|
|
elif n_initial == -1:
|
|
|
|
n_initial = len(self.ids)
|
|
|
|
|
|
|
|
if n_initial < n_results:
|
|
|
|
raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}")
|
|
|
|
|
|
|
|
_, ids, distances = self.get_documents_ids_distances(search_strings, n_initial)
|
|
|
|
if do_time_weight:
|
|
|
|
distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight)
|
|
|
|
results = zip(ids, distances, distances_w)
|
|
|
|
results = sorted(results, key=lambda x: x[2])[:n_results]
|
|
|
|
results = sorted(results, key=lambda x: x[0])
|
|
|
|
ids = [x[0] for x in results]
|
|
|
|
|
2023-05-14 21:19:29 -04:00
|
|
|
return sorted(ids)
|
2023-05-13 13:14:59 -04:00
|
|
|
|
|
|
|
def clear(self):
|
2023-05-20 21:27:22 -04:00
|
|
|
self.ids = []
|
2024-03-07 13:03:18 -05:00
|
|
|
self.chroma_client.delete_collection(name=self.name)
|
|
|
|
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
2023-05-13 13:14:59 -04:00
|
|
|
|
|
|
|
|
|
|
|
def make_collector():
|
2024-03-07 13:03:18 -05:00
|
|
|
return ChromaCollector()
|
2023-05-13 13:14:59 -04:00
|
|
|
|
|
|
|
|
|
|
|
def add_chunks_to_collector(chunks, collector):
|
|
|
|
collector.clear()
|
|
|
|
collector.add(chunks)
|