Fix the way we're injecting the context back into the model for web search.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-27 11:11:41 -04:00
parent c78c95ab42
commit dda59a97a6
3 changed files with 12 additions and 5 deletions

View File

@ -116,6 +116,7 @@ void Chat::resetResponseState()
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
return;
m_sourceExcerpts = QList<SourceExcerpt>();
m_generatedQuestions = QList<QString>();
emit generatedQuestionsChanged();
m_tokenSpeed = QString();
@ -136,6 +137,7 @@ void Chat::prompt(const QString &prompt)
void Chat::regenerateResponse()
{
const int index = m_chatModel->count() - 1;
m_sourceExcerpts = QList<SourceExcerpt>();
m_chatModel->updateSources(index, QList<SourceExcerpt>());
emit regenerateResponseRequested();
}

View File

@ -869,13 +869,13 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32
static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))");
QRegularExpressionMatch match = re.match(toolCall);
QString prompt("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2");
QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2");
QString query;
if (match.hasMatch()) {
query = match.captured(1);
} else {
qWarning() << "WARNING: Could not find the tool for " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, prompt.arg(QString()), QString("%1") /*promptTemplate*/,
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, promptTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
}
@ -887,7 +887,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32
emit sourceExcerptsChanged(braveResponse.second);
return promptInternal(QList<QString>()/*collectionList*/, prompt.arg(braveResponse.first), QString("%1") /*promptTemplate*/,
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, promptTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
}

View File

@ -219,8 +219,13 @@ public:
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
item.sources = sources;
item.consolidatedSources = consolidateSources(sources);
if (sources.isEmpty()) {
item.sources.clear();
item.consolidatedSources.clear();
} else {
item.sources << sources;
item.consolidatedSources << consolidateSources(sources);
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
}