Read GGUF metadata (#3873)

This commit is contained in:
oobabooga 2023-09-11 18:49:30 -03:00 committed by GitHub
parent 39f4800d94
commit 9331ab4798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 52 deletions

View File

@ -7,10 +7,7 @@ from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import (
get_model_settings_from_yamls,
update_model_parameters
)
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import (
encode,
generate_reply,
@ -132,7 +129,7 @@ class Handler(BaseHTTPRequestHandler):
shared.model_name = model_name
unload_model()
model_settings = get_model_settings_from_yamls(shared.model_name)
model_settings = get_model_metadata(shared.model_name)
shared.settings.update(model_settings)
update_model_parameters(model_settings, initial=True)

View File

@ -1,11 +1,9 @@
from modules import shared
from modules.utils import get_available_models
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
update_model_parameters)
from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import *
from modules import shared
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models
def get_current_model_list() -> list:
@ -33,7 +31,7 @@ def load_model(model_name: str) -> dict:
shared.model_name = model_name
unload_model()
model_settings = get_model_settings_from_yamls(shared.model_name)
model_settings = get_model_metadata(shared.model_name)
shared.settings.update(model_settings)
update_model_parameters(model_settings, initial=True)

View File

@ -8,10 +8,7 @@ from tqdm import tqdm
from modules import shared
from modules.models import load_model, unload_model
from modules.models_settings import (
get_model_settings_from_yamls,
update_model_parameters
)
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import encode
@ -69,7 +66,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
if model != 'current model':
try:
yield cumulative_log + f"Loading {model}...\n\n"
model_settings = get_model_settings_from_yamls(model)
model_settings = get_model_metadata(model)
shared.settings.update(model_settings) # hijacking the interface defaults
update_model_parameters(model_settings) # hijacking the command-line arguments
shared.model_name = model

84
modules/metadata_gguf.py Normal file
View File

@ -0,0 +1,84 @@
import struct
from enum import IntEnum
class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
_simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?",
}
value_type_info = {
GGUFValueType.UINT8: 1,
GGUFValueType.INT8: 1,
GGUFValueType.UINT16: 2,
GGUFValueType.INT16: 2,
GGUFValueType.UINT32: 4,
GGUFValueType.INT32: 4,
GGUFValueType.FLOAT32: 4,
GGUFValueType.UINT64: 8,
GGUFValueType.INT64: 8,
GGUFValueType.FLOAT64: 8,
GGUFValueType.BOOL: 1,
}
def get_single(value_type, file):
if value_type == GGUFValueType.STRING:
value_length = struct.unpack("<Q", file.read(8))[0]
value = file.read(value_length).decode('utf-8')
else:
type_str = _simple_value_packing.get(value_type)
bytes_length = value_type_info.get(value_type)
value = struct.unpack(type_str, file.read(bytes_length))[0]
return value
def load_metadata(fname):
metadata = {}
with open(fname, 'rb') as file:
GGUF_MAGIC = struct.unpack("<I", file.read(4))[0]
GGUF_VERSION = struct.unpack("<I", file.read(4))[0]
ti_data_count = struct.unpack("<Q", file.read(8))[0]
kv_data_count = struct.unpack("<Q", file.read(8))[0]
for i in range(kv_data_count):
key_length = struct.unpack("<Q", file.read(8))[0]
key = file.read(key_length)
value_type = GGUFValueType(struct.unpack("<I", file.read(4))[0])
if value_type == GGUFValueType.ARRAY:
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
length = struct.unpack("<Q", file.read(8))[0]
for j in range(length):
_ = get_single(ltype, file)
else:
value = get_single(value_type, file)
metadata[key.decode()] = value
return metadata

View File

@ -18,9 +18,9 @@ from transformers import (
)
import modules.shared as shared
from modules import llama_attn_hijack, RoPE, sampler_hijack
from modules import RoPE, llama_attn_hijack, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import infer_loader
from modules.models_settings import get_model_metadata
transformers.logging.set_verbosity_error()
@ -62,15 +62,11 @@ def load_model(model_name, loader=None):
'ctransformers': ctransformers_loader,
}
p = Path(model_name)
if p.exists():
model_name = p.parts[-1]
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = infer_loader(model_name)
loader = get_model_metadata(model_name)['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
return None, None

View File

@ -3,23 +3,57 @@ from pathlib import Path
import yaml
from modules import loaders, shared, ui
from modules import loaders, metadata_gguf, shared, ui
def get_model_settings_from_yamls(model):
settings = shared.model_config
def get_fallback_settings():
return {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'n_ctx': 2048,
'rope_freq_base': 0,
}
def get_model_metadata(model):
model_settings = {}
# Get settings from models/config.yaml and models/config-user.yaml
settings = shared.model_config
for pat in settings:
if re.match(pat.lower(), model.lower()):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
if 'loader' not in model_settings:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
model_settings['loader'] = loader
# Read GGUF metadata
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
path = Path(f'{shared.args.model_dir}/{model}')
if path.is_file():
model_file = path
else:
model_file = list(path.glob('*.gguf'))[0]
metadata = metadata_gguf.load_metadata(model_file)
if 'llama.context_length' in metadata:
model_settings['n_ctx'] = metadata['llama.context_length']
return model_settings
def infer_loader(model_name):
def infer_loader(model_name, model_settings):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
model_settings = get_model_settings_from_yamls(model_name)
if not path_to_model.exists():
loader = None
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
@ -85,11 +119,9 @@ def update_model_parameters(state, initial=False):
# UI: update the state variable with the model settings
def apply_model_settings_to_state(model, state):
model_settings = get_model_settings_from_yamls(model)
if 'loader' not in model_settings:
loader = infer_loader(model)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
model_settings = get_model_metadata(model)
if 'loader' in model_settings:
loader = model_settings.pop('loader')
# If the user is using an alternative loader for the same model type, let them keep using it
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']):

View File

@ -15,7 +15,7 @@ from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import (
apply_model_settings_to_state,
get_model_settings_from_yamls,
get_model_metadata,
save_model_settings,
update_model_parameters
)
@ -196,7 +196,7 @@ def load_model_wrapper(selected_model, loader, autoload=False):
if shared.model is not None:
output = f"Successfully loaded `{selected_model}`."
settings = get_model_settings_from_yamls(selected_model)
settings = get_model_metadata(selected_model)
if 'instruction_template' in settings:
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])

View File

@ -1,8 +1,8 @@
import os
import warnings
from modules.logging_colors import logger
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
from modules.logging_colors import logger
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
@ -12,6 +12,7 @@ with RequestBlocker():
import gradio as gr
import matplotlib
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import json
@ -37,13 +38,14 @@ from modules import (
ui_notebook,
ui_parameters,
ui_session,
utils,
utils
)
from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model
from modules.models import load_model
from modules.models_settings import (
get_model_settings_from_yamls,
get_fallback_settings,
get_model_metadata,
update_model_parameters
)
from modules.utils import gradio
@ -169,17 +171,7 @@ if __name__ == "__main__":
shared.settings.update(new_settings)
# Fallback settings for models
shared.model_config['.*'] = {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'rope_freq_base': 0,
}
shared.model_config['.*'] = get_fallback_settings()
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
# Activate the extensions listed on settings.yaml
@ -213,12 +205,18 @@ if __name__ == "__main__":
# If any model has been selected, load it
if shared.model_name != 'None':
model_settings = get_model_settings_from_yamls(shared.model_name)
p = Path(shared.model_name)
if p.exists():
model_name = p.parts[-1]
else:
model_name = shared.model_name
model_settings = get_model_metadata(model_name)
shared.settings.update(model_settings) # hijacking the interface defaults
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
# Load the model
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.model, shared.tokenizer = load_model(model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora)