diff --git a/modules/text_generation.py b/modules/text_generation.py index 1f6a2819..4b6de6a7 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,3 +1,4 @@ +import gc import re import time @@ -73,7 +74,9 @@ def formatted_outputs(reply, model_name): return reply def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): - torch.cuda.empty_cache() + gc.collect() + if not shared.args.cpu: + torch.cuda.empty_cache() original_question = question if not (shared.args.chat or shared.args.cai_chat):