Hybrid search (#2969)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
AT 2024-09-26 11:58:48 -04:00 committed by GitHub
parent 117a8e7faa
commit 10d2375bf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 303 additions and 23 deletions

View File

@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
### Added
- Add bm25 hybrid search to localdocs ([#2969](https://github.com/nomic-ai/gpt4all/pull/2969))
## [3.3.0] - 2024-09-20
### Added
@ -119,6 +124,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))
[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.3.0...HEAD
[3.3.0]: https://github.com/nomic-ai/gpt4all/compare/v3.2.1...v3.3.0
[3.2.1]: https://github.com/nomic-ai/gpt4all/compare/v3.2.0...v3.2.1
[3.2.0]: https://github.com/nomic-ai/gpt4all/compare/v3.1.1...v3.2.0

View File

@ -103,6 +103,20 @@ static const QString INIT_DB_SQL[] = {
tokens integer default 0 not null,
foreign key(document_id) references documents(id)
);
)"_s, uR"(
create virtual table chunks_fts using fts5(
id unindexed,
document_id unindexed,
chunk_text,
file,
title,
author,
subject,
keywords,
content='chunks',
content_rowid='id',
tokenize='porter'
);
)"_s, uR"(
create table collections(
id integer primary key,
@ -151,7 +165,13 @@ static const QString INSERT_CHUNK_SQL = uR"(
file, title, author, subject, keywords, page, line_from, line_to, words)
values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
returning id;
)"_s;
)"_s;
static const QString INSERT_CHUNK_FTS_SQL = uR"(
insert into chunks_fts(document_id, chunk_text,
file, title, author, subject, keywords)
values(?, ?, ?, ?, ?, ?, ?);
)"_s;
static const QString DELETE_CHUNKS_SQL[] = {
uR"(
@ -161,12 +181,14 @@ static const QString DELETE_CHUNKS_SQL[] = {
);
)"_s, uR"(
delete from chunks where document_id = ?;
)"_s, uR"(
delete from chunks_fts where document_id = ?;
)"_s,
};
static const QString SELECT_CHUNKS_BY_DOCUMENT_SQL = uR"(
select id from chunks WHERE document_id = ?;
)"_s;
)"_s;
static const QString SELECT_CHUNKS_SQL = uR"(
select c.id, d.document_time, d.document_path, c.chunk_text, c.file, c.title, c.author, c.page, c.line_from, c.line_to, co.name
@ -190,14 +212,21 @@ static const QString SELECT_UNCOMPLETED_CHUNKS_SQL = uR"(
from embeddings e
where e.chunk_id = c.id and e.model = co.embedding_model
);
)"_s;
)"_s;
static const QString SELECT_COUNT_CHUNKS_SQL = uR"(
select count(c.id)
from chunks c
join documents d on d.id = c.document_id
where d.folder_id = ?;
)"_s;
)"_s;
static const QString SELECT_CHUNKS_FTS_SQL = uR"(
select id, bm25(chunks_fts) as score
from chunks_fts
where chunks_fts match ?
order by score limit %1;
)"_s;
static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords,
@ -219,6 +248,18 @@ static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, c
if (!q.exec() || !q.next())
return false;
*chunk_id = q.value(0).toInt();
if (!q.prepare(INSERT_CHUNK_FTS_SQL))
return false;
q.addBindValue(document_id);
q.addBindValue(chunk_text);
q.addBindValue(file);
q.addBindValue(title);
q.addBindValue(author);
q.addBindValue(subject);
q.addBindValue(keywords);
if (!q.exec())
return false;
return true;
}
@ -424,6 +465,7 @@ static bool selectAllFromCollections(QSqlQuery &q, QList<CollectionItem> *collec
return false;
break;
case 2:
case 3:
if (!q.prepare(SELECT_COLLECTIONS_SQL_V2))
return false;
break;
@ -770,6 +812,12 @@ static const QString GET_COLLECTION_EMBEDDINGS_SQL = uR"(
where co.name in ('%1');
)"_s;
static const QString GET_CHUNK_EMBEDDINGS_SQL = uR"(
select e.chunk_id, e.embedding
from embeddings e
where e.chunk_id in (%1);
)"_s;
static const QString GET_CHUNK_FILE_SQL = uR"(
select file from chunks where id = ?;
)"_s;
@ -1858,19 +1906,13 @@ void Database::removeFolderFromWatch(const QString &path)
m_watchedPaths -= QSet(children.begin(), children.end());
}
QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections, int nNeighbors)
QList<int> Database::searchEmbeddingsHelper(const std::vector<float> &query, QSqlQuery &q, int nNeighbors)
{
constexpr int BATCH_SIZE = 2048;
const int n_embd = query.size();
const us::metric_punned_t metric(n_embd, us::metric_kind_t::ip_k); // inner product
QSqlQuery q(m_db);
if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) {
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
return {};
}
us::executor_default_t executor(std::thread::hardware_concurrency());
us::exact_search_t search;
@ -1882,6 +1924,7 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
struct Result { int chunkId; us::distance_punned_t dist; };
QList<Result> results;
// The q parameter is expected to be the result of a QSqlQuery returning (chunk_id, embedding) pairs
while (q.at() != QSql::AfterLastRow) { // batches
batchChunkIds.clear();
batchEmbeddings.clear();
@ -1937,6 +1980,223 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
return chunkIds;
}
QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections,
int nNeighbors)
{
QSqlQuery q(m_db);
if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) {
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
return {};
}
return searchEmbeddingsHelper(query, q, nNeighbors);
}
QList<int> Database::scoreChunks(const std::vector<float> &query, const QList<int> &chunks)
{
QList<QString> chunkStrings;
for (int id : chunks)
chunkStrings << QString::number(id);
QSqlQuery q(m_db);
if (!q.exec(GET_CHUNK_EMBEDDINGS_SQL.arg(chunkStrings.join(", ")))) {
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
return {};
}
return searchEmbeddingsHelper(query, q, chunks.size());
}
QList<Database::BM25Query> Database::queriesForFTS5(const QString &input)
{
// Escape double quotes by adding a second double quote
QString escapedInput = input;
escapedInput.replace("\"", "\"\"");
static QRegularExpression spaces("\\s+");
QStringList oWords = escapedInput.split(spaces, Qt::SkipEmptyParts);
QList<BM25Query> queries;
// Start by trying to match the entire input
BM25Query e;
e.isExact = true;
e.input = oWords.join(" ");
e.query = "\"" + oWords.join(" ") + "\"";
e.qlength = oWords.size();
e.ilength = oWords.size();
queries << e;
// https://github.com/igorbrigadir/stopwords?tab=readme-ov-file
// Lucene, Solr, Elastisearch
static const QSet<QString> stopWords = {
"a", "an", "and", "are", "as", "at", "be", "but", "by",
"for", "if", "in", "into", "is", "it", "no", "not", "of",
"on", "or", "such", "that", "the", "their", "then", "there",
"these", "they", "this", "to", "was", "will", "with"
};
QStringList quotedWords;
for (const QString &w : oWords)
if (!stopWords.contains(w.toLower()))
quotedWords << "\"" + w + "\"";
BM25Query b;
b.input = oWords.join(" ");
b.query = "(" + quotedWords.join(" OR ") + ")";
b.qlength = 1; // length of phrase
b.ilength = oWords.size();
b.rlength = oWords.size() - quotedWords.size();
queries << b;
return queries;
}
QList<int> Database::searchBM25(const QString &query, const QList<QString> &collections, BM25Query &bm25q, int k)
{
struct SearchResult { int chunkId; float score; };
QList<BM25Query> bm25Queries = queriesForFTS5(query);
QSqlQuery sqlQuery(m_db);
sqlQuery.prepare(SELECT_CHUNKS_FTS_SQL.arg(k));
QList<SearchResult> results;
for (auto &bm25Query : std::as_const(bm25Queries)) {
sqlQuery.addBindValue(bm25Query.query);
if (!sqlQuery.exec()) {
qWarning() << "Database ERROR: Failed to execute BM25 query:" << sqlQuery.lastError();
return {};
}
if (sqlQuery.next()) {
// Save the query that was used to produce results
bm25q = bm25Query;
break;
}
}
do {
const int chunkId = sqlQuery.value(0).toInt();
const float score = sqlQuery.value(1).toFloat();
results.append({chunkId, score});
} while (sqlQuery.next());
k = qMin(k, results.size());
std::partial_sort(
results.begin(), results.begin() + k, results.end(),
[](const SearchResult &a, const SearchResult &b) { return a.score < b.score; }
);
QList<int> chunkIds;
chunkIds.reserve(k);
for (int i = 0; i < k; i++)
chunkIds << results[i].chunkId;
return chunkIds;
}
float Database::computeBM25Weight(const Database::BM25Query &bm25q)
{
float bmWeight = 0.0f;
if (bm25q.isExact) {
bmWeight = 0.9f; // the highest we give
} else {
// qlength is the length of the phrases in the query by number of distinct words
// ilength is the length of the natural language query by number of distinct words
// rlength is the number of stop words removed from the natural language query to form the query
// calculate the query length weight based on the ratio of query terms to meaningful terms.
// this formula adjusts the weight with the empirically determined insight that BM25's
// effectiveness decreases as query length increases.
float queryLengthWeight = 1 / powf(float(bm25q.ilength - bm25q.rlength), 2);
queryLengthWeight = qBound(0.0f, queryLengthWeight, 1.0f);
// the weighting is bound between 1/4 and 3/4 which was determined empirically to work well
// with the beir nfcorpus, scifact, fiqa and trec-covid datasets along with our embedding
// model
bmWeight = 0.25f + queryLengthWeight * 0.50f;
}
#if 0
qDebug()
<< "bm25q.type" << bm25q.type
<< "bm25q.qlength" << bm25q.qlength
<< "bm25q.ilength" << bm25q.ilength
<< "bm25q.rlength" << bm25q.rlength
<< "bmWeight" << bmWeight;
#endif
return bmWeight;
}
QList<int> Database::reciprocalRankFusion(const std::vector<float> &query, const QList<int> &embeddingResults,
const QList<int> &bm25Results, const BM25Query &bm25q, int k)
{
// We default to the embedding results and augment with bm25 if any
QList<int> results = embeddingResults;
QList<int> missingScores;
QHash<int, int> bm25Ranks;
for (int i = 0; i < bm25Results.size(); ++i) {
if (!results.contains(bm25Results[i]))
missingScores.append(bm25Results[i]);
bm25Ranks[bm25Results[i]] = i + 1;
}
if (!missingScores.isEmpty()) {
QList<int> scored = scoreChunks(query, missingScores);
results << scored;
}
QHash<int, int> embeddingRanks;
for (int i = 0; i < results.size(); ++i)
embeddingRanks[results[i]] = i + 1;
const float bmWeight = bm25Results.isEmpty() ? 0 : computeBM25Weight(bm25q);
// From the paper: "Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"
// doi: 10.1145/1571941.157211
const int fusion_k = 60;
std::stable_sort(
results.begin(), results.end(),
[&](const int &a, const int &b) {
// Reciprocal Rank Fusion (RRF)
const int aBm25Rank = bm25Ranks.value(a, bm25Results.size() + 1);
const int aEmbeddingRank = embeddingRanks.value(a, embeddingResults.size() + 1);
Q_ASSERT(embeddingRanks.contains(a));
const int bBm25Rank = bm25Ranks.value(b, bm25Results.size() + 1);
const int bEmbeddingRank = embeddingRanks.value(b, embeddingResults.size() + 1);
Q_ASSERT(embeddingRanks.contains(b));
const float aBm25Score = 1.0f / (fusion_k + aBm25Rank);
const float bBm25Score = 1.0f / (fusion_k + bBm25Rank);
const float aEmbeddingScore = 1.0f / (fusion_k + aEmbeddingRank);
const float bEmbeddingScore = 1.0f / (fusion_k + bEmbeddingRank);
const float aWeightedScore = bmWeight * aBm25Score + (1.f - bmWeight) * aEmbeddingScore;
const float bWeightedScore = bmWeight * bBm25Score + (1.f - bmWeight) * bEmbeddingScore;
// Higher RRF score means better ranking, so we use greater than for sorting
return aWeightedScore > bWeightedScore;
}
);
k = qMin(k, results.size());
results.resize(k);
return results;
}
QList<int> Database::searchDatabase(const QString &query, const QList<QString> &collections, int k)
{
std::vector<float> queryEmbd = m_embLLM->generateQueryEmbedding(query);
if (queryEmbd.empty()) {
qDebug() << "ERROR: generating embeddings returned a null result";
return { };
}
const QList<int> embeddingResults = searchEmbeddings(queryEmbd, collections, k);
BM25Query bm25q;
const QList<int> bm25Results = searchBM25(query, collections, bm25q, k);
return reciprocalRankFusion(queryEmbd, embeddingResults, bm25Results, bm25q, k);
}
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
QList<ResultInfo> *results)
{
@ -1944,13 +2204,7 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
#endif
std::vector<float> queryEmbd = m_embLLM->generateQueryEmbedding(text);
if (queryEmbd.empty()) {
qDebug() << "ERROR: generating embeddings returned a null result";
return;
}
QList<int> searchResults = searchEmbeddings(queryEmbd, collections, retrievalSize);
QList<int> searchResults = searchDatabase(text, collections, retrievalSize);
if (searchResults.isEmpty())
return;
@ -1960,10 +2214,9 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
return;
}
QHash<int, ResultInfo> tempResults;
while (q.next()) {
#if defined(DEBUG)
const int rowid = q.value(0).toInt();
#endif
const QString document_path = q.value(2).toString();
const QString chunk_text = q.value(3).toString();
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
@ -1985,12 +2238,16 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
info.page = page;
info.from = from;
info.to = to;
results->append(info);
tempResults.insert(rowid, info);
#if defined(DEBUG)
qDebug() << "retrieve rowid:" << rowid
<< "chunk_text:" << chunk_text;
#endif
}
for (int id : searchResults)
if (tempResults.contains(id))
results->append(tempResults.value(id));
}
// FIXME This is very slow and non-interruptible and when we close the application and we're

View File

@ -35,7 +35,7 @@ class QTimer;
// minimum supported version
static const int LOCALDOCS_MIN_VER = 1;
// current version
static const int LOCALDOCS_VERSION = 2;
static const int LOCALDOCS_VERSION = 3;
struct DocumentInfo
{
@ -206,7 +206,24 @@ private:
bool cleanDB();
void addFolderToWatch(const QString &path);
void removeFolderFromWatch(const QString &path);
QList<int> searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections, int nNeighbors);
static QList<int> searchEmbeddingsHelper(const std::vector<float> &query, QSqlQuery &q, int nNeighbors);
QList<int> searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections,
int nNeighbors);
struct BM25Query {
QString input;
QString query;
bool isExact = false;
int qlength = 0;
int ilength = 0;
int rlength = 0;
};
QList<Database::BM25Query> queriesForFTS5(const QString &input);
QList<int> searchBM25(const QString &query, const QList<QString> &collections, BM25Query &bm25q, int k);
QList<int> scoreChunks(const std::vector<float> &query, const QList<int> &chunks);
float computeBM25Weight(const BM25Query &bm25q);
QList<int> reciprocalRankFusion(const std::vector<float> &query, const QList<int> &embeddingResults,
const QList<int> &bm25Results, const BM25Query &bm25q, int k);
QList<int> searchDatabase(const QString &query, const QList<QString> &collections, int k);
void setStartUpdateTime(CollectionItem &item);
void setLastUpdateTime(CollectionItem &item);