EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Firewall 算法题 Clash 云服务器 证件照 diffusers 继承 多进程 FP32 GPT4 SVR FlashAttention Card Markdown CSV PyTorch HaggingFace BeautifulSoup Claude CAM 财报 MD5 LeetCode Permission Tiktoken Random scipy SQL Distillation PIP Bitcoin 净利润 Augmentation Hotel Freesound transformers Ubuntu Qwen printf Mixtral Shortcut TensorRT FastAPI 签证 ONNX RAR Miniforge 公式 CEIR Jetson Use Video Math Image2Text ChatGPT Ptyhon OCR Statistics SAM Tensor 多线程 Bipartite Baidu PDF 音频 CC Magnet Hungarian LLAMA Quantize TensorFlow IndexTTS2 Anaconda Sklearn Search git Vmess Django TSV Linux VPN Cloudreve Google Tracking Qwen2.5 ResNet-50 LaTeX 第一性原理 UI Zip C++ Qwen2 Bin WebCrawler Diagram Data OpenAI API Llama NLTK Password Excel Pandas Translation FP16 Bert Quantization Base64 HuggingFace Website Color NLP Breakpoint Git LLM FP64 Food Land BF16 CLAP 图形思考法 tar ModelScope CUDA SPIE Github Algorithm hf DeepStream Review WAN VGG-16 CV BTC llama.cpp mmap AI Web NameSilo Nginx Vim Template COCO OpenCV Plate Proxy git-lfs 递归学习法 QWEN CTC Conda Docker Heatmap Hilton SQLite InvalidArgumentError EXCEL Michelin DeepSeek Input Transformers VSCode v2ray 版权 关于博主 GIT 腾讯云 logger 强化学习 Domain torchinfo FP8 Paper JSON 阿里云 uwsgi LoRA Knowledge uWSGI 域名 飞书 Streamlit Agent Plotly Dataset Attention PDB PyCharm Pickle Jupyter Pytorch Crawler GoogLeNet 搞笑 Safetensors XGBoost Paddle Interview News Numpy XML Logo 顶会 Gemma YOLO Disk GPTQ tqdm UNIX TTS v0.dev Animate Python Pillow RGB Datetime GGML 报税 Windows
站点统计

本站现有博文321篇,共被浏览779433

本站已经建立2471天!

热门文章
文章归档
回到顶部