diff --git a/server.py b/server.py index e940d366..6aa69c90 100644 --- a/server.py +++ b/server.py @@ -183,9 +183,9 @@ def count_tokens(text): def download_model_wrapper(repo_id): try: downloader = importlib.import_module("download-model") - - model = repo_id - branch = "main" + repo_id_parts = repo_id.split(":") + model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id + branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main" check = False yield ("Cleaning up the model/branch names") @@ -370,7 +370,7 @@ def create_model_menus(): with gr.Row(): with gr.Column(): - shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m") + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main") shared.gradio['download_model_button'] = gr.Button("Download") with gr.Column():