mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
metrics run on configs now
This commit is contained in:
parent
d5769d7614
commit
4e8e7e7300
@ -5,7 +5,7 @@ lora: true
|
||||
lora_path: "nomic-ai/vicuna-lora-1024"
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: .25
|
||||
temperature: 0.001
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
|
@ -4,7 +4,7 @@ tokenizer_name: "zpn/llama-7b"
|
||||
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0
|
||||
temperature: 0.001
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
|
@ -4,7 +4,7 @@ tokenizer_name: "zpn/llama-7b"
|
||||
|
||||
|
||||
max_new_tokens: 512
|
||||
temperature: 0
|
||||
temperature: 0.001
|
||||
prompt: |
|
||||
#this code prints a string reversed
|
||||
my_string = "hello how are you"
|
||||
|
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from read import read_config
|
||||
from argparse import ArgumentParser
|
||||
from peft import PeftModelForCausalLM
|
||||
@ -22,7 +24,7 @@ def setup_model(config):
|
||||
if added_tokens > 0:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if config["lora"]:
|
||||
if 'lora' in config and config['lora']:
|
||||
model = PeftModelForCausalLM.from_pretrained(model, config["lora_path"], device_map="auto", torch_dtype=torch.float16, return_hidden_states=True)
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
@ -33,10 +35,8 @@ def setup_model(config):
|
||||
|
||||
|
||||
|
||||
|
||||
def eval_example(model, tokenizer, example, config):
|
||||
|
||||
#set up data
|
||||
prompt = example['instruction'] + ' ' + example['instances'][0]['input']
|
||||
gt = prompt + ' ' + example['instances'][0]['output']
|
||||
|
||||
@ -45,29 +45,40 @@ def eval_example(model, tokenizer, example, config):
|
||||
input = {k: v.to(model.device) for k, v in input.items()}
|
||||
|
||||
continuations = []
|
||||
tokenized_continuations = []
|
||||
trajectories = []
|
||||
for i in range(5):
|
||||
print(i)
|
||||
outputs = model.generate(input_ids=input['input_ids'],
|
||||
max_new_tokens=config["max_new_tokens"],
|
||||
temperature=config["temperature"])
|
||||
for i in range(3):
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(input_ids=input['input_ids'],
|
||||
max_new_tokens=config["max_new_tokens"],
|
||||
min_new_tokens=5,
|
||||
temperature=config["temperature"],
|
||||
repetition_penalty=1.0,
|
||||
do_sample=True)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
||||
|
||||
y = model(input_ids=outputs)
|
||||
y = model(input_ids=outputs)
|
||||
trajectory = y.hidden_states[0].detach().cpu().numpy()[0]
|
||||
trajectory = trajectory / np.linalg.norm(trajectory, axis=1, keepdims=True)
|
||||
trajectory = np.cumsum(trajectory, axis=0) / np.arange(1, trajectory.shape[0]+1).reshape(-1, 1)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
||||
|
||||
trajectories.append(trajectory)
|
||||
continuations.append(decoded[len(prompt):])
|
||||
continuations.append(decoded)
|
||||
tokenized_continuations.append(tokenizer.tokenize(decoded))
|
||||
|
||||
#compute the ground truth perplexity
|
||||
gt_input = tokenizer(gt, return_tensors="pt")
|
||||
gt_input = {k: v.to(model.device) for k, v in gt_input.items()}
|
||||
|
||||
nlls = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(len(prompt), len(gt), 1)):
|
||||
end_loc = min(begin_loc + max_length, seq_len)
|
||||
stride = 512
|
||||
seq_len = gt_input['input_ids'].size(1)
|
||||
|
||||
for begin_loc in tqdm(range(input['input_ids'].size(1), gt_input['input_ids'].size(1), stride)):
|
||||
end_loc = min(begin_loc + stride, seq_len)
|
||||
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
|
||||
input_ids = input['input_ids'][:, begin_loc:end_loc].to(model.device)
|
||||
input_ids = gt_input['input_ids'][:, begin_loc:end_loc].to(model.device)
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:, :-trg_len] = -100
|
||||
|
||||
@ -80,26 +91,36 @@ def eval_example(model, tokenizer, example, config):
|
||||
if end_loc == seq_len:
|
||||
break
|
||||
|
||||
ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
|
||||
ppl = torch.exp(torch.stack(nlls).sum() / end_loc).item()
|
||||
print('ppl: ', ppl)
|
||||
|
||||
print('perplexity: ', ppl)
|
||||
print('trajectories: ', trajectories)
|
||||
print('continuations: ', continuations)
|
||||
print(prompt)
|
||||
print(80*'-')
|
||||
for continuation in continuations:
|
||||
print(continuation)
|
||||
print(80*'-')
|
||||
|
||||
raise
|
||||
|
||||
return ppl, trajectories, continuations
|
||||
return ppl, trajectories, continuations, tokenized_continuations
|
||||
|
||||
def do_eval(config):
|
||||
eval_data = read_jsonl_file('eval_data/user_oriented_instructions.jsonl')
|
||||
model, tokenizer = setup_model(config)
|
||||
trajectories = []
|
||||
perplexities = []
|
||||
continuations = []
|
||||
for example in eval_data:
|
||||
gt_perplexity, trajectories, continuations = eval_example(model, tokenizer, example, config)
|
||||
|
||||
all_trajectories = []
|
||||
all_perplexities = []
|
||||
all_continuations = []
|
||||
all_tokenized_continuations = []
|
||||
for example in tqdm(eval_data):
|
||||
gt_perplexity, trajectories, continuations, tokenized_continuations = eval_example(model, tokenizer, example, config)
|
||||
all_trajectories.append(trajectories)
|
||||
all_perplexities.append(gt_perplexity)
|
||||
all_continuations.append(continuations)
|
||||
|
||||
with open('eval_data/eval__model-{}__lora-{}.pkl'.format(config['model_name'].replace('/', '_'), config['lora_path'].replace('/', '_')), 'wb') as f:
|
||||
r = {'trajectories': all_trajectories,
|
||||
'perplexities': all_perplexities,
|
||||
'continuations': all_continuations,
|
||||
'tokenized_continuations': all_tokenized_continuations}
|
||||
pickle.dump(r, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user