diff --git a/modules/models.py b/modules/models.py index 3ec4cd9d..d639ca65 100644 --- a/modules/models.py +++ b/modules/models.py @@ -193,15 +193,25 @@ def load_model(model_name): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) elif type(model) is transformers.LlamaForCausalLM: - tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) - # Leaving this here until the LLaMA tokenizer gets figured out. - # For some people this fixes things, for others it causes an error. - try: - tokenizer.eos_token_id = 2 - tokenizer.bos_token_id = 1 - tokenizer.pad_token_id = 0 - except: - pass + tokenizer = None + + # Try to load an universal LLaMA tokenizer + for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: + if p.exists(): + print(f"Loading the universal LLaMA tokenizer from {p}...") + tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) + break + + # Otherwise, load it from the model folder and hope that these + # are not outdated tokenizer files. + if tokenizer is None: + tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except: + pass else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), trust_remote_code=trust_remote_code)