From 46996f6519d483841a49c8a857104a6bf6e2ac7a Mon Sep 17 00:00:00 2001 From: RandoInternetPreson Date: Fri, 27 Sep 2024 23:26:03 -0400 Subject: [PATCH] ExllamaV2 tensor parallelism to increase multi gpu inference speeds (#6356) --- modules/exllamav2.py | 26 ++++++++++++++++++-------- modules/exllamav2_hf.py | 28 +++++++++++++++++++--------- modules/shared.py | 1 + 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a770e342..42b9ade1 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -7,6 +7,7 @@ from exllamav2 import ( ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config, ExLlamaV2Tokenizer ) @@ -54,21 +55,30 @@ class Exllamav2Model: model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + if shared.args.enable_tp: + model.load_tp(split) + elif not shared.args.autosplit: model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - cache = ExLlamaV2Cache_Q4(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - cache = ExLlamaV2Cache(model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + cache = ExLlamaV2Cache_TP(model, base=cache_type) + else: + cache = cache_type(model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 53143d9a..febb2c64 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -9,6 +9,7 @@ from exllamav2 import ( ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_TP, ExLlamaV2Config ) from torch.nn import CrossEntropyLoss @@ -42,21 +43,30 @@ class Exllamav2HF(PreTrainedModel): self.ex_model = ExLlamaV2(config) - if not shared.args.autosplit: - split = None - if shared.args.gpu_split: - split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + split = None + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] - self.ex_model.load(split) + if shared.args.enable_tp: + model.load_tp(split) + elif not shared.args.autosplit: + model.load(split) + # Determine the correct cache type if shared.args.cache_8bit: - self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_8bit elif shared.args.cache_4bit: - self.ex_cache = ExLlamaV2Cache_Q4(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache_Q4 else: - self.ex_cache = ExLlamaV2Cache(self.ex_model, lazy=shared.args.autosplit) + cache_type = ExLlamaV2Cache - if shared.args.autosplit: + # Use TP if specified + if shared.args.enable_tp: + self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) + else: + self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit) + + if shared.args.autosplit and not shared.args.enable_tp: self.ex_model.load_autosplit(self.ex_cache) self.past_seq = None diff --git a/modules/shared.py b/modules/shared.py index 43533a14..894ed6fe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,7 @@ group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to n group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') +group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') # AutoGPTQ group = parser.add_argument_group('AutoGPTQ')