EADST

Save the LLAMA Model with LoRA to One Model

Save the LLAMA Model with LoRA to One Model

"""
Usage: 
python merge_llama_with_chinese_lora.py \
    --base_model path/to/llama/model \
    --lora_model path/to/first/lora/model [path/to/second/lora/model] \
    --output_type [pth|huggingface] \
    --output_dir path/to/output/dir
"""
import argparse
import json
import os
import gc
import torch
import peft
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
from huggingface_hub import hf_hub_download

parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, required=True,
                    type=str, help="Please specify a base_model")
parser.add_argument('--lora_model', default=None, required=True,
                    type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models.")
parser.add_argument('--offload_dir', default=None, type=str,
                    help="(Optional) Please specify a temp folder for offloading (useful for low-RAM machines). Default None (disable offload).")
parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], type=str,
                    help="save the merged model in pth or huggingface format.")
parser.add_argument('--output_dir', default='./', type=str)


emb_to_model_size = {
    4096 : '7B',
    5120 : '13B',
    6656 : '33B',
    8192 : '65B',
}
num_shards_of_models = {'7B': 1, '13B': 2, '33B': 4, '65B': 8}
params_of_models = {
    '7B':
        {
        "dim": 4096,
        "multiple_of": 256,
        "n_heads": 32,
        "n_layers": 32,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '13B':
        {
        "dim": 5120,
        "multiple_of": 256,
        "n_heads": 40,
        "n_layers": 40,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '33B':
        {
        "dim": 6656,
        "multiple_of": 256,
        "n_heads": 52,
        "n_layers": 60,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '65B':
        {
        "dim": 8192,
        "multiple_of": 256,
        "n_heads": 64,
        "n_layers": 80,
        "norm_eps": 1e-05,
        "vocab_size": -1,
        },
}

def transpose(weight, fan_in_fan_out):
    return weight.T if fan_in_fan_out else weight

# Borrowed and modified from https://github.com/tloen/alpaca-lora
def translate_state_dict_key(k):
    k = k.replace("base_model.model.", "")
    if k == "model.embed_tokens.weight":
        return "tok_embeddings.weight"
    elif k == "model.norm.weight":
        return "norm.weight"
    elif k == "lm_head.weight":
        return "output.weight"
    elif k.startswith("model.layers."):
        layer = k.split(".")[2]
        if k.endswith(".self_attn.q_proj.weight"):
            return f"layers.{layer}.attention.wq.weight"
        elif k.endswith(".self_attn.k_proj.weight"):
            return f"layers.{layer}.attention.wk.weight"
        elif k.endswith(".self_attn.v_proj.weight"):
            return f"layers.{layer}.attention.wv.weight"
        elif k.endswith(".self_attn.o_proj.weight"):
            return f"layers.{layer}.attention.wo.weight"
        elif k.endswith(".mlp.gate_proj.weight"):
            return f"layers.{layer}.feed_forward.w1.weight"
        elif k.endswith(".mlp.down_proj.weight"):
            return f"layers.{layer}.feed_forward.w2.weight"
        elif k.endswith(".mlp.up_proj.weight"):
            return f"layers.{layer}.feed_forward.w3.weight"
        elif k.endswith(".input_layernorm.weight"):
            return f"layers.{layer}.attention_norm.weight"
        elif k.endswith(".post_attention_layernorm.weight"):
            return f"layers.{layer}.ffn_norm.weight"
        elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
            return None
        else:
            print(layer, k)
            raise NotImplementedError
    else:
        print(k)
        raise NotImplementedError


def unpermute(w):
    return (
        w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
    )


def save_shards(model_sd, num_shards: int):
    # Add the no_grad context manager
    with torch.no_grad():
        if num_shards == 1:
            new_state_dict = {}
            for k, v in model_sd.items():
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if "wq" in new_k or "wk" in new_k:
                        new_state_dict[new_k] = unpermute(v)
                    else:
                        new_state_dict[new_k] = v

            os.makedirs(output_dir, exist_ok=True)
            print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth")
            torch.save(new_state_dict, output_dir + "/consolidated.00.pth")
            with open(output_dir + "/params.json", "w") as f:
                json.dump(params, f)
        else:
            new_state_dicts = [dict() for _ in range(num_shards)]
            for k in list(model_sd.keys()):
                v = model_sd[k]
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if new_k=='tok_embeddings.weight':
                        print(f"Processing {new_k}")
                        assert v.size(1)%num_shards==0
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif new_k=='output.weight':
                        print(f"Processing {new_k}")
                        if v.size(0)%num_shards==0:
                            splits = v.split(v.size(0)//num_shards,dim=0)
                        else:
                            size_list = [v.size(0)//num_shards] * num_shards
                            size_list[-1] += v.size(0)%num_shards
                            splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
                    elif new_k=='norm.weight':
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'ffn_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'attention_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards


                    elif 'w1.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    elif 'w2.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif 'w3.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)


                    elif 'wo.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)

                    elif 'wv.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)

                    elif "wq.weight" in new_k or "wk.weight" in new_k:
                        print(f"Processing {new_k}")
                        v = unpermute(v)
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    else:
                        print(f"Unexpected key {new_k}")
                        raise ValueError
                    for sd,split in zip(new_state_dicts,splits):
                        sd[new_k] = split.clone()
                        del split
                    del splits
                del model_sd[k],v
                gc.collect()    # Effectively enforce garbage collection

            os.makedirs(output_dir, exist_ok=True)
            for i,new_state_dict in enumerate(new_state_dicts):
                print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth")
                torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth")
            with open(output_dir + "/params.json", "w") as f:
                print(f"Saving params.json into {output_dir}/params.json")
                json.dump(params, f)


if __name__=='__main__':

    args = parser.parse_args()
    base_model_path = args.base_model
    lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0]
    output_dir = args.output_dir
    output_type = args.output_type
    offload_dir = args.offload_dir

    print(f"Base model: {base_model_path}")
    print(f"LoRA model(s) {lora_model_paths}:")

    if offload_dir is not None:
        # Load with offloading, which is useful for low-RAM machines.
        # Note that if you have enough RAM, please use original method instead, as it is faster.
        base_model = LlamaForCausalLM.from_pretrained(
            base_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            offload_folder=offload_dir,
            offload_state_dict=True,
            low_cpu_mem_usage=True,
            device_map={"": "cpu"},
        )
    else:
        # Original method without offloading
        base_model = LlamaForCausalLM.from_pretrained(
            base_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            device_map={"": "cpu"},
        )

    ## infer the model size from the checkpoint
    embedding_size = base_model.get_input_embeddings().weight.size(1)
    model_size = emb_to_model_size[embedding_size]
    print(f"Peft version: {peft.__version__}")
    print(f"Loading LoRA for {model_size} model")

    lora_model = None
    lora_model_sd = None
    for lora_index, lora_model_path in enumerate(lora_model_paths):
        print(f"Loading LoRA {lora_model_path}...")
        tokenizer = LlamaTokenizer.from_pretrained(lora_model_path)
        print(f"base_model vocab size: {base_model.get_input_embeddings().weight.size(0)}")
        print(f"tokenizer vocab size: {len(tokenizer)}")

        model_vocab_size = base_model.get_input_embeddings().weight.size(0)
        assert len(tokenizer) >= model_vocab_size, \
        (f"The vocab size of the tokenizer {len(tokenizer)} is smaller than the vocab size of the base model {model_vocab_size}\n"
        "This is not the intended use. Please check your model and tokenizer.")
        if model_vocab_size != len(tokenizer):
            base_model.resize_token_embeddings(len(tokenizer))
            print(f"Extended vocabulary size to {len(tokenizer)}")

        first_weight = base_model.model.layers[0].self_attn.q_proj.weight
        first_weight_old = first_weight.clone()

        print(f"Loading LoRA weights")
        if hasattr(peft.LoraModel,'merge_and_unload'):
            try:
                lora_model = PeftModel.from_pretrained(
                    base_model,
                    lora_model_path,
                    device_map={"": "cpu"},
                    torch_dtype=torch.float16,
                )
            except RuntimeError as e:
                if '[49953, 4096]' in str(e):
                    print("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
                           "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
                           "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")
                raise e
            assert torch.allclose(first_weight_old, first_weight)
            print(f"Merging with merge_and_unload...")
            base_model = lora_model.merge_and_unload()
        else:
            base_model_sd = base_model.state_dict()
            try:
                lora_model_sd = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu')
            except FileNotFoundError:
                print("Cannot find lora model on the disk. Downloading lora model from hub...")
                filename = hf_hub_download(repo_id=lora_model_path,filename='adapter_model.bin')
                lora_model_sd = torch.load(filename,map_location='cpu')
            if 'base_model.model.model.embed_tokens.weight' in lora_model_sd:
                assert lora_model_sd['base_model.model.model.embed_tokens.weight'].shape[0]==len(tokenizer), \
                ("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
                "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
                "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")

            lora_config = peft.LoraConfig.from_pretrained(lora_model_path)
            lora_scaling = lora_config.lora_alpha / lora_config.r
            fan_in_fan_out = lora_config.fan_in_fan_out
            lora_keys = [k for k in lora_model_sd if 'lora_A' in k]
            non_lora_keys = [k for k in lora_model_sd if not 'lora_' in k]

            for k in non_lora_keys:
                print(f"merging {k}")
                original_k = k.replace('base_model.model.','')
                base_model_sd[original_k].copy_(lora_model_sd[k])

            for k in lora_keys:
                print(f"merging {k}")
                original_key = k.replace('.lora_A','').replace('base_model.model.','')
                assert original_key in base_model_sd
                lora_a_key = k
                lora_b_key = k.replace('lora_A','lora_B')
                base_model_sd[original_key] += (
                    transpose(lora_model_sd[lora_b_key].float() @ lora_model_sd[lora_a_key].float(),fan_in_fan_out) * lora_scaling
                )
                assert base_model_sd[original_key].dtype == torch.float16

        # did we do anything?
        assert not torch.allclose(first_weight_old, first_weight)

    tokenizer.save_pretrained(output_dir)

    if output_type=='huggingface':
        print("Saving to Hugging Face format...")
        LlamaForCausalLM.save_pretrained(base_model, output_dir) #, state_dict=deloreanized_sd)
    else: # output_type=='pth
        print("Saving to pth format...")

        base_model_sd = base_model.state_dict()
        del lora_model, base_model, lora_model_sd

        params = params_of_models[model_size]
        num_shards = num_shards_of_models[model_size]
        n_layers = params["n_layers"]
        n_heads = params["n_heads"]
        dim = params["dim"]
        dims_per_head = dim // n_heads
        base = 10000.0
        inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))

        save_shards(model_sd=base_model_sd, num_shards=num_shards)

Reference:

merge_llama_with_chinese_lora.py

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Google 阿里云 Baidu 搞笑 Qwen 算法题 Transformers Breakpoint BTC 云服务器 LLM XGBoost AI Input HaggingFace v2ray transformers git-lfs Tiktoken Template ModelScope NLP Gemma mmap Pickle Math HuggingFace Ptyhon NameSilo Llama 论文速读 Land Animate COCO 图标 C++ Bitcoin Interview Data Password Crawler Tracking Shortcut Web Algorithm Quantization BF16 Paddle Windows SVR Pytorch 域名 强化学习 GPTQ v0.dev LaTeX Docker Ubuntu Streamlit 证件照 财报 Sklearn Website NLTK Datetime 关于博主 Logo diffusers Proxy Distillation TensorFlow FP64 torchinfo 继承 Github UNIX OpenAI Firewall MD5 Freesound Jupyter Anaconda 音频 OCR printf XML PIP Food 公式 Git hf Translation llama.cpp DeepSeek RAR Image2Text LLAMA GPT4 Qwen2 WAN CLAP ChatGPT GoogLeNet Rebuttal 腾讯云 tqdm FastAPI icon Bipartite YOLO Permission VPN TensorRT Tensor VSCode Numpy Claude Knowledge VGG-16 Linux SQLite FlashAttention ResNet-50 LeetCode Domain Excel Pillow Magnet Quantize FP32 CSV Statistics Disk Color SPIE TSV uWSGI Pandas CAM git LoRA SAM Vim Hungarian CV 飞书 Mixtral Jetson FP8 Miniforge ONNX PyTorch GIT Markdown Michelin 第一性原理 OpenCV Review IndexTTS2 API InvalidArgumentError Nginx WebCrawler Agent 净利润 Clash CEIR Cloudreve TTS Heatmap BeautifulSoup QWEN Zip 多线程 Diagram JSON Python Dataset Base64 CTC Django 多进程 Video uwsgi PDF Plate Bert 版权 EXCEL Plotly Conda DeepStream 递归学习法 Attention CUDA Safetensors CC Use logger PyCharm Bin Qwen2.5 Random Hotel PDB Card SQL GGML 签证 tar RGB News 顶会 报税 Paper UI FP16 Vmess 图形思考法 scipy Augmentation Search Hilton
站点统计

本站现有博文326篇,共被浏览825510

本站已经建立2531天!

热门文章
文章归档
回到顶部