From f0fcd1f697d0adcb3bc96f6fa2a947b697c15209 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 25 Jun 2023 01:44:36 -0300 Subject: [PATCH] Sort some imports --- modules/chat.py | 7 +++++-- modules/evaluate.py | 6 ++++-- modules/html_generator.py | 6 ------ modules/models.py | 18 +++++++++++++----- modules/monkey_patch_gptq_lora.py | 10 +++++++--- modules/sampler_hijack.py | 8 +++++--- modules/training.py | 21 +++++++++++++++------ server.py | 17 +++++++++++------ 8 files changed, 60 insertions(+), 33 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index ac3e52dc..6d57dbfe 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -14,8 +14,11 @@ import modules.shared as shared from modules.extensions import apply_extensions from modules.html_generator import chat_html_wrapper, make_thumbnail from modules.logging_colors import logger -from modules.text_generation import (generate_reply, get_encoded_length, - get_max_prompt_length) +from modules.text_generation import ( + generate_reply, + get_encoded_length, + get_max_prompt_length +) from modules.utils import delete_file, replace_all, save_file diff --git a/modules/evaluate.py b/modules/evaluate.py index 3e555a3e..d94863d9 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -8,8 +8,10 @@ from tqdm import tqdm from modules import shared from modules.models import load_model, unload_model -from modules.models_settings import (get_model_settings_from_yamls, - update_model_parameters) +from modules.models_settings import ( + get_model_settings_from_yamls, + update_model_parameters +) from modules.text_generation import encode diff --git a/modules/html_generator.py b/modules/html_generator.py index ceb167e0..6c83dbe7 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -1,9 +1,3 @@ -''' - -This is a library for formatting text outputs as nice HTML. - -''' - import os import re import time diff --git a/modules/models.py b/modules/models.py index 28f6be66..f12e700c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -7,9 +7,15 @@ from pathlib import Path import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights -from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoModelForSeq2SeqLM, AutoTokenizer, - BitsAndBytesConfig, LlamaTokenizer) +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + BitsAndBytesConfig, + LlamaTokenizer +) import modules.shared as shared from modules import llama_attn_hijack, sampler_hijack @@ -21,8 +27,10 @@ transformers.logging.set_verbosity_error() local_rank = None if shared.args.deepspeed: import deepspeed - from transformers.deepspeed import (HfDeepSpeedConfig, - is_deepspeed_zero3_enabled) + from transformers.deepspeed import ( + HfDeepSpeedConfig, + is_deepspeed_zero3_enabled + ) from modules.deepspeed_parameters import generate_ds_config diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py index a37e7906..bf8d478d 100644 --- a/modules/monkey_patch_gptq_lora.py +++ b/modules/monkey_patch_gptq_lora.py @@ -7,10 +7,14 @@ sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit"))) import autograd_4bit from amp_wrapper import AMPWrapper -from autograd_4bit import (Autograd4bitQuantLinear, - load_llama_model_4bit_low_ram) +from autograd_4bit import ( + Autograd4bitQuantLinear, + load_llama_model_4bit_low_ram +) from monkeypatch.peft_tuners_lora_monkey_patch import ( - Linear4bitLt, replace_peft_model_with_gptq_lora_model) + Linear4bitLt, + replace_peft_model_with_gptq_lora_model +) from modules import shared from modules.GPTQ_loader import find_quantized_model_file diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 447c8782..bcec250a 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -3,9 +3,11 @@ import math import torch import transformers from transformers import LogitsWarper -from transformers.generation.logits_process import (LogitNormalization, - LogitsProcessorList, - TemperatureLogitsWarper) +from transformers.generation.logits_process import ( + LogitNormalization, + LogitsProcessorList, + TemperatureLogitsWarper +) class TailFreeLogitsWarper(LogitsWarper): diff --git a/modules/training.py b/modules/training.py index 65f1668a..d0018a0f 100644 --- a/modules/training.py +++ b/modules/training.py @@ -11,12 +11,19 @@ import gradio as gr import torch import transformers from datasets import Dataset, load_dataset -from peft import (LoraConfig, get_peft_model, prepare_model_for_kbit_training, - set_peft_model_state_dict) +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + set_peft_model_state_dict +) from modules import shared, ui, utils -from modules.evaluate import (calculate_perplexity, generate_markdown_table, - save_past_evaluations) +from modules.evaluate import ( + calculate_perplexity, + generate_markdown_table, + save_past_evaluations +) from modules.logging_colors import logger # This mapping is from a very recent commit, not yet released. @@ -25,8 +32,9 @@ try: from peft.utils.other import \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ model_to_lora_modules - from transformers.models.auto.modeling_auto import \ + from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ) MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} except: standard_modules = ["q_proj", "v_proj"] @@ -201,8 +209,9 @@ def clean_path(base_path: str, path: str): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str): if shared.args.monkey_patch: - from monkeypatch.peft_tuners_lora_monkey_patch import \ + from monkeypatch.peft_tuners_lora_monkey_patch import ( replace_peft_model_with_gptq_lora_model + ) replace_peft_model_with_gptq_lora_model() global WANT_INTERRUPT diff --git a/server.py b/server.py index acb62c79..a4f73bcc 100644 --- a/server.py +++ b/server.py @@ -38,12 +38,17 @@ from modules.github import clone_or_pull_repository from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model -from modules.models_settings import (apply_model_settings_to_state, - get_model_settings_from_yamls, - save_model_settings, - update_model_parameters) -from modules.text_generation import (generate_reply_wrapper, - get_encoded_length, stop_everything_event) +from modules.models_settings import ( + apply_model_settings_to_state, + get_model_settings_from_yamls, + save_model_settings, + update_model_parameters +) +from modules.text_generation import ( + generate_reply_wrapper, + get_encoded_length, + stop_everything_event +) def load_model_wrapper(selected_model, loader, autoload=False):