Fix torch.compile call on windows (#81)

* Windows not support compile

* Fix code style
This commit is contained in:
Kohaku-Blueleaf 2023-03-20 11:16:02 +08:00 committed by GitHub
parent 81eb72f707
commit 450206caaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import os
import sys
import torch
import torch.nn as nn
@ -195,7 +196,7 @@ model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if torch.__version__ >= "2":
if torch.__version__ >= "2" and sys.platform != 'win32':
model = torch.compile(model)
trainer.train()