Don't show .pt models in the list

This commit is contained in:
oobabooga 2023-03-09 21:54:50 -03:00
parent 1a3d25f75d
commit 9849aac0f1
2 changed files with 4 additions and 1 deletions

View File

@ -105,6 +105,9 @@ def load_model(model_name):
if not Path(f"models/{pt_model}").exists(): if not Path(f"models/{pt_model}").exists():
print(f"Could not find models/{pt_model}, exiting...") print(f"Could not find models/{pt_model}, exiting...")
exit() exit()
elif pt_model == '':
print(f"Could not find the .pt model for {model_name}, exiting...")
exit()
model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4) model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
model = model.to(torch.device('cuda:0')) model = model.to(torch.device('cuda:0'))

View File

@ -37,7 +37,7 @@ def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower) return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)