mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Print the softprompt metadata when it is loaded
This commit is contained in:
parent
f79805f4a4
commit
8c9dd95d55
13
server.py
13
server.py
@ -173,7 +173,19 @@ def load_soft_prompt(name):
|
||||
else:
|
||||
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
||||
zf.extract('tensor.npy')
|
||||
zf.extract('meta.json')
|
||||
j = json.loads(open('meta.json', 'r').read())
|
||||
print(f"\nLoading the softprompt \"{name}\".")
|
||||
for field in j:
|
||||
if field != 'name':
|
||||
if type(j[field]) is list:
|
||||
print(f"{field}: {', '.join(j[field])}")
|
||||
else:
|
||||
print(f"{field}: {j[field]}")
|
||||
print()
|
||||
tensor = np.load('tensor.npy')
|
||||
Path('tensor.npy').unlink()
|
||||
Path('meta.json').unlink()
|
||||
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
|
||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
||||
|
||||
@ -187,6 +199,7 @@ def upload_soft_prompt(file):
|
||||
zf.extract('meta.json')
|
||||
j = json.loads(open('meta.json', 'r').read())
|
||||
name = j['name']
|
||||
Path('meta.json').unlink()
|
||||
|
||||
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
||||
f.write(file)
|
||||
|
Loading…
Reference in New Issue
Block a user