Fix LoRa weight merging

This commit is contained in:
Kevin Kwok 2023-03-16 00:50:24 -07:00 committed by GitHub
parent b8c32be806
commit dde89950f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,7 +21,12 @@ lora_model = PeftModel.from_pretrained(
torch_dtype=torch.float16,
)
lora_model.eval() # merge weights
# merge weights
for layer in lora_model.base_model.model.model.layers:
layer.self_attn.q_proj.merge_weights = True
layer.self_attn.v_proj.merge_weights = True
lora_model.train(False)
lora_model_sd = lora_model.state_dict()