Fix a warning (#186)

Avoids the 
"Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." 
warning
This commit is contained in:
Angainor Development 2023-03-27 21:13:35 +02:00 committed by GitHub
parent dbd04f3560
commit 69b9d9ea8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -108,6 +108,7 @@ def train(
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)