mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Disable kernel threshold for gpt-j
This commit is contained in:
parent
1ac003d41c
commit
41ec682834
@ -14,7 +14,7 @@ import llama_inference_offload
|
|||||||
from quant import make_quant
|
from quant import make_quant
|
||||||
from modelutils import find_layers
|
from modelutils import find_layers
|
||||||
|
|
||||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']):
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||||
config = AutoConfig.from_pretrained(model)
|
config = AutoConfig.from_pretrained(model)
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -32,7 +32,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
for name in exclude_layers:
|
for name in exclude_layers:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel)
|
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
@ -109,7 +109,8 @@ def load_quantized(model_name):
|
|||||||
if shared.args.pre_layer:
|
if shared.args.pre_layer:
|
||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
||||||
else:
|
else:
|
||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
|
threshold = False if model_type == 'gptj' else 128
|
||||||
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
||||||
|
|
||||||
# accelerate offload (doesn't work properly)
|
# accelerate offload (doesn't work properly)
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
|
Loading…
Reference in New Issue
Block a user