EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
tar FP16 VPN Math CLAP Bitcoin Review DeepSeek AI LLM Password VSCode OCR Cloudreve Web XML Card Docker Numpy Image2Text LLAMA Qwen 继承 tqdm Translation Ptyhon BF16 RGB Hotel Miniforge RAR Nginx Tiktoken PyCharm Python Permission 财报 Search Bin 签证 Hungarian Diagram CSV Algorithm Vim Firewall MD5 C++ 飞书 Clash CUDA Github Animate 第一性原理 Streamlit 搞笑 HaggingFace 报税 COCO scipy Ubuntu Vmess GPT4 Pillow ModelScope logger Pickle Color Distillation GIT Random BTC Paper Template Anaconda Logo Data QWEN TensorRT mmap LaTeX 算法题 uWSGI Pytorch PDB CC TTS JSON News Qwen2 CV LeetCode Website Proxy YOLO Qwen2.5 API InvalidArgumentError v0.dev TSV 阿里云 Video Attention 净利润 FP32 Dataset VGG-16 OpenCV FP8 Linux CAM 顶会 UI Safetensors Datetime Django WAN Base64 GGML Food Interview git-lfs 多线程 WebCrawler DeepStream Jetson Magnet uwsgi Heatmap IndexTTS2 Mixtral Zip GPTQ Plotly Claude 音频 Paddle Git FlashAttention llama.cpp Sklearn PDF Statistics FastAPI OpenAI 多进程 Pandas FP64 Michelin transformers 证件照 v2ray Tensor HuggingFace Tracking NameSilo Disk EXCEL Hilton ONNX Gemma diffusers Freesound Markdown Land Baidu Use PyTorch 腾讯云 强化学习 版权 Excel Bert Bipartite UNIX Google git NLP Llama Breakpoint 公式 Transformers Knowledge 递归学习法 Shortcut Windows hf torchinfo SQL Plate Quantize TensorFlow 域名 Agent SAM 图形思考法 LoRA NLTK Domain printf Crawler SQLite BeautifulSoup 关于博主 Jupyter GoogLeNet Input XGBoost SPIE Augmentation PIP CTC CEIR SVR Quantization Conda ResNet-50 ChatGPT
站点统计

本站现有博文320篇,共被浏览760652

本站已经建立2432天!

热门文章
文章归档
回到顶部