diff --git a/README.md b/README.md index 9b78cf01..c8a76bb1 100644 --- a/README.md +++ b/README.md @@ -85,13 +85,15 @@ Then browse to Optionally, you can use the following command-line flags: ``` --h, --help show this help message and exit ---model MODEL Name of the model to load by default. ---notebook Launch the webui in notebook mode, where the output is written - to the same text box as the input. ---chat Launch the webui in chat mode. ---cpu Use the CPU to generate text. ---listen Make the webui reachable from your local network. +-h, --help show this help message and exit +--model MODEL Name of the model to load by default. +--notebook Launch the webui in notebook mode, where the output is written to the same text + box as the input. +--chat Launch the webui in chat mode. +--cpu Use the CPU to generate text. +--auto-devices Automatically split the model across the available GPU(s) and CPU. +--load-in-8bit Load the model with 8-bit precision. +--listen Make the webui reachable from your local network. ``` ## Presets diff --git a/server.py b/server.py index 3f6172a1..f066f418 100644 --- a/server.py +++ b/server.py @@ -17,6 +17,8 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') +parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') +parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--listen', action='store_true', help='Make the webui reachable from your local network.') args = parser.parse_args() loaded_preset = None @@ -28,23 +30,45 @@ def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() - # Loading the model - if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists(): - print("Loading in .pt format...") - model = torch.load(Path(f"torch-dumps/{model_name}.pt")) - elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) - elif model_name in ['flan-t5', 't5-large']: - if args.cpu: - model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")) - else: + # Default settings + if not (args.cpu or args.auto_devices or args.load_in_8bit): + if Path(f"torch-dumps/{model_name}.pt").exists(): + print("Loading in .pt format...") + model = torch.load(Path(f"torch-dumps/{model_name}.pt")) + elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')): + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) + elif model_name in ['flan-t5', 't5-large']: model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda() - else: - if args.cpu: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float32) else: model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() + # Custom + else: + settings = ["low_cpu_mem_usage=True"] + cuda = "" + + if model_name in ['flan-t5', 't5-large']: + command = f"T5ForConditionalGeneration.from_pretrained" + else: + command = "AutoModelForCausalLM.from_pretrained" + + if args.cpu: + settings.append("torch_dtype=torch.float32") + else: + if args.load_in_8bit: + settings.append("device_map='auto'") + settings.append("load_in_8bit=True") + else: + settings.append("torch_dtype=torch.float16") + if args.auto_devices: + settings.append("device_map='auto'") + else: + cuda = ".cuda()" + + settings = ', '.join(settings) + command = f"{command}(Path(f'models/{model_name}'), {settings}){cuda}" + model = eval(command) + # Loading the tokenizer if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))