2023-10-24 09:28:21 -04:00
|
|
|
#!/usr/bin/env python3
|
2023-03-27 17:50:08 -04:00
|
|
|
import json
|
|
|
|
import torch
|
2023-03-27 20:09:47 -04:00
|
|
|
import pickle
|
2023-03-27 17:50:08 -04:00
|
|
|
import numpy as np
|
2023-03-27 20:09:47 -04:00
|
|
|
from tqdm import tqdm
|
2023-03-27 17:50:08 -04:00
|
|
|
from read import read_config
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from peft import PeftModelForCausalLM
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
2023-03-28 14:47:38 -04:00
|
|
|
'''
|
|
|
|
Evaluates perplexity on the outputs of:
|
|
|
|
https://github.com/yizhongw/self-instruct/blob/main/human_eval/user_oriented_instructions.jsonl
|
|
|
|
'''
|
|
|
|
|
2023-03-27 17:50:08 -04:00
|
|
|
def read_jsonl_file(file_path):
|
|
|
|
data = []
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
|
for line in file:
|
|
|
|
json_object = json.loads(line.strip())
|
|
|
|
data.append(json_object)
|
|
|
|
return data
|
|
|
|
|
|
|
|
def setup_model(config):
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(config["model_name"], device_map="auto", torch_dtype=torch.float16, output_hidden_states=True)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
|
|
|
|
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
|
|
|
|
|
|
|
|
if added_tokens > 0:
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
2023-03-27 20:09:47 -04:00
|
|
|
if 'lora' in config and config['lora']:
|
2023-03-27 17:50:08 -04:00
|
|
|
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16, return_hidden_states=True)
|
|
|
|
model.to(dtype=torch.float16)
|
|
|
|
|
|
|
|
print(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB")
|
|
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eval_example(model, tokenizer, example, config):
|
|
|
|
|
|
|
|
prompt = example['instruction'] + ' ' + example['instances'][0]['input']
|
|
|
|
gt = prompt + ' ' + example['instances'][0]['output']
|
|
|
|
|
|
|
|
#decode several continuations and compute their page trajectories
|
|
|
|
input = tokenizer(prompt, return_tensors="pt")
|
|
|
|
input = {k: v.to(model.device) for k, v in input.items()}
|
|
|
|
|
|
|
|
#compute the ground truth perplexity
|
2023-03-27 20:09:47 -04:00
|
|
|
gt_input = tokenizer(gt, return_tensors="pt")
|
|
|
|
gt_input = {k: v.to(model.device) for k, v in gt_input.items()}
|
|
|
|
|
2023-03-27 17:50:08 -04:00
|
|
|
nlls = []
|
|
|
|
prev_end_loc = 0
|
2023-03-27 20:09:47 -04:00
|
|
|
stride = 512
|
|
|
|
seq_len = gt_input['input_ids'].size(1)
|
|
|
|
|
|
|
|
for begin_loc in tqdm(range(input['input_ids'].size(1), gt_input['input_ids'].size(1), stride)):
|
|
|
|
end_loc = min(begin_loc + stride, seq_len)
|
2023-03-27 17:50:08 -04:00
|
|
|
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
2023-03-27 20:09:47 -04:00
|
|
|
input_ids = gt_input['input_ids'][:, begin_loc:end_loc].to(model.device)
|
2023-03-27 17:50:08 -04:00
|
|
|
target_ids = input_ids.clone()
|
|
|
|
target_ids[:, :-trg_len] = -100
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = model(input_ids, labels=target_ids)
|
|
|
|
neg_log_likelihood = outputs.loss * trg_len
|
|
|
|
|
|
|
|
nlls.append(neg_log_likelihood)
|
|
|
|
prev_end_loc = end_loc
|
|
|
|
if end_loc == seq_len:
|
|
|
|
break
|
|
|
|
|
2023-03-27 20:09:47 -04:00
|
|
|
ppl = torch.exp(torch.stack(nlls).sum() / end_loc).item()
|
|
|
|
print('ppl: ', ppl)
|
2023-03-27 17:50:08 -04:00
|
|
|
|
2023-03-27 20:09:47 -04:00
|
|
|
print(prompt)
|
|
|
|
print(80*'-')
|
2023-04-09 22:14:20 -04:00
|
|
|
|
2023-03-27 17:50:08 -04:00
|
|
|
|
2023-04-09 22:14:20 -04:00
|
|
|
return ppl
|
2023-03-27 17:50:08 -04:00
|
|
|
|
|
|
|
def do_eval(config):
|
|
|
|
eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl')
|
|
|
|
model, tokenizer = setup_model(config)
|
2023-03-27 20:09:47 -04:00
|
|
|
all_perplexities = []
|
|
|
|
for example in tqdm(eval_data):
|
2023-04-09 22:14:20 -04:00
|
|
|
gt_perplexity = eval_example(model, tokenizer, example, config)
|
2023-03-27 20:09:47 -04:00
|
|
|
all_perplexities.append(gt_perplexity)
|
|
|
|
|
2023-04-09 22:14:20 -04:00
|
|
|
|
|
|
|
name = f"eval_data/eval__model-{config['model_name'].replace('/', '_')}{'__lora-' + config['lora_path'].replace('/', '_') if config['lora'] else ''}.pkl"
|
|
|
|
|
|
|
|
with open(name, 'wb') as f:
|
|
|
|
r = {'perplexities': all_perplexities}
|
2023-03-27 20:09:47 -04:00
|
|
|
pickle.dump(r, f)
|
2023-03-27 17:50:08 -04:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument("--config", type=str, required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
config = read_config(args.config)
|
|
|
|
do_eval(config)
|