From 69b9d9ea8ba5b69c0a90485addac40a00105ca90 Mon Sep 17 00:00:00 2001 From: Angainor Development <54739135+AngainorDev@users.noreply.github.com> Date: Mon, 27 Mar 2023 21:13:35 +0200 Subject: [PATCH] Fix a warning (#186) Avoids the "Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." warning --- finetune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/finetune.py b/finetune.py index 883a663..67f122c 100644 --- a/finetune.py +++ b/finetune.py @@ -108,6 +108,7 @@ def train( model = LlamaForCausalLM.from_pretrained( base_model, load_in_8bit=True, + torch_dtype=torch.float16, device_map=device_map, )