mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
parent
117a8e7faa
commit
10d2375bf3
@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Add bm25 hybrid search to localdocs ([#2969](https://github.com/nomic-ai/gpt4all/pull/2969))
|
||||
|
||||
## [3.3.0] - 2024-09-20
|
||||
|
||||
### Added
|
||||
@ -119,6 +124,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
|
||||
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))
|
||||
|
||||
[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.3.0...HEAD
|
||||
[3.3.0]: https://github.com/nomic-ai/gpt4all/compare/v3.2.1...v3.3.0
|
||||
[3.2.1]: https://github.com/nomic-ai/gpt4all/compare/v3.2.0...v3.2.1
|
||||
[3.2.0]: https://github.com/nomic-ai/gpt4all/compare/v3.1.1...v3.2.0
|
||||
|
@ -103,6 +103,20 @@ static const QString INIT_DB_SQL[] = {
|
||||
tokens integer default 0 not null,
|
||||
foreign key(document_id) references documents(id)
|
||||
);
|
||||
)"_s, uR"(
|
||||
create virtual table chunks_fts using fts5(
|
||||
id unindexed,
|
||||
document_id unindexed,
|
||||
chunk_text,
|
||||
file,
|
||||
title,
|
||||
author,
|
||||
subject,
|
||||
keywords,
|
||||
content='chunks',
|
||||
content_rowid='id',
|
||||
tokenize='porter'
|
||||
);
|
||||
)"_s, uR"(
|
||||
create table collections(
|
||||
id integer primary key,
|
||||
@ -151,7 +165,13 @@ static const QString INSERT_CHUNK_SQL = uR"(
|
||||
file, title, author, subject, keywords, page, line_from, line_to, words)
|
||||
values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
returning id;
|
||||
)"_s;
|
||||
)"_s;
|
||||
|
||||
static const QString INSERT_CHUNK_FTS_SQL = uR"(
|
||||
insert into chunks_fts(document_id, chunk_text,
|
||||
file, title, author, subject, keywords)
|
||||
values(?, ?, ?, ?, ?, ?, ?);
|
||||
)"_s;
|
||||
|
||||
static const QString DELETE_CHUNKS_SQL[] = {
|
||||
uR"(
|
||||
@ -161,12 +181,14 @@ static const QString DELETE_CHUNKS_SQL[] = {
|
||||
);
|
||||
)"_s, uR"(
|
||||
delete from chunks where document_id = ?;
|
||||
)"_s, uR"(
|
||||
delete from chunks_fts where document_id = ?;
|
||||
)"_s,
|
||||
};
|
||||
|
||||
static const QString SELECT_CHUNKS_BY_DOCUMENT_SQL = uR"(
|
||||
select id from chunks WHERE document_id = ?;
|
||||
)"_s;
|
||||
)"_s;
|
||||
|
||||
static const QString SELECT_CHUNKS_SQL = uR"(
|
||||
select c.id, d.document_time, d.document_path, c.chunk_text, c.file, c.title, c.author, c.page, c.line_from, c.line_to, co.name
|
||||
@ -190,14 +212,21 @@ static const QString SELECT_UNCOMPLETED_CHUNKS_SQL = uR"(
|
||||
from embeddings e
|
||||
where e.chunk_id = c.id and e.model = co.embedding_model
|
||||
);
|
||||
)"_s;
|
||||
)"_s;
|
||||
|
||||
static const QString SELECT_COUNT_CHUNKS_SQL = uR"(
|
||||
select count(c.id)
|
||||
from chunks c
|
||||
join documents d on d.id = c.document_id
|
||||
where d.folder_id = ?;
|
||||
)"_s;
|
||||
)"_s;
|
||||
|
||||
static const QString SELECT_CHUNKS_FTS_SQL = uR"(
|
||||
select id, bm25(chunks_fts) as score
|
||||
from chunks_fts
|
||||
where chunks_fts match ?
|
||||
order by score limit %1;
|
||||
)"_s;
|
||||
|
||||
static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file,
|
||||
const QString &title, const QString &author, const QString &subject, const QString &keywords,
|
||||
@ -219,6 +248,18 @@ static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, c
|
||||
if (!q.exec() || !q.next())
|
||||
return false;
|
||||
*chunk_id = q.value(0).toInt();
|
||||
|
||||
if (!q.prepare(INSERT_CHUNK_FTS_SQL))
|
||||
return false;
|
||||
q.addBindValue(document_id);
|
||||
q.addBindValue(chunk_text);
|
||||
q.addBindValue(file);
|
||||
q.addBindValue(title);
|
||||
q.addBindValue(author);
|
||||
q.addBindValue(subject);
|
||||
q.addBindValue(keywords);
|
||||
if (!q.exec())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -424,6 +465,7 @@ static bool selectAllFromCollections(QSqlQuery &q, QList<CollectionItem> *collec
|
||||
return false;
|
||||
break;
|
||||
case 2:
|
||||
case 3:
|
||||
if (!q.prepare(SELECT_COLLECTIONS_SQL_V2))
|
||||
return false;
|
||||
break;
|
||||
@ -770,6 +812,12 @@ static const QString GET_COLLECTION_EMBEDDINGS_SQL = uR"(
|
||||
where co.name in ('%1');
|
||||
)"_s;
|
||||
|
||||
static const QString GET_CHUNK_EMBEDDINGS_SQL = uR"(
|
||||
select e.chunk_id, e.embedding
|
||||
from embeddings e
|
||||
where e.chunk_id in (%1);
|
||||
)"_s;
|
||||
|
||||
static const QString GET_CHUNK_FILE_SQL = uR"(
|
||||
select file from chunks where id = ?;
|
||||
)"_s;
|
||||
@ -1858,19 +1906,13 @@ void Database::removeFolderFromWatch(const QString &path)
|
||||
m_watchedPaths -= QSet(children.begin(), children.end());
|
||||
}
|
||||
|
||||
QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections, int nNeighbors)
|
||||
QList<int> Database::searchEmbeddingsHelper(const std::vector<float> &query, QSqlQuery &q, int nNeighbors)
|
||||
{
|
||||
constexpr int BATCH_SIZE = 2048;
|
||||
|
||||
const int n_embd = query.size();
|
||||
const us::metric_punned_t metric(n_embd, us::metric_kind_t::ip_k); // inner product
|
||||
|
||||
QSqlQuery q(m_db);
|
||||
if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) {
|
||||
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
|
||||
return {};
|
||||
}
|
||||
|
||||
us::executor_default_t executor(std::thread::hardware_concurrency());
|
||||
us::exact_search_t search;
|
||||
|
||||
@ -1882,6 +1924,7 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
|
||||
struct Result { int chunkId; us::distance_punned_t dist; };
|
||||
QList<Result> results;
|
||||
|
||||
// The q parameter is expected to be the result of a QSqlQuery returning (chunk_id, embedding) pairs
|
||||
while (q.at() != QSql::AfterLastRow) { // batches
|
||||
batchChunkIds.clear();
|
||||
batchEmbeddings.clear();
|
||||
@ -1937,6 +1980,223 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
|
||||
return chunkIds;
|
||||
}
|
||||
|
||||
QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections,
|
||||
int nNeighbors)
|
||||
{
|
||||
QSqlQuery q(m_db);
|
||||
if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) {
|
||||
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
|
||||
return {};
|
||||
}
|
||||
return searchEmbeddingsHelper(query, q, nNeighbors);
|
||||
}
|
||||
|
||||
QList<int> Database::scoreChunks(const std::vector<float> &query, const QList<int> &chunks)
|
||||
{
|
||||
QList<QString> chunkStrings;
|
||||
for (int id : chunks)
|
||||
chunkStrings << QString::number(id);
|
||||
QSqlQuery q(m_db);
|
||||
if (!q.exec(GET_CHUNK_EMBEDDINGS_SQL.arg(chunkStrings.join(", ")))) {
|
||||
qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError();
|
||||
return {};
|
||||
}
|
||||
return searchEmbeddingsHelper(query, q, chunks.size());
|
||||
}
|
||||
|
||||
QList<Database::BM25Query> Database::queriesForFTS5(const QString &input)
|
||||
{
|
||||
// Escape double quotes by adding a second double quote
|
||||
QString escapedInput = input;
|
||||
escapedInput.replace("\"", "\"\"");
|
||||
|
||||
static QRegularExpression spaces("\\s+");
|
||||
QStringList oWords = escapedInput.split(spaces, Qt::SkipEmptyParts);
|
||||
|
||||
QList<BM25Query> queries;
|
||||
|
||||
// Start by trying to match the entire input
|
||||
BM25Query e;
|
||||
e.isExact = true;
|
||||
e.input = oWords.join(" ");
|
||||
e.query = "\"" + oWords.join(" ") + "\"";
|
||||
e.qlength = oWords.size();
|
||||
e.ilength = oWords.size();
|
||||
queries << e;
|
||||
|
||||
// https://github.com/igorbrigadir/stopwords?tab=readme-ov-file
|
||||
// Lucene, Solr, Elastisearch
|
||||
static const QSet<QString> stopWords = {
|
||||
"a", "an", "and", "are", "as", "at", "be", "but", "by",
|
||||
"for", "if", "in", "into", "is", "it", "no", "not", "of",
|
||||
"on", "or", "such", "that", "the", "their", "then", "there",
|
||||
"these", "they", "this", "to", "was", "will", "with"
|
||||
};
|
||||
|
||||
QStringList quotedWords;
|
||||
for (const QString &w : oWords)
|
||||
if (!stopWords.contains(w.toLower()))
|
||||
quotedWords << "\"" + w + "\"";
|
||||
|
||||
BM25Query b;
|
||||
b.input = oWords.join(" ");
|
||||
b.query = "(" + quotedWords.join(" OR ") + ")";
|
||||
b.qlength = 1; // length of phrase
|
||||
b.ilength = oWords.size();
|
||||
b.rlength = oWords.size() - quotedWords.size();
|
||||
queries << b;
|
||||
return queries;
|
||||
}
|
||||
|
||||
QList<int> Database::searchBM25(const QString &query, const QList<QString> &collections, BM25Query &bm25q, int k)
|
||||
{
|
||||
struct SearchResult { int chunkId; float score; };
|
||||
QList<BM25Query> bm25Queries = queriesForFTS5(query);
|
||||
|
||||
QSqlQuery sqlQuery(m_db);
|
||||
sqlQuery.prepare(SELECT_CHUNKS_FTS_SQL.arg(k));
|
||||
|
||||
QList<SearchResult> results;
|
||||
for (auto &bm25Query : std::as_const(bm25Queries)) {
|
||||
sqlQuery.addBindValue(bm25Query.query);
|
||||
|
||||
if (!sqlQuery.exec()) {
|
||||
qWarning() << "Database ERROR: Failed to execute BM25 query:" << sqlQuery.lastError();
|
||||
return {};
|
||||
}
|
||||
|
||||
if (sqlQuery.next()) {
|
||||
// Save the query that was used to produce results
|
||||
bm25q = bm25Query;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
do {
|
||||
const int chunkId = sqlQuery.value(0).toInt();
|
||||
const float score = sqlQuery.value(1).toFloat();
|
||||
results.append({chunkId, score});
|
||||
} while (sqlQuery.next());
|
||||
|
||||
k = qMin(k, results.size());
|
||||
std::partial_sort(
|
||||
results.begin(), results.begin() + k, results.end(),
|
||||
[](const SearchResult &a, const SearchResult &b) { return a.score < b.score; }
|
||||
);
|
||||
|
||||
QList<int> chunkIds;
|
||||
chunkIds.reserve(k);
|
||||
for (int i = 0; i < k; i++)
|
||||
chunkIds << results[i].chunkId;
|
||||
return chunkIds;
|
||||
}
|
||||
|
||||
float Database::computeBM25Weight(const Database::BM25Query &bm25q)
|
||||
{
|
||||
float bmWeight = 0.0f;
|
||||
if (bm25q.isExact) {
|
||||
bmWeight = 0.9f; // the highest we give
|
||||
} else {
|
||||
// qlength is the length of the phrases in the query by number of distinct words
|
||||
// ilength is the length of the natural language query by number of distinct words
|
||||
// rlength is the number of stop words removed from the natural language query to form the query
|
||||
|
||||
// calculate the query length weight based on the ratio of query terms to meaningful terms.
|
||||
// this formula adjusts the weight with the empirically determined insight that BM25's
|
||||
// effectiveness decreases as query length increases.
|
||||
float queryLengthWeight = 1 / powf(float(bm25q.ilength - bm25q.rlength), 2);
|
||||
queryLengthWeight = qBound(0.0f, queryLengthWeight, 1.0f);
|
||||
|
||||
// the weighting is bound between 1/4 and 3/4 which was determined empirically to work well
|
||||
// with the beir nfcorpus, scifact, fiqa and trec-covid datasets along with our embedding
|
||||
// model
|
||||
bmWeight = 0.25f + queryLengthWeight * 0.50f;
|
||||
}
|
||||
|
||||
#if 0
|
||||
qDebug()
|
||||
<< "bm25q.type" << bm25q.type
|
||||
<< "bm25q.qlength" << bm25q.qlength
|
||||
<< "bm25q.ilength" << bm25q.ilength
|
||||
<< "bm25q.rlength" << bm25q.rlength
|
||||
<< "bmWeight" << bmWeight;
|
||||
#endif
|
||||
|
||||
return bmWeight;
|
||||
}
|
||||
|
||||
QList<int> Database::reciprocalRankFusion(const std::vector<float> &query, const QList<int> &embeddingResults,
|
||||
const QList<int> &bm25Results, const BM25Query &bm25q, int k)
|
||||
{
|
||||
// We default to the embedding results and augment with bm25 if any
|
||||
QList<int> results = embeddingResults;
|
||||
|
||||
QList<int> missingScores;
|
||||
QHash<int, int> bm25Ranks;
|
||||
for (int i = 0; i < bm25Results.size(); ++i) {
|
||||
if (!results.contains(bm25Results[i]))
|
||||
missingScores.append(bm25Results[i]);
|
||||
bm25Ranks[bm25Results[i]] = i + 1;
|
||||
}
|
||||
|
||||
if (!missingScores.isEmpty()) {
|
||||
QList<int> scored = scoreChunks(query, missingScores);
|
||||
results << scored;
|
||||
}
|
||||
|
||||
QHash<int, int> embeddingRanks;
|
||||
for (int i = 0; i < results.size(); ++i)
|
||||
embeddingRanks[results[i]] = i + 1;
|
||||
|
||||
const float bmWeight = bm25Results.isEmpty() ? 0 : computeBM25Weight(bm25q);
|
||||
|
||||
// From the paper: "Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"
|
||||
// doi: 10.1145/1571941.157211
|
||||
const int fusion_k = 60;
|
||||
|
||||
std::stable_sort(
|
||||
results.begin(), results.end(),
|
||||
[&](const int &a, const int &b) {
|
||||
// Reciprocal Rank Fusion (RRF)
|
||||
const int aBm25Rank = bm25Ranks.value(a, bm25Results.size() + 1);
|
||||
const int aEmbeddingRank = embeddingRanks.value(a, embeddingResults.size() + 1);
|
||||
Q_ASSERT(embeddingRanks.contains(a));
|
||||
|
||||
const int bBm25Rank = bm25Ranks.value(b, bm25Results.size() + 1);
|
||||
const int bEmbeddingRank = embeddingRanks.value(b, embeddingResults.size() + 1);
|
||||
Q_ASSERT(embeddingRanks.contains(b));
|
||||
|
||||
const float aBm25Score = 1.0f / (fusion_k + aBm25Rank);
|
||||
const float bBm25Score = 1.0f / (fusion_k + bBm25Rank);
|
||||
const float aEmbeddingScore = 1.0f / (fusion_k + aEmbeddingRank);
|
||||
const float bEmbeddingScore = 1.0f / (fusion_k + bEmbeddingRank);
|
||||
const float aWeightedScore = bmWeight * aBm25Score + (1.f - bmWeight) * aEmbeddingScore;
|
||||
const float bWeightedScore = bmWeight * bBm25Score + (1.f - bmWeight) * bEmbeddingScore;
|
||||
|
||||
// Higher RRF score means better ranking, so we use greater than for sorting
|
||||
return aWeightedScore > bWeightedScore;
|
||||
}
|
||||
);
|
||||
|
||||
k = qMin(k, results.size());
|
||||
results.resize(k);
|
||||
return results;
|
||||
}
|
||||
|
||||
QList<int> Database::searchDatabase(const QString &query, const QList<QString> &collections, int k)
|
||||
{
|
||||
std::vector<float> queryEmbd = m_embLLM->generateQueryEmbedding(query);
|
||||
if (queryEmbd.empty()) {
|
||||
qDebug() << "ERROR: generating embeddings returned a null result";
|
||||
return { };
|
||||
}
|
||||
|
||||
const QList<int> embeddingResults = searchEmbeddings(queryEmbd, collections, k);
|
||||
BM25Query bm25q;
|
||||
const QList<int> bm25Results = searchBM25(query, collections, bm25q, k);
|
||||
return reciprocalRankFusion(queryEmbd, embeddingResults, bm25Results, bm25q, k);
|
||||
}
|
||||
|
||||
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
|
||||
QList<ResultInfo> *results)
|
||||
{
|
||||
@ -1944,13 +2204,7 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
|
||||
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
|
||||
#endif
|
||||
|
||||
std::vector<float> queryEmbd = m_embLLM->generateQueryEmbedding(text);
|
||||
if (queryEmbd.empty()) {
|
||||
qDebug() << "ERROR: generating embeddings returned a null result";
|
||||
return;
|
||||
}
|
||||
|
||||
QList<int> searchResults = searchEmbeddings(queryEmbd, collections, retrievalSize);
|
||||
QList<int> searchResults = searchDatabase(text, collections, retrievalSize);
|
||||
if (searchResults.isEmpty())
|
||||
return;
|
||||
|
||||
@ -1960,10 +2214,9 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
|
||||
return;
|
||||
}
|
||||
|
||||
QHash<int, ResultInfo> tempResults;
|
||||
while (q.next()) {
|
||||
#if defined(DEBUG)
|
||||
const int rowid = q.value(0).toInt();
|
||||
#endif
|
||||
const QString document_path = q.value(2).toString();
|
||||
const QString chunk_text = q.value(3).toString();
|
||||
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
|
||||
@ -1985,12 +2238,16 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
|
||||
info.page = page;
|
||||
info.from = from;
|
||||
info.to = to;
|
||||
results->append(info);
|
||||
tempResults.insert(rowid, info);
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "retrieve rowid:" << rowid
|
||||
<< "chunk_text:" << chunk_text;
|
||||
#endif
|
||||
}
|
||||
|
||||
for (int id : searchResults)
|
||||
if (tempResults.contains(id))
|
||||
results->append(tempResults.value(id));
|
||||
}
|
||||
|
||||
// FIXME This is very slow and non-interruptible and when we close the application and we're
|
||||
|
@ -35,7 +35,7 @@ class QTimer;
|
||||
// minimum supported version
|
||||
static const int LOCALDOCS_MIN_VER = 1;
|
||||
// current version
|
||||
static const int LOCALDOCS_VERSION = 2;
|
||||
static const int LOCALDOCS_VERSION = 3;
|
||||
|
||||
struct DocumentInfo
|
||||
{
|
||||
@ -206,7 +206,24 @@ private:
|
||||
bool cleanDB();
|
||||
void addFolderToWatch(const QString &path);
|
||||
void removeFolderFromWatch(const QString &path);
|
||||
QList<int> searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections, int nNeighbors);
|
||||
static QList<int> searchEmbeddingsHelper(const std::vector<float> &query, QSqlQuery &q, int nNeighbors);
|
||||
QList<int> searchEmbeddings(const std::vector<float> &query, const QList<QString> &collections,
|
||||
int nNeighbors);
|
||||
struct BM25Query {
|
||||
QString input;
|
||||
QString query;
|
||||
bool isExact = false;
|
||||
int qlength = 0;
|
||||
int ilength = 0;
|
||||
int rlength = 0;
|
||||
};
|
||||
QList<Database::BM25Query> queriesForFTS5(const QString &input);
|
||||
QList<int> searchBM25(const QString &query, const QList<QString> &collections, BM25Query &bm25q, int k);
|
||||
QList<int> scoreChunks(const std::vector<float> &query, const QList<int> &chunks);
|
||||
float computeBM25Weight(const BM25Query &bm25q);
|
||||
QList<int> reciprocalRankFusion(const std::vector<float> &query, const QList<int> &embeddingResults,
|
||||
const QList<int> &bm25Results, const BM25Query &bm25q, int k);
|
||||
QList<int> searchDatabase(const QString &query, const QList<QString> &collections, int k);
|
||||
|
||||
void setStartUpdateTime(CollectionItem &item);
|
||||
void setLastUpdateTime(CollectionItem &item);
|
||||
|
Loading…
Reference in New Issue
Block a user