llmodel: add CUDA to the DLL search path if CUDA_PATH is set (#2357)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-05-16 17:39:49 -04:00 committed by GitHub
parent a92d266cea
commit 2025d2d15b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 3 deletions

View File

@ -58,14 +58,15 @@ public:
#include <string>
#include <exception>
#include <stdexcept>
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#include <libloaderapi.h>
class Dlhandle {
HMODULE chandle;

View File

@ -15,8 +15,16 @@
#include <unordered_map>
#include <vector>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
# include <intrin.h>
#endif
#ifndef __APPLE__
@ -85,6 +93,20 @@ static bool isImplementation(const Dlhandle &dl) {
return dl.get<bool(uint32_t)>("is_g4a_backend_model_implementation");
}
// Add the CUDA Toolkit to the DLL search path on Windows.
// This is necessary for chat.exe to find CUDA when started from Qt Creator.
static void addCudaSearchPath() {
#ifdef _WIN32
if (const auto *cudaPath = _wgetenv(L"CUDA_PATH")) {
auto libDir = std::wstring(cudaPath) + L"\\bin";
if (!AddDllDirectory(libDir.c_str())) {
auto err = GetLastError();
std::wcerr << L"AddDllDirectory(\"" << libDir << L"\") failed with error 0x" << std::hex << err << L"\n";
}
}
#endif
}
const std::vector<LLModel::Implementation> &LLModel::Implementation::implementationList() {
if (cpu_supports_avx() == 0) {
throw std::runtime_error("CPU does not support AVX");
@ -95,6 +117,8 @@ const std::vector<LLModel::Implementation> &LLModel::Implementation::implementat
static auto* libs = new std::vector<Implementation>([] () {
std::vector<Implementation> fres;
addCudaSearchPath();
std::string impl_name_re = "(gptj|llamamodel-mainline)-(cpu|metal|kompute|vulkan|cuda)";
if (cpu_supports_avx2() == 0) {
impl_name_re += "-avxonly";