diff --git a/modules/presets.py b/modules/presets.py index 14c5a777..06e2be70 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -44,7 +44,7 @@ def default_preset(): 'dry_base': 1.75, 'dry_allowed_length': 2, 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', - 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram' + 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nencoder_repetition_penalty\nno_repeat_ngram' } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 9e865bb3..65503c12 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -537,6 +537,7 @@ def get_logits_processor_patch(self, **kwargs): 'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty', 'PresencePenaltyLogitsProcessor': 'presence_penalty', 'FrequencyPenaltyLogitsProcessor': 'frequency_penalty', + 'DRYLogitsProcessor': 'dry', 'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty', 'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram', } @@ -588,7 +589,7 @@ def generation_config_init_patch(self, **kwargs): self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') self.temperature_last = kwargs.pop("temperature_last", False) - self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram']) + self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'encoder_repetition_penalty', 'no_repeat_ngram']) def hijack_samplers():