Fix linters (#185)

* install isort

* isort .

* whoops

* fix black
This commit is contained in:
Eric J. Wang 2023-03-27 14:34:23 -04:00 committed by GitHub
parent 69b31e0fed
commit dbd04f3560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 13 deletions

View File

@ -25,10 +25,10 @@ jobs:
python-version: 3.8
- name: Install Python dependencies
run: pip install black black[jupyter] flake8
run: pip install black black[jupyter] flake8 isort
- name: lint isort
run: isort --check --diff
run: isort --check --diff .
- name: lint black
run: black --check --diff
run: black --check --diff .

View File

@ -54,8 +54,8 @@ def train(
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
):
print(
@ -94,16 +94,16 @@ def train(
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or \
("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ['WANDB_PROJECT'] = wandb_project
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ['WANDB_WATCH'] = wandb_watch
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ['WANDB_LOG_MODEL'] = wandb_log_model
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
@ -231,7 +231,7 @@ def train(
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True