Bump transformers (16-bit llama must be reconverted/redownloaded)

This commit is contained in:
oobabooga 2023-04-06 16:04:03 -03:00
parent 5f4f38ca5d
commit 113f94b61e
3 changed files with 8 additions and 2 deletions

View File

@ -10,7 +10,7 @@ import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
@ -172,6 +172,8 @@ def load_model(model_name):
# Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left'

View File

@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:,1:]
if shared.args.cpu:
return input_ids
elif shared.args.flexgen:

View File

@ -13,4 +13,4 @@ safetensors==0.3.0
sentencepiece
pyyaml
tqdm
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0
git+https://github.com/huggingface/transformers