diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index b4f4dac7..55c84ad5 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -123,20 +123,14 @@ def find_quantized_model_file(model_name): # If the model hasn't been found with a well-behaved name, pick the last .pt # or the last .safetensors found in its folder as a last resort if not pt_path: - found_pts = list(path_to_model.glob("*.pt")) - found_safetensors = list(path_to_model.glob("*.safetensors")) - pt_path = None + for ext in ['.pt', '.safetensors']: + found = list(path_to_model.glob(f"*{ext}")) + if len(found) > 0: + if len(found) > 1: + logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') - if len(found_pts) > 0: - if len(found_pts) > 1: - logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.') - - pt_path = found_pts[-1] - elif len(found_safetensors) > 0: - if len(found_safetensors) > 1: - logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.') - - pt_path = found_safetensors[-1] + pt_path = found[-1] + break return pt_path