diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index 0c6ff83b..a9c6e001 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -44,8 +44,8 @@ const auto SELECT_SQL = QLatin1String(R"( join folders ON documents.folder_id = folders.id join collections ON folders.id = collections.folder_id where chunks_fts match ? and collections.collection_name in (%1) - order by bm25(chunks_fts) desc - limit 3; + order by bm25(chunks_fts) + limit %2; )"); bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text, int embedding_id, @@ -120,7 +120,7 @@ QStringList generateGrams(const QString &input, int N) return ngrams; } -bool selectChunk(QSqlQuery &q, const QList &collection_names, const QString &chunk_text) +bool selectChunk(QSqlQuery &q, const QList &collection_names, const QString &chunk_text, int retrievalSize) { const int N_WORDS = chunk_text.split(QRegularExpression("\\s+")).size(); for (int N = N_WORDS; N > 2; N--) { @@ -128,7 +128,7 @@ bool selectChunk(QSqlQuery &q, const QList &collection_names, const QSt QList text = generateGrams(chunk_text, N); QString orText = text.join(" OR "); const QString collection_names_str = collection_names.join("', '"); - const QString formatted_query = SELECT_SQL.arg("'" + collection_names_str + "'"); + const QString formatted_query = SELECT_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize)); if (!q.prepare(formatted_query)) return false; q.addBindValue(orText); @@ -480,9 +480,10 @@ QSqlError initDb() return QSqlError(); } -Database::Database() +Database::Database(int chunkSize) : QObject(nullptr) , m_watcher(new QFileSystemWatcher(this)) + , m_chunkSize(chunkSize) { moveToThread(&m_dbThread); connect(&m_dbThread, &QThread::started, this, &Database::start); @@ -500,7 +501,6 @@ void Database::handleDocumentErrorAndScheduleNext(const QString &errorMessage, void Database::chunkStream(QTextStream &stream, int document_id) { - const int chunkSize = 256; int chunk_id = 0; int charCount = 0; QList words; @@ -510,7 +510,7 @@ void Database::chunkStream(QTextStream &stream, int document_id) stream >> word; charCount += word.length(); words.append(word); - if (charCount + words.size() - 1 >= chunkSize || stream.atEnd()) { + if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) { const QString chunk = words.join(" "); QSqlQuery q; if (!addChunk(q, @@ -752,9 +752,7 @@ void Database::addFolder(const QString &collection, const QString &path) return; } - if (!addFolderToWatch(path)) - return; - + addFolderToWatch(path); scanDocuments(folder_id, path); updateCollectionList(); } @@ -869,14 +867,14 @@ bool Database::removeFolderFromWatch(const QString &path) return true; } -void Database::retrieveFromDB(const QList &collections, const QString &text) +void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize) { #if defined(DEBUG) - qDebug() << "retrieveFromDB" << collections << text; + qDebug() << "retrieveFromDB" << collections << text << retrievalSize; #endif QSqlQuery q; - if (!selectChunk(q, collections, text)) { + if (!selectChunk(q, collections, text, retrievalSize)) { qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); return; } @@ -957,6 +955,45 @@ void Database::cleanDB() updateCollectionList(); } +void Database::changeChunkSize(int chunkSize) +{ + if (chunkSize == m_chunkSize) + return; + +#if defined(DEBUG) + qDebug() << "changeChunkSize" << chunkSize; +#endif + + m_chunkSize = chunkSize; + + QSqlQuery q; + // Scan all documents in db to make sure they still exist + if (!q.prepare(SELECT_ALL_DOCUMENTS_SQL)) { + qWarning() << "ERROR: Cannot prepare sql for select all documents" << q.lastError(); + return; + } + + if (!q.exec()) { + qWarning() << "ERROR: Cannot exec sql for select all documents" << q.lastError(); + return; + } + + while (q.next()) { + int document_id = q.value(0).toInt(); + QString document_path = q.value(1).toString(); + // Remove all chunks and documents to change the chunk size + QSqlQuery query; + if (!removeChunksByDocumentId(query, document_id)) { + qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); + } + + if (!removeDocument(query, document_id)) { + qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError(); + } + } + addCurrentFolders(); +} + void Database::directoryChanged(const QString &path) { #if defined(DEBUG) diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index 0e13ee98..76d8bf57 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -25,15 +25,16 @@ class Database : public QObject { Q_OBJECT public: - Database(); + Database(int chunkSize); public Q_SLOTS: void scanQueue(); void scanDocuments(int folder_id, const QString &folder_path); void addFolder(const QString &collection, const QString &path); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QList &collections, const QString &text); + void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize); void cleanDB(); + void changeChunkSize(int chunkSize); Q_SIGNALS: void docsToScanChanged(); @@ -55,6 +56,7 @@ private: int document_id, const QString &document_path, const QSqlError &error); private: + int m_chunkSize; QQueue m_docsToScan; QList m_retrieve; QThread m_dbThread; diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index b340b9c4..a7905306 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -10,14 +10,24 @@ LocalDocs *LocalDocs::globalInstance() LocalDocs::LocalDocs() : QObject(nullptr) , m_localDocsModel(new LocalDocsModel(this)) - , m_database(new Database) + , m_database(nullptr) { + QSettings settings; + settings.sync(); + m_chunkSize = settings.value("localdocs/chunkSize", 256).toInt(); + m_retrievalSize = settings.value("localdocs/retrievalSize", 3).toInt(); + + // Create the DB with the chunk size from settings + m_database = new Database(m_chunkSize); + connect(this, &LocalDocs::requestAddFolder, m_database, &Database::addFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestRemoveFolder, m_database, &Database::removeFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestRetrieveFromDB, m_database, &Database::retrieveFromDB, Qt::QueuedConnection); + connect(this, &LocalDocs::requestChunkSizeChange, m_database, + &Database::changeChunkSize, Qt::QueuedConnection); connect(m_database, &Database::retrieveResult, this, &LocalDocs::handleRetrieveResult, Qt::QueuedConnection); connect(m_database, &Database::collectionListUpdated, @@ -42,7 +52,36 @@ void LocalDocs::removeFolder(const QString &collection, const QString &path) void LocalDocs::requestRetrieve(const QList &collections, const QString &text) { m_retrieveResult = QList(); - emit requestRetrieveFromDB(collections, text); + emit requestRetrieveFromDB(collections, text, m_retrievalSize); +} + +int LocalDocs::chunkSize() const +{ + return m_chunkSize; +} + +void LocalDocs::setChunkSize(int chunkSize) +{ + if (m_chunkSize == chunkSize) + return; + + m_chunkSize = chunkSize; + emit chunkSizeChanged(); + emit requestChunkSizeChange(chunkSize); +} + +int LocalDocs::retrievalSize() const +{ + return m_retrievalSize; +} + +void LocalDocs::setRetrievalSize(int retrievalSize) +{ + if (m_retrievalSize == retrievalSize) + return; + + m_retrievalSize = retrievalSize; + emit retrievalSizeChanged(); } void LocalDocs::handleRetrieveResult(const QList &result) diff --git a/gpt4all-chat/localdocs.h b/gpt4all-chat/localdocs.h index fae76b4a..9395655a 100644 --- a/gpt4all-chat/localdocs.h +++ b/gpt4all-chat/localdocs.h @@ -10,6 +10,8 @@ class LocalDocs : public QObject { Q_OBJECT Q_PROPERTY(LocalDocsModel *localDocsModel READ localDocsModel NOTIFY localDocsModelChanged) + Q_PROPERTY(int chunkSize READ chunkSize WRITE setChunkSize NOTIFY chunkSizeChanged) + Q_PROPERTY(int retrievalSize READ retrievalSize WRITE setRetrievalSize NOTIFY retrievalSizeChanged) public: static LocalDocs *globalInstance(); @@ -22,17 +24,28 @@ public: QList result() const { return m_retrieveResult; } void requestRetrieve(const QList &collections, const QString &text); + int chunkSize() const; + void setChunkSize(int chunkSize); + + int retrievalSize() const; + void setRetrievalSize(int retrievalSize); + Q_SIGNALS: void requestAddFolder(const QString &collection, const QString &path); void requestRemoveFolder(const QString &collection, const QString &path); - void requestRetrieveFromDB(const QList &collections, const QString &text); + void requestRetrieveFromDB(const QList &collections, const QString &text, int N); + void requestChunkSizeChange(int chunkSize); void receivedResult(); void localDocsModelChanged(); + void chunkSizeChanged(); + void retrievalSizeChanged(); private Q_SLOTS: void handleRetrieveResult(const QList &result); private: + int m_chunkSize; + int m_retrievalSize; LocalDocsModel *m_localDocsModel; Database *m_database; QList m_retrieveResult; diff --git a/gpt4all-chat/qml/LocalDocs.qml b/gpt4all-chat/qml/LocalDocs.qml index 30a91304..47c321f5 100644 --- a/gpt4all-chat/qml/LocalDocs.qml +++ b/gpt4all-chat/qml/LocalDocs.qml @@ -8,9 +8,31 @@ import localdocs Item { id: root + property string collection: "" property string folder_path: "" + property int defaultChunkSize: 256 + property int defaultRetrievalSize: 3 + + property alias chunkSize: settings.chunkSize + property alias retrievalSize: settings.retrievalSize + + Settings { + id: settings + category: "localdocs" + property int chunkSize: root.defaultChunkSize + property int retrievalSize: root.defaultRetrievalSize + } + + function restoreLocalDocsDefaults() { + settings.chunkSize = root.defaultChunkSize + settings.retrievalSize = root.defaultRetrievalSize + LocalDocs.chunkSize = settings.chunkSize + LocalDocs.retrievalSize = settings.retrievalSize + settings.sync() + } + FolderDialog { id: folderDialog title: "Please choose a directory" @@ -188,6 +210,21 @@ Item { Layout.column: 1 ToolTip.text: qsTr("Number of characters per document snippet.\nNOTE: larger numbers increase likelihood of factual responses, but also result in slower generation.") ToolTip.visible: hovered + text: settings.chunkSize.toString() + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.chunkSize = val + settings.sync() + focus = false + LocalDocs.chunkSize = settings.chunkSize + } else { + text = settings.chunkSize.toString() + } + } } Label { @@ -203,6 +240,21 @@ Item { Layout.column: 1 ToolTip.text: qsTr("Best N matches of retrieved document snippets to add to the context for prompt.\nNOTE: larger numbers increase likelihood of factual responses, but also result in slower generation.") ToolTip.visible: hovered + text: settings.retrievalSize.toString() + validator: IntValidator { + bottom: 1 + } + onEditingFinished: { + var val = parseInt(text) + if (!isNaN(val)) { + settings.retrievalSize = val + settings.sync() + focus = false + LocalDocs.retrievalSize = settings.retrievalSize + } else { + text = settings.retrievalSize.toString() + } + } } MyButton { @@ -215,7 +267,7 @@ Item { Accessible.name: text Accessible.description: qsTr("Restores the settings dialog to a default state") onClicked: { - // settingsDialog.restoreGenerationDefaults() + root.restoreLocalDocsDefaults(); } } }