From aeecd8c0add705c8c5037f5039aed7ebb4522674 Mon Sep 17 00:00:00 2001 From: Xuechen Li Date: Thu, 13 Apr 2023 10:52:08 -0700 Subject: [PATCH] add weight diff conversion. --- weight_diff.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 weight_diff.py diff --git a/weight_diff.py b/weight_diff.py new file mode 100644 index 0000000..108d4a0 --- /dev/null +++ b/weight_diff.py @@ -0,0 +1,157 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import fire +import torch +import tqdm +import transformers +from train import smart_tokenizer_and_embedding_resize + + +@torch.inference_mode() +def make_diff( + path_raw: str, path_tuned: str, path_diff: str, device="cpu", # "cuda" or "cpu" +): + """Make the weight diff. + + This function is given to present full transparency of how the weight diff was created. + + Run: + python weight_diff.py make_diff --path_raw --path_tuned --path_diff + """ + model_tuned: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( + path_tuned, + device_map={"": torch.device(device)}, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( + path_raw, + device_map={"": torch.device(device)}, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + + tokenizer_tuned: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_tuned + ) + tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_raw + ) + if tokenizer_raw.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + model=model_raw, + tokenizer=tokenizer_raw, + ) + + state_dict_tuned = model_tuned.state_dict() + state_dict_raw = model_raw.state_dict() + for key in tqdm.tqdm(state_dict_tuned): + state_dict_tuned[key].add_(-state_dict_raw[key]) + + model_tuned.save_pretrained(path_diff) + tokenizer_tuned.save_pretrained(path_diff) + + +@torch.inference_mode() +def recover( + path_raw, + path_diff, + path_tuned: Optional[str] = None, + device="cpu", + test_inference=True, + check_integrity_naively=True, +): + """Recover the original weights from the released weight diff. + + This function is given for you to run. + + Things to do before running this: + 1. Convert Meta's released weights into huggingface format. Follow this guide: + https://huggingface.co/docs/transformers/main/model_doc/llama + 2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at: + https://huggingface.co/tatsu-lab/alpaca-7b/tree/main + 3. Run this function with the correct paths. E.g., + python weight_diff.py recover --path_raw --path_diff + + Additional notes: + - If things run too slowly, and you have an 80G GPU lying around, let GPU go brrr by setting `--device "cuda"`. + - If you want to save the recovered weights, set `--path_tuned `. + Next time you can load the recovered weights directly from ``. + """ + model_raw: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( + path_raw, + device_map={"": torch.device(device)}, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + model_recovered: transformers.PreTrainedModel = transformers.AutoModelForCausalLM.from_pretrained( + path_diff, + device_map={"": torch.device(device)}, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ) + + tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_raw + ) + if tokenizer_raw.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + model=model_raw, + tokenizer=tokenizer_raw, + ) + tokenizer_recovered: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained( + path_diff + ) + + state_dict_recovered = model_recovered.state_dict() + state_dict_raw = model_raw.state_dict() + for key in tqdm.tqdm(state_dict_recovered): + state_dict_recovered[key].add_(state_dict_raw[key]) + + if check_integrity_naively: + # This is not a rigorous, cryptographically strong integrity check :) + allsum = sum(state_dict_recovered[key].sum() for key in state_dict_recovered) + assert torch.allclose( + allsum, torch.full_like(allsum, fill_value=50637.1836), atol=1e-2, rtol=0 + ), "Naive integrity check failed. This could imply that some of the checkpoint files are corrupted." + + if path_tuned is not None: + model_recovered.save_pretrained(path_tuned) + + if test_inference: + input_text = ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\r\n\r\n" + "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:" + ) + inputs = tokenizer_recovered(input_text, return_tensors="pt") + out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100) + output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0] + output_text = output_text[len(input_text) :] + print(f"Input: {input_text}\nCompletion: {output_text}") + + return model_recovered, tokenizer_recovered + + +def main(task, **kwargs): + globals()[task](**kwargs) + + +if __name__ == "__main__": + fire.Fire(main)