From 659bb76722fb0dec8932839e339070e80ae2c987 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 1 Mar 2023 12:08:55 -0300 Subject: [PATCH] Add RWKVModel class --- modules/RWKV.py | 19 ++++++++++++++----- modules/models.py | 6 ++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index b7388ea7..c4481043 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -16,11 +16,20 @@ os.environ["RWKV_CUDA_ON"] = '0' # '1' : use CUDA kernel for seq mode (much fas from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS +class RWKVModel: + def __init__(self): + pass -def load_RWKV_model(path): - print(f'strategy={"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}') + @classmethod + def from_pretrained(self, path, dtype="fp16", device="cuda"): + tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") - model = RWKV(model=path.as_posix(), strategy=f'{"cpu" if shared.args.cpu else "cuda"} {"fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16"}') - pipeline = PIPELINE(model, Path("models/20B_tokenizer.json").as_posix()) + model = RWKV(model=path.as_posix(), strategy=f'{device} {dtype}') + pipeline = PIPELINE(model, tokenizer_path.as_posix()) - return pipeline + result = self() + result.model = pipeline + return result + + def generate(self, context, **kwargs): + return self.model.generate(context, **kwargs) diff --git a/modules/models.py b/modules/models.py index b3e4b8e0..955ade0b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -79,9 +79,11 @@ def load_model(model_name): # RMKV model (not on HuggingFace) elif shared.is_RWKV: - from modules.RWKV import load_RWKV_model + from modules.RWKV import RWKVModel - return load_RWKV_model(Path(f'models/{model_name}')), None + model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + + return model, None # Custom else: