EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Attention Card OpenCV Qwen2.5 MD5 XML PIP WAN Food 图标 Knowledge ModelScope 腾讯云 云服务器 Magnet Hilton ResNet-50 CTC Bitcoin 多进程 Plotly FP8 Search logger NameSilo Web Qwen2 Cloudreve Git diffusers NLTK CSV Input Rebuttal Quantize QWEN Tiktoken Augmentation LLAMA Template ChatGPT Jetson 音频 TSV WebCrawler Numpy git COCO BF16 tqdm PDF LLM JSON CLAP VSCode Website Disk CV LoRA 证件照 Datetime Bin Ptyhon PyCharm Pytorch CC Crawler Statistics mmap DeepSeek Image2Text LeetCode 搞笑 Paddle BeautifulSoup hf OpenAI Agent 版权 Llama 关于博主 News FlashAttention 财报 VPN 继承 FP16 printf Logo scipy DeepStream Conda Vim API Pandas Shortcut Hungarian 域名 Tracking 报税 净利润 Michelin CUDA 递归学习法 顶会 C++ GoogLeNet TensorRT OCR SPIE Python Freesound transformers GPTQ Permission LaTeX UI Use Windows AI FP32 Breakpoint InvalidArgumentError GIT Qwen Ubuntu UNIX Safetensors Django IndexTTS2 公式 PyTorch Pillow v2ray Firewall Base64 Streamlit TensorFlow Plate 多线程 Miniforge Google Sklearn 第一性原理 Hotel Land CEIR 飞书 XGBoost Data v0.dev Proxy HuggingFace FastAPI Diagram SQLite Quantization Dataset Animate Random Bert Linux Claude Algorithm Pickle ONNX 算法题 VGG-16 git-lfs Math Clash Nginx tar Anaconda 图形思考法 RAR TTS 强化学习 Gemma Paper Mixtral Video Interview Color Heatmap Jupyter HaggingFace Tensor llama.cpp Docker NLP uWSGI SVR FP64 GGML torchinfo Password icon CAM Baidu 阿里云 PDB Excel Translation Domain SAM Vmess GPT4 uwsgi YOLO BTC RGB SQL Markdown Transformers Bipartite Review Distillation Zip EXCEL Github 签证
站点统计

本站现有博文324篇,共被浏览819277

本站已经建立2523天!

热门文章
文章归档
回到顶部