Clear the torch cache while evaluating

This commit is contained in:
oobabooga 2023-10-16 10:52:50 -07:00
parent 388d1864a6
commit 2d44adbb76

View File

@ -7,7 +7,7 @@ from datasets import load_dataset
from tqdm import tqdm
from modules import shared
from modules.models import load_model, unload_model
from modules.models import clear_torch_cache, load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import encode
@ -97,7 +97,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
input_ids = encodings[:, begin_loc:end_loc]
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
clear_torch_cache()
with torch.no_grad():
outputs = shared.model(input_ids=input_ids, labels=target_ids)
@ -107,7 +107,6 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
neg_log_likelihood = outputs.loss
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break