EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
BF16 Plotly GPTQ 算法题 Password Video Nginx Diagram LLM printf CV Bipartite LoRA Knowledge NameSilo Michelin Breakpoint 域名 Llama Jetson COCO Tensor Disk Paddle PDB 报税 scipy UI 证件照 PyTorch git transformers Jupyter 云服务器 Cloudreve Linux Image2Text Markdown Logo YOLO GoogLeNet Tiktoken VSCode CSV GGML Website Use uWSGI Magnet LaTeX Qwen 多进程 Translation Docker Git Attention Data Pytorch Heatmap News Algorithm CAM mmap Anaconda Safetensors Paper C++ MD5 Permission OpenAI tar Card Sklearn CC Random v0.dev PDF Excel Web Gemma icon 搞笑 Quantization AI Hilton Land ChatGPT Vmess OCR EXCEL FP64 Quantize GPT4 Rebuttal Datetime ModelScope ONNX Plate Firewall Mixtral Color Python Base64 BeautifulSoup Search Conda Agent Bitcoin 版权 Pandas VGG-16 TensorRT GIT Hungarian NLTK Hotel PIP TensorFlow BTC HuggingFace 强化学习 FP32 Math llama.cpp TSV 继承 Bert uwsgi CTC SQLite XML SPIE Freesound 净利润 Shortcut git-lfs 多线程 Food LeetCode 财报 Augmentation Transformers 阿里云 CLAP Bin Streamlit 关于博主 logger Baidu 音频 FP16 WebCrawler Domain InvalidArgumentError Review Django JSON TTS Distillation 顶会 Pickle UNIX RGB CUDA Statistics 公式 Clash 图形思考法 Windows Numpy SQL NLP DeepSeek Tracking Dataset Claude torchinfo 递归学习法 Proxy FastAPI SVR LLAMA Ubuntu v2ray SAM Template FP8 XGBoost VPN tqdm IndexTTS2 API RAR 腾讯云 CEIR HaggingFace 第一性原理 ResNet-50 Input Google Vim WAN Crawler diffusers DeepStream PyCharm 签证 Qwen2.5 QWEN Animate Github 飞书 Interview FlashAttention hf Pillow Qwen2 Ptyhon Zip Miniforge 图标 OpenCV
站点统计

本站现有博文323篇,共被浏览801526

本站已经建立2500天!

热门文章
文章归档
回到顶部