diff --git a/modules/evaluate.py b/modules/evaluate.py index 61e30261..866d7f90 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -82,7 +82,12 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): yield cumulative_log + "Tokenizing the input dataset...\n\n" encodings = encode(text, add_special_tokens=False) seq_len = encodings.shape[1] - max_length = _max_length or shared.model.config.max_position_embeddings + if not _max_length: + if hasattr(shared.model.config, 'max_position_embeddings'): + max_length = shared.model.config.max_position_embeddings + else: + max_length = 2048 + nlls = [] prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)):