diff --git a/modules/training.py b/modules/training.py index 96cd6e7c..e2be18e8 100644 --- a/modules/training.py +++ b/modules/training.py @@ -74,7 +74,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le } def generate_prompt(data_point: dict[str, str]): for options, data in formatData.items(): - if set(options.split(',')) == set(data_point.keys()): + if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0): for key, val in data_point.items(): data = data.replace(f'%{key}%', val) return data