mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
parent
69b31e0fed
commit
dbd04f3560
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@ -25,10 +25,10 @@ jobs:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: pip install black black[jupyter] flake8
|
||||
run: pip install black black[jupyter] flake8 isort
|
||||
|
||||
- name: lint isort
|
||||
run: isort --check --diff
|
||||
run: isort --check --diff .
|
||||
|
||||
- name: lint black
|
||||
run: black --check --diff
|
||||
run: black --check --diff .
|
||||
|
20
finetune.py
20
finetune.py
@ -54,8 +54,8 @@ def train(
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
):
|
||||
print(
|
||||
@ -94,16 +94,16 @@ def train(
|
||||
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||
|
||||
# Check if parameter passed or if set within environ
|
||||
use_wandb = len(wandb_project) > 0 or \
|
||||
("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
||||
use_wandb = len(wandb_project) > 0 or (
|
||||
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
||||
)
|
||||
# Only overwrite environ if wandb param passed
|
||||
if len(wandb_project) > 0:
|
||||
os.environ['WANDB_PROJECT'] = wandb_project
|
||||
if len(wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = wandb_project
|
||||
if len(wandb_watch) > 0:
|
||||
os.environ['WANDB_WATCH'] = wandb_watch
|
||||
os.environ["WANDB_WATCH"] = wandb_watch
|
||||
if len(wandb_log_model) > 0:
|
||||
os.environ['WANDB_LOG_MODEL'] = wandb_log_model
|
||||
|
||||
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@ -231,7 +231,7 @@ def train(
|
||||
ddp_find_unused_parameters=False if ddp else None,
|
||||
group_by_length=group_by_length,
|
||||
report_to="wandb" if use_wandb else None,
|
||||
run_name=wandb_run_name if use_wandb else None
|
||||
run_name=wandb_run_name if use_wandb else None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
|
Loading…
Reference in New Issue
Block a user