diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 73570f18..1cc66c49 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -125,16 +125,21 @@ class GPT4All(): download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\") download_url = get_download_url(model_filename) - # TODO: Find good way of safely removing file that got interrupted. response = requests.get(download_url, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2 ** 20 # 1 MB - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(download_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() + + with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: + try: + with open(download_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + except Exception: + if os.path.exists(download_path): + print('Cleaning up the interrupted download...') + os.remove(download_path) + raise # Validate download was successful if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: