diff --git a/extensions/multimodal/abstract_pipeline.py b/extensions/multimodal/abstract_pipeline.py index 58421941..9c49935a 100644 --- a/extensions/multimodal/abstract_pipeline.py +++ b/extensions/multimodal/abstract_pipeline.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch from PIL import Image +from transformers import is_torch_xpu_available class AbstractMultimodalPipeline(ABC): @@ -55,7 +56,7 @@ class AbstractMultimodalPipeline(ABC): def _get_device(self, setting_name: str, params: dict): if params[setting_name] is None: - return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return torch.device("cuda:0" if torch.cuda.is_available() else "xpu:0" if is_torch_xpu_available() else "cpu") return torch.device(params[setting_name]) def _get_dtype(self, setting_name: str, params: dict): diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index 987f5ba7..8623cf8d 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -1,5 +1,6 @@ from pathlib import Path +from accelerate import is_xpu_available from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig import modules.shared as shared @@ -41,7 +42,7 @@ def load_quantized(model_name): # Define the params for AutoGPTQForCausalLM.from_quantized params = { 'model_basename': pt_path.stem, - 'device': "cuda:0" if not shared.args.cpu else "cpu", + 'device': "xpu:0" if is_xpu_available() else "cuda:0" if not shared.args.cpu else "cpu", 'use_triton': shared.args.triton, 'inject_fused_attention': not shared.args.no_inject_fused_attention, 'inject_fused_mlp': not shared.args.no_inject_fused_mlp, diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index bc528b18..cdbcf08b 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -5,15 +5,15 @@ from pathlib import Path import accelerate import torch import transformers +from accelerate import is_xpu_available +from gptq_for_llama import llama_inference_offload +from gptq_for_llama.modelutils import find_layers +from gptq_for_llama.quant import make_quant from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared from modules.logging_colors import logger -from gptq_for_llama import llama_inference_offload -from gptq_for_llama.modelutils import find_layers -from gptq_for_llama.quant import make_quant - # This function is a replacement for the load_quant function in the # GPTQ-for_LLaMa repository. It supports more models and branches. @@ -144,7 +144,7 @@ def load_quantized(model_name): model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) # accelerate offload (doesn't work properly) - if shared.args.gpu_memory or torch.cuda.device_count() > 1: + if shared.args.gpu_memory or torch.cuda.device_count() > 1 or (is_xpu_available() and torch.xpu.device_count() > 1): if shared.args.gpu_memory: memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' @@ -163,6 +163,9 @@ def load_quantized(model_name): # No offload elif not shared.args.cpu: - model = model.to(torch.device('cuda:0')) + if is_xpu_available(): + model = model.to(torch.device("xpu:0")) + else: + model = model.to(torch.device('cuda:0')) return model diff --git a/modules/LoRA.py b/modules/LoRA.py index 1f1156cf..4b119994 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -2,6 +2,7 @@ from pathlib import Path import torch from peft import PeftModel +from transformers import is_torch_xpu_available import modules.shared as shared from modules.logging_colors import logger @@ -179,6 +180,9 @@ def add_lora_transformers(lora_names): if torch.backends.mps.is_available(): device = torch.device('mps') shared.model = shared.model.to(device) + elif is_torch_xpu_available(): + device = torch.device("xpu:0") + shared.model = shared.model.to(device) else: shared.model = shared.model.cuda() diff --git a/modules/RWKV.py b/modules/RWKV.py index 39487c66..8a15e540 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -9,6 +9,7 @@ from pathlib import Path import numpy as np from tokenizers import Tokenizer +from transformers import is_torch_xpu_available import modules.shared as shared from modules.callbacks import Iteratorize @@ -27,7 +28,7 @@ class RWKVModel: pass @classmethod - def from_pretrained(self, path, dtype="fp16", device="cuda"): + def from_pretrained(self, path, dtype="bf16" if is_torch_xpu_available() else "fp16", device="xpu" if is_torch_xpu_available() else "cuda"): tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") if shared.args.rwkv_strategy is None: model = RWKV(model=str(path), strategy=f'{device} {dtype}') diff --git a/modules/callbacks.py b/modules/callbacks.py index e29e397d..bb979a6c 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -5,6 +5,7 @@ from threading import Thread import torch import transformers +from transformers import is_torch_xpu_available import modules.shared as shared @@ -92,4 +93,7 @@ class Iteratorize: def clear_torch_cache(): gc.collect() if not shared.args.cpu: - torch.cuda.empty_cache() + if is_torch_xpu_available(): + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() diff --git a/modules/logits.py b/modules/logits.py index 6fc5bf60..e356a986 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,4 +1,5 @@ import torch +from transformers import is_torch_xpu_available from modules import sampler_hijack, shared from modules.logging_colors import logger @@ -32,13 +33,19 @@ def get_next_logits(prompt, state, use_samplers, previous): scores = sampler_hijack.global_scores[-1] else: if is_non_hf_exllamav2 or is_non_hf_exllamav1: - tokens = shared.tokenizer.encode(prompt).cuda() + if is_torch_xpu_available(): + tokens = shared.tokenizer.encode(prompt).to("xpu:0") + else: + tokens = shared.tokenizer.encode(prompt).cuda() scores = shared.model.get_logits(tokens)[-1][-1] elif is_non_hf_llamacpp: tokens = shared.tokenizer.encode(prompt) scores = shared.model.get_logits(tokens)[-1][-1] else: - tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() + if is_torch_xpu_available(): + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0") + else: + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() output = shared.model(input_ids=tokens) scores = output['logits'][-1][-1] diff --git a/modules/models.py b/modules/models.py index 71203152..de160022 100644 --- a/modules/models.py +++ b/modules/models.py @@ -7,7 +7,12 @@ from pathlib import Path import torch import transformers -from accelerate import infer_auto_device_map, init_empty_weights +from accelerate import ( + infer_auto_device_map, + init_empty_weights, + is_ccl_available, + is_xpu_available +) from transformers import ( AutoConfig, AutoModel, @@ -38,8 +43,12 @@ if shared.args.deepspeed: # Distributed setup local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) - torch.cuda.set_device(local_rank) - deepspeed.init_distributed() + if is_xpu_available() and is_ccl_available(): + torch.xpu.set_device(local_rank) + deepspeed.init_distributed(backend="ccl") + else: + torch.cuda.set_device(local_rank) + deepspeed.init_distributed() ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration @@ -137,8 +146,9 @@ def huggingface_loader(model_name): if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - model = model.to('xpu') + elif is_xpu_available(): + device = torch.device("xpu") + model = model.to(device) else: model = model.cuda() @@ -151,15 +161,10 @@ def huggingface_loader(model_name): # Load with quantization and/or offloading else: - conditions = [ - shared.args.cpu, - torch.cuda.is_available(), - torch.backends.mps.is_available(), - hasattr(torch, 'xpu') and torch.xpu.is_available(), - ] - if not any(conditions): - logger.warning('No GPU has been detected by Pytorch. Falling back to CPU mode.') + if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())): + logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.') + shared.args.cpu = True if shared.args.cpu: @@ -362,7 +367,12 @@ def RWKV_loader(model_name): ''' from modules.RWKV import RWKVModel, RWKVTokenizer - model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + model = RWKVModel.from_pretrained( + Path(f'{shared.args.model_dir}/{model_name}'), + dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", + device="cpu" if shared.args.cpu else "xpu" if is_xpu_available() else "cuda" + ) + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer @@ -380,7 +390,10 @@ def get_max_memory_dict(): # If --auto-devices is provided standalone, try to get a reasonable value # for the maximum memory of device :0 elif shared.args.auto_devices: - total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) + if is_xpu_available(): + total_mem = (torch.xpu.get_device_properties(0).total_memory / (1024 * 1024)) + else: + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) suggestion = round((total_mem - 1000) / 1000) * 1000 if total_mem - suggestion < 800: suggestion -= 1000 @@ -395,7 +408,10 @@ def get_max_memory_dict(): def clear_torch_cache(): gc.collect() if not shared.args.cpu: - torch.cuda.empty_cache() + if is_xpu_available(): + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() def unload_model(): diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index f8546fa0..56bee2a6 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -2,7 +2,7 @@ import math import torch import transformers -from transformers import LogitsWarper +from transformers import LogitsWarper, is_torch_xpu_available from transformers.generation.logits_process import ( LogitNormalization, LogitsProcessor, @@ -106,9 +106,12 @@ class MirostatLogitsWarper(LogitsWarper): break # Normalize the probabilities of the remaining words - prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') - - prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') + if is_torch_xpu_available(): + prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu") + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu") + else: + prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda') + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') observed_surprise = -math.log2(prob_topk[prev_i]) self.e = observed_surprise - self.mirostat_tau diff --git a/modules/text_generation.py b/modules/text_generation.py index c178a53a..bd8f97ef 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -9,7 +9,7 @@ import traceback import numpy as np import torch import transformers -from transformers import LogitsProcessorList +from transformers import LogitsProcessorList, is_torch_xpu_available import modules.shared as shared from modules.callbacks import ( @@ -132,8 +132,8 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt elif torch.backends.mps.is_available(): device = torch.device('mps') return input_ids.to(device) - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - return input_ids.to('xpu') + elif is_torch_xpu_available(): + return input_ids.to("xpu:0") else: return input_ids.cuda() @@ -238,7 +238,8 @@ def set_manual_seed(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + elif is_torch_xpu_available(): + torch.xpu.manual_seed_all(seed) return seed diff --git a/modules/training.py b/modules/training.py index ab5bad24..cc1df37b 100644 --- a/modules/training.py +++ b/modules/training.py @@ -26,6 +26,7 @@ from peft import ( ) from peft.utils.other import \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules +from transformers import is_torch_xpu_available from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) @@ -626,6 +627,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: # TODO: Enable multi-device support ddp_find_unused_parameters=None, no_cuda=shared.args.cpu, + use_ipex=True if is_torch_xpu_available and not shared.args.cpu else False ), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), callbacks=list([Callbacks()]) diff --git a/modules/ui.py b/modules/ui.py index 53a86bea..9d87bad6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,10 +4,10 @@ from pathlib import Path import gradio as gr import torch import yaml +from transformers import is_torch_xpu_available from modules import shared - with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f: css = f.read() with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: @@ -85,9 +85,12 @@ def list_model_elements(): 'rope_freq_base', 'numa', ] - - for i in range(torch.cuda.device_count()): - elements.append(f'gpu_memory_{i}') + if is_torch_xpu_available(): + for i in range(torch.xpu.device_count()): + elements.append(f'gpu_memory_{i}') + else: + for i in range(torch.cuda.device_count()): + elements.append(f'gpu_memory_{i}') return elements diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 5d9b6cb6..8208b63a 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -8,6 +8,7 @@ from pathlib import Path import gradio as gr import psutil import torch +from transformers import is_torch_xpu_available from modules import loaders, shared, ui, utils from modules.logging_colors import logger @@ -27,8 +28,12 @@ def create_ui(): # Finding the default values for the GPU and CPU memories total_mem = [] - for i in range(torch.cuda.device_count()): - total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))) + if is_torch_xpu_available(): + for i in range(torch.xpu.device_count()): + total_mem.append(math.floor(torch.xpu.get_device_properties(i).total_memory / (1024 * 1024))) + else: + for i in range(torch.cuda.device_count()): + total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024))) default_gpu_mem = [] if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0: diff --git a/one_click.py b/one_click.py index 169d328b..2f3dc172 100644 --- a/one_click.py +++ b/one_click.py @@ -56,6 +56,19 @@ def cpu_has_avx2(): return True +def cpu_has_amx(): + try: + import cpuinfo + + info = cpuinfo.get_cpu_info() + if 'amx' in info['flags']: + return True + else: + return False + except: + return True + + def torch_version(): site_packages_path = None for sitedir in site.getsitepackages():