Allow for download of models hosted on third party hosts.

This commit is contained in:
Adam Treat 2023-06-04 19:02:43 -04:00
parent 5073630759
commit bdba2e8de6
2 changed files with 13 additions and 6 deletions

View File

@ -228,8 +228,11 @@ void Download::downloadModel(const QString &modelFile)
tempFile->seek(incomplete_size); tempFile->seek(incomplete_size);
} }
ModelInfo info = m_modelMap.value(modelFile);
QString url = !info.url.isEmpty() ? info.url : "http://gpt4all.io/models/" + modelFile;
Network::globalInstance()->sendDownloadStarted(modelFile); Network::globalInstance()->sendDownloadStarted(modelFile);
QNetworkRequest request("http://gpt4all.io/models/" + modelFile); QNetworkRequest request(url);
request.setAttribute(QNetworkRequest::User, modelFile);
request.setRawHeader("range", QString("bytes=%1-").arg(incomplete_size).toUtf8()); request.setRawHeader("range", QString("bytes=%1-").arg(incomplete_size).toUtf8());
QSslConfiguration conf = request.sslConfiguration(); QSslConfiguration conf = request.sslConfiguration();
conf.setPeerVerifyMode(QSslSocket::VerifyNone); conf.setPeerVerifyMode(QSslSocket::VerifyNone);
@ -370,6 +373,7 @@ void Download::parseModelsJsonFile(const QByteArray &jsonData)
QString modelFilename = obj["filename"].toString(); QString modelFilename = obj["filename"].toString();
QString modelFilesize = obj["filesize"].toString(); QString modelFilesize = obj["filesize"].toString();
QString requiresVersion = obj["requires"].toString(); QString requiresVersion = obj["requires"].toString();
QString url = obj["url"].toString();
QByteArray modelMd5sum = obj["md5sum"].toString().toLatin1().constData(); QByteArray modelMd5sum = obj["md5sum"].toString().toLatin1().constData();
bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true"); bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true");
bool bestGPTJ = obj.contains("bestGPTJ") && obj["bestGPTJ"] == QString("true"); bool bestGPTJ = obj.contains("bestGPTJ") && obj["bestGPTJ"] == QString("true");
@ -409,6 +413,7 @@ void Download::parseModelsJsonFile(const QByteArray &jsonData)
modelInfo.bestMPT = bestMPT; modelInfo.bestMPT = bestMPT;
modelInfo.description = description; modelInfo.description = description;
modelInfo.requiresVersion = requiresVersion; modelInfo.requiresVersion = requiresVersion;
modelInfo.url = url;
m_modelMap.insert(modelInfo.filename, modelInfo); m_modelMap.insert(modelInfo.filename, modelInfo);
} }
@ -500,7 +505,7 @@ void Download::handleErrorOccurred(QNetworkReply::NetworkError code)
if (!modelReply) if (!modelReply)
return; return;
QString modelFilename = modelReply->url().fileName(); QString modelFilename = modelReply->request().attribute(QNetworkRequest::User).toString();
qWarning() << "ERROR: Network error occurred attempting to download" qWarning() << "ERROR: Network error occurred attempting to download"
<< modelFilename << modelFilename
<< "code:" << code << "code:" << code
@ -523,7 +528,7 @@ void Download::handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal)
bytesTotal = contentTotalSize.toLongLong(); bytesTotal = contentTotalSize.toLongLong();
} }
QString modelFilename = modelReply->url().fileName(); QString modelFilename = modelReply->request().attribute(QNetworkRequest::User).toString();
emit downloadProgress(tempFile->pos(), bytesTotal, modelFilename); emit downloadProgress(tempFile->pos(), bytesTotal, modelFilename);
} }
@ -539,7 +544,7 @@ void HashAndSaveFile::hashAndSave(const QString &expectedHash, const QString &sa
QFile *tempFile, QNetworkReply *modelReply) QFile *tempFile, QNetworkReply *modelReply)
{ {
Q_ASSERT(!tempFile->isOpen()); Q_ASSERT(!tempFile->isOpen());
QString modelFilename = modelReply->url().fileName(); QString modelFilename = modelReply->request().attribute(QNetworkRequest::User).toString();
// Reopen the tempFile for hashing // Reopen the tempFile for hashing
if (!tempFile->open(QIODevice::ReadOnly)) { if (!tempFile->open(QIODevice::ReadOnly)) {
@ -608,7 +613,7 @@ void Download::handleModelDownloadFinished()
if (!modelReply) if (!modelReply)
return; return;
QString modelFilename = modelReply->url().fileName(); QString modelFilename = modelReply->request().attribute(QNetworkRequest::User).toString();
QFile *tempFile = m_activeDownloads.value(modelReply); QFile *tempFile = m_activeDownloads.value(modelReply);
m_activeDownloads.remove(modelReply); m_activeDownloads.remove(modelReply);
@ -638,7 +643,7 @@ void Download::handleHashAndSaveFinished(bool success,
{ {
// The hash and save should send back with tempfile closed // The hash and save should send back with tempfile closed
Q_ASSERT(!tempFile->isOpen()); Q_ASSERT(!tempFile->isOpen());
QString modelFilename = modelReply->url().fileName(); QString modelFilename = modelReply->request().attribute(QNetworkRequest::User).toString();
Network::globalInstance()->sendDownloadFinished(modelFilename, success); Network::globalInstance()->sendDownloadFinished(modelFilename, success);
ModelInfo info = m_modelMap.value(modelFilename); ModelInfo info = m_modelMap.value(modelFilename);

View File

@ -23,6 +23,7 @@ struct ModelInfo {
Q_PROPERTY(bool isChatGPT MEMBER isChatGPT) Q_PROPERTY(bool isChatGPT MEMBER isChatGPT)
Q_PROPERTY(QString description MEMBER description) Q_PROPERTY(QString description MEMBER description)
Q_PROPERTY(QString requiresVersion MEMBER requiresVersion) Q_PROPERTY(QString requiresVersion MEMBER requiresVersion)
Q_PROPERTY(QString url MEMBER url)
public: public:
QString filename; QString filename;
@ -37,6 +38,7 @@ public:
bool isChatGPT = false; bool isChatGPT = false;
QString description; QString description;
QString requiresVersion; QString requiresVersion;
QString url;
}; };
Q_DECLARE_METATYPE(ModelInfo) Q_DECLARE_METATYPE(ModelInfo)