Add --rwkv-strategy parameter

This commit is contained in:
oobabooga 2023-03-01 20:02:48 -03:00
parent 99dc95e14e
commit a2a3e8f797
2 changed files with 5 additions and 1 deletions

View File

@ -25,7 +25,10 @@ class RWKVModel:
def from_pretrained(self, path, dtype="fp16", device="cuda"): def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}') if shared.args.rwkv_strategy is None:
model = RWKV(model=os.path.abspath(path), strategy=f'{device} {dtype}')
else:
model = RWKV(model=os.path.abspath(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, os.path.abspath(tokenizer_path)) pipeline = PIPELINE(model, os.path.abspath(tokenizer_path))
result = self() result = self()

View File

@ -63,6 +63,7 @@ parser.add_argument("--compress-weight", action="store_true", help="FlexGen: act
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.') parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
parser.add_argument('--rwkv-strategy', type=str, default=None, help='The strategy to use while loading RWKV models. Examples: "cpu fp32", "cuda fp16", "cuda fp16 *30 -> cpu fp32".')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')