This commit is contained in:
oobabooga 2023-09-11 07:57:38 -07:00
parent 92f3cd624c
commit df52dab67b
4 changed files with 4 additions and 4 deletions

View File

@ -108,7 +108,7 @@ def complex_model_load(model):
req['args']['bf16'] = True # for 24GB req['args']['bf16'] = True # for 24GB
elif '13b' in model: elif '13b' in model:
req['args']['load_in_8bit'] = True # for 24GB req['args']['load_in_8bit'] = True # for 24GB
elif 'ggml' in model: elif 'gguf' in model:
# req['args']['threads'] = 16 # req['args']['threads'] = 16
if '7b' in model: if '7b' in model:
req['args']['n_gpu_layers'] = 100 req['args']['n_gpu_layers'] = 100

View File

@ -125,7 +125,7 @@ class ModelDownloader:
if base_folder is None: if base_folder is None:
base_folder = 'models' if not is_lora else 'loras' base_folder = 'models' if not is_lora else 'loras'
# If the model is of type GGUF or GGML, save directly in the base_folder # If the model is of type GGUF, save directly in the base_folder
if is_llamacpp: if is_llamacpp:
return Path(base_folder) return Path(base_folder)

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch.nn.functional as F import torch.nn.functional as F
from torch import version as torch_version from torch import version as torch_version
from modules import RoPE, shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import clear_torch_cache from modules.models import clear_torch_cache
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length

View File

@ -7,7 +7,7 @@ from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import RoPE, shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
try: try: