diff --git a/finetune.py b/finetune.py index 883a663..67f122c 100644 --- a/finetune.py +++ b/finetune.py @@ -108,6 +108,7 @@ def train( model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=True, + torch_dtype=torch.float16, device_map=device_map, )