python: connection resume and MSVC support (#1535)

This commit is contained in:
cebtenzzre 2023-10-19 12:06:38 -04:00 committed by GitHub
parent 017c3a9649
commit 5fbeeb1cb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 37 deletions

View File

@ -11,7 +11,9 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
import requests
from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError
from . import pyllmodel
@ -217,35 +219,61 @@ class GPT4All:
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
download_url = get_download_url(model_filename)
response = requests.get(download_url, stream=True)
if response.status_code != 200:
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
def make_request(offset=None):
headers = {}
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
response = requests.get(download_url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests')
return response
response = make_request()
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2**20 # 1 MB
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
with open(download_path, "wb") as file, \
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
try:
with open(download_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
while True:
last_progress = progress_bar.n
try:
for data in response.iter_content(block_size):
file.write(data)
progress_bar.update(len(data))
except ChunkedEncodingError as cee:
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
# the socket was closed during a read - retry
response = make_request(progress_bar.n)
continue
raise
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
if progress_bar.n == last_progress:
raise RuntimeError('Download not making progress, aborting.')
# server closed connection prematurely - retry
response = make_request(progress_bar.n)
continue
break
except Exception:
if os.path.exists(download_path):
if verbose:
print("Cleaning up the interrupted download...")
if verbose:
print("Cleaning up the interrupted download...", file=sys.stderr)
try:
os.remove(download_path)
except OSError:
pass
raise
# Validate download was successful
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise RuntimeError("An error occurred during download. Downloaded file may not work.")
# Sleep for a little bit so OS can remove file lock
time.sleep(2)
if os.name == 'nt':
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
if verbose:
print("Model downloaded at: ", download_path)
print("Model downloaded at:", download_path, file=sys.stderr)
return download_path
def generate(

View File

@ -23,28 +23,20 @@ MODEL_LIB_PATH = file_manager.enter_context(importlib.resources.as_file(
importlib.resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build",
))
def load_llmodel_library():
system = platform.system()
ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()]
def get_c_shared_lib_extension():
if system == "Darwin":
return "dylib"
elif system == "Linux":
return "so"
elif system == "Windows":
return "dll"
else:
raise Exception("Operating System not supported")
try:
# Linux, Windows, MinGW
lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}"))
except FileNotFoundError:
if ext != 'dll':
raise
# MSVC
lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll"))
c_lib_ext = get_c_shared_lib_extension()
llmodel_file = "libllmodel" + "." + c_lib_ext
llmodel_dir = str(MODEL_LIB_PATH / llmodel_file).replace("\\", r"\\")
llmodel_lib = ctypes.CDLL(llmodel_dir)
return llmodel_lib
return lib
llmodel = load_llmodel_library()