EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Use Mixtral 公式 LLAMA Claude RAR Conda mmap Sklearn OpenCV Heatmap tar ModelScope FlashAttention FP32 音频 Video Miniforge TTS Freesound XGBoost Python RGB Card v0.dev Pillow printf Input GoogLeNet CAM Streamlit Paddle Diagram scipy Nginx Michelin Template Hotel Bipartite Github Markdown VPN 第一性原理 Dataset UNIX Hilton FP16 GPTQ 继承 VSCode Pandas git 算法题 Qwen2.5 Excel PDB Llama XML GGML 搞笑 Base64 NLP 多进程 Color hf ResNet-50 Animate Interview COCO Safetensors UI FP64 阿里云 WebCrawler Disk Magnet Plate FastAPI PyCharm Google InvalidArgumentError SVR Quantization GIT QWEN logger HaggingFace 腾讯云 签证 Ptyhon Logo 域名 CLAP git-lfs TensorFlow Math Statistics Paper Breakpoint diffusers BTC Knowledge Vmess HuggingFace Web Website BF16 Cloudreve Ubuntu 多线程 ChatGPT 报税 版权 SPIE tqdm CSV Git Tiktoken IndexTTS2 MD5 v2ray Qwen 递归学习法 CUDA OCR C++ Docker DeepStream Review Data CTC PIP Zip ONNX Linux PyTorch Tensor EXCEL uwsgi LLM NLTK uWSGI BeautifulSoup Random Plotly Land YOLO 关于博主 Bitcoin Transformers WAN Qwen2 Pickle Bert CV TensorRT 顶会 CC NameSilo LoRA 图形思考法 Distillation Crawler Password Domain CEIR OpenAI Django Tracking PDF Translation SQLite Datetime API Attention Augmentation Jetson Pytorch DeepSeek Permission Anaconda Firewall Quantize Windows Bin Agent Gemma Vim Proxy JSON Algorithm transformers Food 净利润 强化学习 VGG-16 Baidu torchinfo TSV Clash AI 飞书 Jupyter 证件照 FP8 Search LeetCode Image2Text Numpy SQL 财报 llama.cpp Hungarian Shortcut GPT4 SAM LaTeX
站点统计

本站现有博文319篇,共被浏览751724

本站已经建立2408天!

热门文章
文章归档
回到顶部