From ce45efe6053cc38bb104e3e7a130b11d3a12917f Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 3 Jul 2024 12:04:52 +0800 Subject: [PATCH] load model library when using pyinstaller to build executable file in windows Signed-off-by: nathan --- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 892d72e7..4e578064 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -5,6 +5,7 @@ import os import platform import re import subprocess +from pathlib import Path import sys import threading from enum import Enum @@ -54,16 +55,37 @@ MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" def load_llmodel_library(): + """ + Loads the llmodel shared library based on the current operating system. + + This function attempts to load the shared library using the appropriate file + extension for the operating system. It first tries to load the library with the + 'lib' prefix (common for macOS, Linux, and MinGW on Windows). If the file is not + found and the operating system is Windows, it attempts to load the library without + the 'lib' prefix (common for MSVC on Windows). + + Returns: + ctypes.CDLL: The loaded shared library. + + Raises: + OSError: If the shared library cannot be found. + """ + # Determine the appropriate file extension for the shared library based on the platform ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()] + # Define library names with and without the 'lib' prefix + library_name_with_lib_prefix = f"libllmodel.{ext}" + library_name_without_lib_prefix = "llmodel.dll" + base_path = MODEL_LIB_PATH + try: - # macOS, Linux, MinGW - lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}")) - except FileNotFoundError: - if ext != 'dll': + # Attempt to load the shared library with the 'lib' prefix (common for macOS, Linux, and MinGW) + lib = ctypes.CDLL(str(base_path / library_name_with_lib_prefix)) + except OSError: # OSError is more general and includes FileNotFoundError + if ext != "dll": raise - # MSVC - lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll")) + # For Windows (ext == 'dll'), attempt to load the shared library without the 'lib' prefix (common for MSVC) + lib = ctypes.CDLL(str(base_path / library_name_without_lib_prefix)) return lib