mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Update Training PRO (#4972)
- rolling back safetensors to bi, until it is fixed correctly - removing the ugly checkpoint detour
This commit is contained in:
parent
f1f2c4c3f4
commit
59da429cbd
@ -51,59 +51,9 @@ from modules.logging_colors import logger
|
||||
from modules.models import reload_model
|
||||
from modules.utils import natural_keys
|
||||
|
||||
|
||||
|
||||
## just temporary to avoid warning
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Callable, Optional, Tuple, ContextManager
|
||||
|
||||
|
||||
|
||||
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
||||
def my_checkpoint(
|
||||
function,
|
||||
*args,
|
||||
use_reentrant: Optional[bool] = None,
|
||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn,
|
||||
determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE,
|
||||
debug: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if use_reentrant is None:
|
||||
#print ("reentran = NONE")
|
||||
use_reentrant = True
|
||||
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
||||
preserve = kwargs.pop("preserve_rng_state", True)
|
||||
if kwargs and use_reentrant:
|
||||
raise ValueError(
|
||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False:
|
||||
raise ValueError(
|
||||
"Passing `context_fn` or `debug` is only supported when "
|
||||
"use_reentrant=False."
|
||||
)
|
||||
return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args)
|
||||
else:
|
||||
|
||||
print ("reentran = FALSE")
|
||||
gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator(
|
||||
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
||||
)
|
||||
# Runs pre-forward logic
|
||||
next(gen)
|
||||
ret = function(*args, **kwargs)
|
||||
# Runs post-forward logic
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
return ret
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings(action = "ignore", message="torch.utils.checkpoint:")
|
||||
warnings.filterwarnings(action = "ignore", message="`do_sample` is set to `False`")
|
||||
|
||||
params = {
|
||||
"display_name": "Training PRO",
|
||||
@ -121,6 +71,7 @@ non_serialized_params = {
|
||||
"save_epochs": 0,
|
||||
"checkpoint_offset": 0,
|
||||
"epoch_offset":0,
|
||||
"safe_serialization": False,
|
||||
}
|
||||
|
||||
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
||||
@ -150,7 +101,7 @@ def ui():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
# YY.MM.DD
|
||||
gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
|
||||
gr.Markdown("`Ver: 23.10.20 (REV2)` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=5):
|
||||
@ -290,7 +241,7 @@ def ui():
|
||||
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||
|
||||
with gr.Column():
|
||||
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||
|
||||
with gr.Row():
|
||||
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||
@ -713,7 +664,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
train_template.clear()
|
||||
|
||||
|
||||
#reset stuff
|
||||
print(f"*** LoRA: {lora_name} ***")
|
||||
non_serialized_params.update({"stop_at_loss": stop_at_loss})
|
||||
@ -726,25 +676,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
non_serialized_params.update({"epoch_offset": 0})
|
||||
train_log_graph.clear()
|
||||
|
||||
# === once fixed, this can be removed ==============================
|
||||
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
||||
print("Testing Pytorch...")
|
||||
old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint)
|
||||
|
||||
# Get the signature of your new checkpoint function
|
||||
my_checkpoint_signature = inspect.signature(my_checkpoint)
|
||||
|
||||
# Check if the signatures match
|
||||
if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters:
|
||||
print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}")
|
||||
#print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint")
|
||||
#print(" Once they do, this function can be removed")
|
||||
torch.utils.checkpoint.checkpoint = my_checkpoint
|
||||
|
||||
|
||||
# END OF FPHAM SENTENCE SPLIT functions ===================
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
train_template["template_type"] = "raw_text"
|
||||
logger.info("Loading text file...")
|
||||
@ -1025,7 +957,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
force_save = True
|
||||
|
||||
if force_save:
|
||||
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/")
|
||||
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/", safe_serialization = non_serialized_params['safe_serialization'])
|
||||
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
@ -1252,7 +1184,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
log_train_dataset(trainer)
|
||||
trainer.train()
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])
|
||||
logger.info("LoRA training run is completed and saved.")
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
@ -1353,7 +1285,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
if not tracked.did_save:
|
||||
logger.info("Training complete, saving...")
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])
|
||||
|
||||
if WANT_INTERRUPT:
|
||||
logger.info("Training interrupted.")
|
||||
|
Loading…
Reference in New Issue
Block a user