diff --git a/download-model.py b/download-model.py index 540f94c6..e2a951cb 100644 --- a/download-model.py +++ b/download-model.py @@ -194,18 +194,25 @@ class ModelDownloader: r = self.s.get(url, stream=True, headers=headers, timeout=20) with open(output_path, mode) as f: total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 + # Every 4MB we report an update + block_size = 4*1024*1024 + with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: + count = 0 for data in r.iter_content(block_size): t.update(len(data)) f.write(data) + if self.progress_bar is not None: + count += len(data) + self.progress_bar(float(count)/float(total_size), f"Downloading {filename}") def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) - def download_model_files(self, model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1): + def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar = None, start_from_scratch=False, threads=1): + self.progress_bar = progress_bar # Creating the folder and writing the metadata if not output_folder.exists(): output_folder.mkdir(parents=True, exist_ok=True) diff --git a/server.py b/server.py index 4198dd69..b6699f14 100644 --- a/server.py +++ b/server.py @@ -122,7 +122,7 @@ def count_tokens(text): return 'Couldn\'t count the number of tokens. Is a tokenizer loaded?' -def download_model_wrapper(repo_id): +def download_model_wrapper(repo_id, progress=gr.Progress()): try: downloader_module = importlib.import_module("download-model") downloader = downloader_module.ModelDownloader() @@ -131,6 +131,7 @@ def download_model_wrapper(repo_id): branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main" check = False + progress(0.0) yield ("Cleaning up the model/branch names") model, branch = downloader.sanitize_model_and_branch_names(model, branch) @@ -141,13 +142,16 @@ def download_model_wrapper(repo_id): output_folder = downloader.get_output_folder(model, branch, is_lora) if check: + progress(0.5) yield ("Checking previously downloaded files") downloader.check_model_files(model, branch, links, sha256, output_folder) + progress(1.0) else: yield (f"Downloading files to {output_folder}") - downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) + downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=1) yield ("Done!") except: + progress(1.0) yield traceback.format_exc() @@ -276,7 +280,7 @@ def create_model_menus(): save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) - shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) + shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=True) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)