From dbd04f35600627609a47188e1b8c05c8f8e2c591 Mon Sep 17 00:00:00 2001 From: "Eric J. Wang" Date: Mon, 27 Mar 2023 14:34:23 -0400 Subject: [PATCH] Fix linters (#185) * install isort * isort . * whoops * fix black --- .github/workflows/lint.yml | 6 +++--- finetune.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1539332..7d63340 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -25,10 +25,10 @@ jobs: python-version: 3.8 - name: Install Python dependencies - run: pip install black black[jupyter] flake8 + run: pip install black black[jupyter] flake8 isort - name: lint isort - run: isort --check --diff + run: isort --check --diff . - name: lint black - run: black --check --diff + run: black --check --diff . diff --git a/finetune.py b/finetune.py index 03175ae..883a663 100644 --- a/finetune.py +++ b/finetune.py @@ -54,8 +54,8 @@ def train( # wandb params wandb_project: str = "", wandb_run_name: str = "", - wandb_watch: str = "", # options: false | gradients | all - wandb_log_model: str = "", # options: false | true + wandb_watch: str = "", # options: false | gradients | all + wandb_log_model: str = "", # options: false | true resume_from_checkpoint: str = None, # either training checkpoint or final adapter ): print( @@ -94,16 +94,16 @@ def train( gradient_accumulation_steps = gradient_accumulation_steps // world_size # Check if parameter passed or if set within environ - use_wandb = len(wandb_project) > 0 or \ - ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) + use_wandb = len(wandb_project) > 0 or ( + "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 + ) # Only overwrite environ if wandb param passed - if len(wandb_project) > 0: - os.environ['WANDB_PROJECT'] = wandb_project + if len(wandb_project) > 0: + os.environ["WANDB_PROJECT"] = wandb_project if len(wandb_watch) > 0: - os.environ['WANDB_WATCH'] = wandb_watch + os.environ["WANDB_WATCH"] = wandb_watch if len(wandb_log_model) > 0: - os.environ['WANDB_LOG_MODEL'] = wandb_log_model - + os.environ["WANDB_LOG_MODEL"] = wandb_log_model model = LlamaForCausalLM.from_pretrained( base_model, @@ -231,7 +231,7 @@ def train( ddp_find_unused_parameters=False if ddp else None, group_by_length=group_by_length, report_to="wandb" if use_wandb else None, - run_name=wandb_run_name if use_wandb else None + run_name=wandb_run_name if use_wandb else None, ), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True