mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Fix a warning (#186)
Avoids the "Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." warning
This commit is contained in:
parent
dbd04f3560
commit
69b9d9ea8b
@ -108,6 +108,7 @@ def train(
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user