mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
python: connection resume and MSVC support (#1535)
This commit is contained in:
parent
017c3a9649
commit
5fbeeb1cb4
@ -11,7 +11,9 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
from tqdm import tqdm
|
||||
from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||
|
||||
from . import pyllmodel
|
||||
|
||||
@ -217,35 +219,61 @@ class GPT4All:
|
||||
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
download_url = get_download_url(model_filename)
|
||||
|
||||
response = requests.get(download_url, stream=True)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||
def make_request(offset=None):
|
||||
headers = {}
|
||||
if offset:
|
||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
||||
response = requests.get(download_url, stream=True, headers=headers)
|
||||
if response.status_code not in (200, 206):
|
||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
||||
raise ValueError('Connection was interrupted and server does not support range requests')
|
||||
return response
|
||||
|
||||
response = make_request()
|
||||
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2**20 # 1 MB
|
||||
|
||||
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||
with open(download_path, "wb") as file, \
|
||||
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||
try:
|
||||
with open(download_path, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
while True:
|
||||
last_progress = progress_bar.n
|
||||
try:
|
||||
for data in response.iter_content(block_size):
|
||||
file.write(data)
|
||||
progress_bar.update(len(data))
|
||||
except ChunkedEncodingError as cee:
|
||||
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
|
||||
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
|
||||
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
|
||||
# the socket was closed during a read - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
raise
|
||||
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
|
||||
if progress_bar.n == last_progress:
|
||||
raise RuntimeError('Download not making progress, aborting.')
|
||||
# server closed connection prematurely - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
if os.path.exists(download_path):
|
||||
if verbose:
|
||||
print("Cleaning up the interrupted download...")
|
||||
if verbose:
|
||||
print("Cleaning up the interrupted download...", file=sys.stderr)
|
||||
try:
|
||||
os.remove(download_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Validate download was successful
|
||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||
raise RuntimeError("An error occurred during download. Downloaded file may not work.")
|
||||
|
||||
# Sleep for a little bit so OS can remove file lock
|
||||
time.sleep(2)
|
||||
if os.name == 'nt':
|
||||
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
|
||||
|
||||
if verbose:
|
||||
print("Model downloaded at: ", download_path)
|
||||
print("Model downloaded at:", download_path, file=sys.stderr)
|
||||
return download_path
|
||||
|
||||
def generate(
|
||||
|
@ -23,28 +23,20 @@ MODEL_LIB_PATH = file_manager.enter_context(importlib.resources.as_file(
|
||||
importlib.resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build",
|
||||
))
|
||||
|
||||
|
||||
def load_llmodel_library():
|
||||
system = platform.system()
|
||||
ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()]
|
||||
|
||||
def get_c_shared_lib_extension():
|
||||
if system == "Darwin":
|
||||
return "dylib"
|
||||
elif system == "Linux":
|
||||
return "so"
|
||||
elif system == "Windows":
|
||||
return "dll"
|
||||
else:
|
||||
raise Exception("Operating System not supported")
|
||||
try:
|
||||
# Linux, Windows, MinGW
|
||||
lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}"))
|
||||
except FileNotFoundError:
|
||||
if ext != 'dll':
|
||||
raise
|
||||
# MSVC
|
||||
lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll"))
|
||||
|
||||
c_lib_ext = get_c_shared_lib_extension()
|
||||
|
||||
llmodel_file = "libllmodel" + "." + c_lib_ext
|
||||
|
||||
llmodel_dir = str(MODEL_LIB_PATH / llmodel_file).replace("\\", r"\\")
|
||||
|
||||
llmodel_lib = ctypes.CDLL(llmodel_dir)
|
||||
|
||||
return llmodel_lib
|
||||
return lib
|
||||
|
||||
|
||||
llmodel = load_llmodel_library()
|
||||
|
Loading…
Reference in New Issue
Block a user