EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Python LaTeX 版权 Hungarian RGB Ptyhon Quantize Cloudreve AI ModelScope Card Firewall PyCharm WebCrawler GIT Bin git GGML 财报 DeepSeek Numpy CLAP Zip 关于博主 Review v2ray OpenCV CEIR Shortcut BeautifulSoup Pillow FP8 UNIX XML PDB Template Pandas 证件照 FP16 签证 v0.dev Streamlit Breakpoint Docker logger Gemma Magnet HuggingFace Random QWEN 域名 Search VSCode scipy XGBoost git-lfs TSV mmap FlashAttention Vmess Pickle Anaconda Baidu Tracking Logo Proxy YOLO 强化学习 GPT4 Color Image2Text Django Permission Tiktoken Plate SQLite printf Attention Ubuntu Linux Statistics Bert NLP VGG-16 Domain Web Augmentation Video OCR Land tar 图标 SAM Distillation Windows Disk 继承 算法题 PIP Math Algorithm Qwen2.5 WAN icon API diffusers SVR BTC 多进程 Git UI Password Animate 公式 第一性原理 Rebuttal Nginx Agent hf 云服务器 ChatGPT ONNX Markdown IndexTTS2 NLTK FastAPI Sklearn 顶会 多线程 Bitcoin 音频 uwsgi Pytorch COCO LLM CAM Safetensors LLAMA Data CTC Input tqdm LeetCode 阿里云 Hilton llama.cpp Michelin Transformers 腾讯云 OpenAI EXCEL LoRA Paddle DeepStream CC Freesound Qwen2 Google 飞书 NameSilo Claude TTS 图形思考法 MD5 Website Use 报税 TensorRT GoogLeNet PDF VPN HaggingFace Dataset uWSGI TensorFlow Interview 净利润 PyTorch Crawler SQL 递归学习法 Vim 搞笑 transformers CSV Mixtral JSON Translation CUDA Plotly Diagram RAR Paper BF16 SPIE Food Clash CV Qwen GPTQ Conda Base64 Knowledge torchinfo Excel FP64 ResNet-50 Llama Heatmap News C++ InvalidArgumentError FP32 Hotel Github 论文速读 Tensor Bipartite Datetime Jetson Miniforge Jupyter Quantization
站点统计

本站现有博文326篇,共被浏览825373

本站已经建立2531天!

热门文章
文章归档
回到顶部