EADST

Save the LLAMA Model with LoRA to One Model

Save the LLAMA Model with LoRA to One Model

"""
Usage: 
python merge_llama_with_chinese_lora.py \
    --base_model path/to/llama/model \
    --lora_model path/to/first/lora/model [path/to/second/lora/model] \
    --output_type [pth|huggingface] \
    --output_dir path/to/output/dir
"""
import argparse
import json
import os
import gc
import torch
import peft
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer
from huggingface_hub import hf_hub_download

parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, required=True,
                    type=str, help="Please specify a base_model")
parser.add_argument('--lora_model', default=None, required=True,
                    type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models.")
parser.add_argument('--offload_dir', default=None, type=str,
                    help="(Optional) Please specify a temp folder for offloading (useful for low-RAM machines). Default None (disable offload).")
parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], type=str,
                    help="save the merged model in pth or huggingface format.")
parser.add_argument('--output_dir', default='./', type=str)


emb_to_model_size = {
    4096 : '7B',
    5120 : '13B',
    6656 : '33B',
    8192 : '65B',
}
num_shards_of_models = {'7B': 1, '13B': 2, '33B': 4, '65B': 8}
params_of_models = {
    '7B':
        {
        "dim": 4096,
        "multiple_of": 256,
        "n_heads": 32,
        "n_layers": 32,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '13B':
        {
        "dim": 5120,
        "multiple_of": 256,
        "n_heads": 40,
        "n_layers": 40,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '33B':
        {
        "dim": 6656,
        "multiple_of": 256,
        "n_heads": 52,
        "n_layers": 60,
        "norm_eps": 1e-06,
        "vocab_size": -1,
        },
    '65B':
        {
        "dim": 8192,
        "multiple_of": 256,
        "n_heads": 64,
        "n_layers": 80,
        "norm_eps": 1e-05,
        "vocab_size": -1,
        },
}

def transpose(weight, fan_in_fan_out):
    return weight.T if fan_in_fan_out else weight

# Borrowed and modified from https://github.com/tloen/alpaca-lora
def translate_state_dict_key(k):
    k = k.replace("base_model.model.", "")
    if k == "model.embed_tokens.weight":
        return "tok_embeddings.weight"
    elif k == "model.norm.weight":
        return "norm.weight"
    elif k == "lm_head.weight":
        return "output.weight"
    elif k.startswith("model.layers."):
        layer = k.split(".")[2]
        if k.endswith(".self_attn.q_proj.weight"):
            return f"layers.{layer}.attention.wq.weight"
        elif k.endswith(".self_attn.k_proj.weight"):
            return f"layers.{layer}.attention.wk.weight"
        elif k.endswith(".self_attn.v_proj.weight"):
            return f"layers.{layer}.attention.wv.weight"
        elif k.endswith(".self_attn.o_proj.weight"):
            return f"layers.{layer}.attention.wo.weight"
        elif k.endswith(".mlp.gate_proj.weight"):
            return f"layers.{layer}.feed_forward.w1.weight"
        elif k.endswith(".mlp.down_proj.weight"):
            return f"layers.{layer}.feed_forward.w2.weight"
        elif k.endswith(".mlp.up_proj.weight"):
            return f"layers.{layer}.feed_forward.w3.weight"
        elif k.endswith(".input_layernorm.weight"):
            return f"layers.{layer}.attention_norm.weight"
        elif k.endswith(".post_attention_layernorm.weight"):
            return f"layers.{layer}.ffn_norm.weight"
        elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
            return None
        else:
            print(layer, k)
            raise NotImplementedError
    else:
        print(k)
        raise NotImplementedError


def unpermute(w):
    return (
        w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
    )


def save_shards(model_sd, num_shards: int):
    # Add the no_grad context manager
    with torch.no_grad():
        if num_shards == 1:
            new_state_dict = {}
            for k, v in model_sd.items():
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if "wq" in new_k or "wk" in new_k:
                        new_state_dict[new_k] = unpermute(v)
                    else:
                        new_state_dict[new_k] = v

            os.makedirs(output_dir, exist_ok=True)
            print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth")
            torch.save(new_state_dict, output_dir + "/consolidated.00.pth")
            with open(output_dir + "/params.json", "w") as f:
                json.dump(params, f)
        else:
            new_state_dicts = [dict() for _ in range(num_shards)]
            for k in list(model_sd.keys()):
                v = model_sd[k]
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if new_k=='tok_embeddings.weight':
                        print(f"Processing {new_k}")
                        assert v.size(1)%num_shards==0
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif new_k=='output.weight':
                        print(f"Processing {new_k}")
                        if v.size(0)%num_shards==0:
                            splits = v.split(v.size(0)//num_shards,dim=0)
                        else:
                            size_list = [v.size(0)//num_shards] * num_shards
                            size_list[-1] += v.size(0)%num_shards
                            splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
                    elif new_k=='norm.weight':
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'ffn_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'attention_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards


                    elif 'w1.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    elif 'w2.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif 'w3.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)


                    elif 'wo.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)

                    elif 'wv.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)

                    elif "wq.weight" in new_k or "wk.weight" in new_k:
                        print(f"Processing {new_k}")
                        v = unpermute(v)
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    else:
                        print(f"Unexpected key {new_k}")
                        raise ValueError
                    for sd,split in zip(new_state_dicts,splits):
                        sd[new_k] = split.clone()
                        del split
                    del splits
                del model_sd[k],v
                gc.collect()    # Effectively enforce garbage collection

            os.makedirs(output_dir, exist_ok=True)
            for i,new_state_dict in enumerate(new_state_dicts):
                print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth")
                torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth")
            with open(output_dir + "/params.json", "w") as f:
                print(f"Saving params.json into {output_dir}/params.json")
                json.dump(params, f)


if __name__=='__main__':

    args = parser.parse_args()
    base_model_path = args.base_model
    lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0]
    output_dir = args.output_dir
    output_type = args.output_type
    offload_dir = args.offload_dir

    print(f"Base model: {base_model_path}")
    print(f"LoRA model(s) {lora_model_paths}:")

    if offload_dir is not None:
        # Load with offloading, which is useful for low-RAM machines.
        # Note that if you have enough RAM, please use original method instead, as it is faster.
        base_model = LlamaForCausalLM.from_pretrained(
            base_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            offload_folder=offload_dir,
            offload_state_dict=True,
            low_cpu_mem_usage=True,
            device_map={"": "cpu"},
        )
    else:
        # Original method without offloading
        base_model = LlamaForCausalLM.from_pretrained(
            base_model_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            device_map={"": "cpu"},
        )

    ## infer the model size from the checkpoint
    embedding_size = base_model.get_input_embeddings().weight.size(1)
    model_size = emb_to_model_size[embedding_size]
    print(f"Peft version: {peft.__version__}")
    print(f"Loading LoRA for {model_size} model")

    lora_model = None
    lora_model_sd = None
    for lora_index, lora_model_path in enumerate(lora_model_paths):
        print(f"Loading LoRA {lora_model_path}...")
        tokenizer = LlamaTokenizer.from_pretrained(lora_model_path)
        print(f"base_model vocab size: {base_model.get_input_embeddings().weight.size(0)}")
        print(f"tokenizer vocab size: {len(tokenizer)}")

        model_vocab_size = base_model.get_input_embeddings().weight.size(0)
        assert len(tokenizer) >= model_vocab_size, \
        (f"The vocab size of the tokenizer {len(tokenizer)} is smaller than the vocab size of the base model {model_vocab_size}\n"
        "This is not the intended use. Please check your model and tokenizer.")
        if model_vocab_size != len(tokenizer):
            base_model.resize_token_embeddings(len(tokenizer))
            print(f"Extended vocabulary size to {len(tokenizer)}")

        first_weight = base_model.model.layers[0].self_attn.q_proj.weight
        first_weight_old = first_weight.clone()

        print(f"Loading LoRA weights")
        if hasattr(peft.LoraModel,'merge_and_unload'):
            try:
                lora_model = PeftModel.from_pretrained(
                    base_model,
                    lora_model_path,
                    device_map={"": "cpu"},
                    torch_dtype=torch.float16,
                )
            except RuntimeError as e:
                if '[49953, 4096]' in str(e):
                    print("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
                           "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
                           "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")
                raise e
            assert torch.allclose(first_weight_old, first_weight)
            print(f"Merging with merge_and_unload...")
            base_model = lora_model.merge_and_unload()
        else:
            base_model_sd = base_model.state_dict()
            try:
                lora_model_sd = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu')
            except FileNotFoundError:
                print("Cannot find lora model on the disk. Downloading lora model from hub...")
                filename = hf_hub_download(repo_id=lora_model_path,filename='adapter_model.bin')
                lora_model_sd = torch.load(filename,map_location='cpu')
            if 'base_model.model.model.embed_tokens.weight' in lora_model_sd:
                assert lora_model_sd['base_model.model.model.embed_tokens.weight'].shape[0]==len(tokenizer), \
                ("The vocab size of the tokenizer does not match the vocab size of the LoRA weight. \n"
                "Did you misuse the LLaMA tokenizer with the Alpaca-LoRA weight?\n"
                "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")

            lora_config = peft.LoraConfig.from_pretrained(lora_model_path)
            lora_scaling = lora_config.lora_alpha / lora_config.r
            fan_in_fan_out = lora_config.fan_in_fan_out
            lora_keys = [k for k in lora_model_sd if 'lora_A' in k]
            non_lora_keys = [k for k in lora_model_sd if not 'lora_' in k]

            for k in non_lora_keys:
                print(f"merging {k}")
                original_k = k.replace('base_model.model.','')
                base_model_sd[original_k].copy_(lora_model_sd[k])

            for k in lora_keys:
                print(f"merging {k}")
                original_key = k.replace('.lora_A','').replace('base_model.model.','')
                assert original_key in base_model_sd
                lora_a_key = k
                lora_b_key = k.replace('lora_A','lora_B')
                base_model_sd[original_key] += (
                    transpose(lora_model_sd[lora_b_key].float() @ lora_model_sd[lora_a_key].float(),fan_in_fan_out) * lora_scaling
                )
                assert base_model_sd[original_key].dtype == torch.float16

        # did we do anything?
        assert not torch.allclose(first_weight_old, first_weight)

    tokenizer.save_pretrained(output_dir)

    if output_type=='huggingface':
        print("Saving to Hugging Face format...")
        LlamaForCausalLM.save_pretrained(base_model, output_dir) #, state_dict=deloreanized_sd)
    else: # output_type=='pth
        print("Saving to pth format...")

        base_model_sd = base_model.state_dict()
        del lora_model, base_model, lora_model_sd

        params = params_of_models[model_size]
        num_shards = num_shards_of_models[model_size]
        n_layers = params["n_layers"]
        n_heads = params["n_heads"]
        dim = params["dim"]
        dims_per_head = dim // n_heads
        base = 10000.0
        inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))

        save_shards(model_sd=base_model_sd, num_shards=num_shards)

Reference:

merge_llama_with_chinese_lora.py

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
InvalidArgumentError Qwen2 FP32 FP16 TSV GPT4 Michelin SAM Docker Cloudreve Web llama.cpp Pandas Safetensors PyCharm Magnet CTC Qwen2.5 飞书 报税 Bitcoin Search Sklearn Land 强化学习 SQL Github 公式 版权 VPN Attention OCR tar Claude 音频 Qwen Base64 FlashAttention Diagram 搞笑 Markdown NLTK News Crawler CEIR XGBoost Bipartite YOLO Data XML UI SQLite 论文速读 BF16 diffusers 阿里云 torchinfo Template Domain Algorithm Logo 递归学习法 Numpy 算法题 Plotly Ptyhon Color 域名 LaTeX Random Pillow API Datetime EXCEL Zip Rebuttal UNIX JSON 财报 Permission CV scipy PDB Hotel ONNX GPTQ Python Miniforge Bert Paper Dataset 第一性原理 Password C++ Card Gemma mmap FastAPI OpenAI Translation GIT DeepSeek HaggingFace Clash Git 净利润 DeepStream Bin SVR Firewall Tracking Freesound Baidu ResNet-50 多线程 Statistics ModelScope 证件照 Agent FP8 Disk Image2Text VGG-16 继承 Breakpoint Quantization WAN Jetson 多进程 Mixtral CC AI Review printf 签证 Jupyter Augmentation TensorFlow RAR Proxy Interview WebCrawler Vim v2ray Pickle HuggingFace git Hilton CSV Nginx transformers hf Tensor 图标 Transformers Ubuntu 顶会 云服务器 Animate ChatGPT Knowledge NameSilo PyTorch FP64 SPIE LLM Video 关于博主 Llama Anaconda CLAP 图形思考法 Excel LLAMA CUDA logger TTS IndexTTS2 uwsgi QWEN BeautifulSoup Use Conda Streamlit Windows COCO MD5 Google Shortcut Pytorch PIP git-lfs Input BTC 腾讯云 Tiktoken Vmess Plate Distillation OpenCV Website 论文 v0.dev LoRA GGML RGB tqdm Hungarian Math Django LeetCode uWSGI Food Quantize Linux Paddle Heatmap GoogLeNet VSCode TensorRT icon NLP CAM PDF
站点统计

本站现有博文328篇,共被浏览847265

本站已经建立2553天!

热门文章
文章归档
回到顶部