Simplify GPTQ_loader.py

This commit is contained in:
oobabooga 2023-05-17 16:22:56 -03:00
parent ef10ffc6b4
commit b667ffa51d

View File

@ -123,20 +123,14 @@ def find_quantized_model_file(model_name):
# If the model hasn't been found with a well-behaved name, pick the last .pt # If the model hasn't been found with a well-behaved name, pick the last .pt
# or the last .safetensors found in its folder as a last resort # or the last .safetensors found in its folder as a last resort
if not pt_path: if not pt_path:
found_pts = list(path_to_model.glob("*.pt")) for ext in ['.pt', '.safetensors']:
found_safetensors = list(path_to_model.glob("*.safetensors")) found = list(path_to_model.glob(f"*{ext}"))
pt_path = None if len(found) > 0:
if len(found) > 1:
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
if len(found_pts) > 0: pt_path = found[-1]
if len(found_pts) > 1: break
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
if len(found_safetensors) > 1:
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
pt_path = found_safetensors[-1]
return pt_path return pt_path