EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
EXCEL Diagram Random Crawler Interview LaTeX Pytorch Animate SVR Paddle printf Color CUDA Qwen2 CSV SQLite Food Django tar YOLO FastAPI VSCode Breakpoint Docker TTS Website Pickle UI 腾讯云 Algorithm Ptyhon Firewall ModelScope transformers Windows 阿里云 Jupyter Transformers Proxy SQL Streamlit logger Logo torchinfo API Qwen mmap Augmentation HaggingFace CLAP tqdm Numpy RGB scipy C++ 净利润 llama.cpp Paper Michelin 继承 v0.dev DeepSeek 音频 Tiktoken Excel Clash uWSGI OpenCV UNIX JSON Base64 Miniforge Attention Google Mixtral VPN Sklearn GoogLeNet 财报 OpenAI Github Baidu LoRA Conda PDB LeetCode GPTQ CAM FlashAttention 算法题 WebCrawler Jetson 关于博主 git Freesound InvalidArgumentError Datetime Card GPT4 Review SPIE Heatmap Hilton BeautifulSoup Distillation NLP SAM FP32 Use Magnet Land Llama Markdown TSV 多线程 CV Anaconda Video CTC Disk Git BF16 签证 Quantization Math LLM Image2Text Bitcoin Shortcut Translation 报税 Plate Python Vim FP64 证件照 域名 TensorFlow ResNet-50 Pandas git-lfs CC GIT Template RAR uwsgi Linux PIP Nginx v2ray FP16 PyCharm Tracking Knowledge hf 飞书 Bin FP8 NLTK MD5 BTC CEIR COCO Dataset Gemma ONNX diffusers Cloudreve Safetensors LLAMA Password Statistics TensorRT WAN XML Domain Plotly Hungarian IndexTTS2 多进程 Bert Claude Permission DeepStream AI QWEN PDF Ubuntu VGG-16 Web HuggingFace Hotel GGML ChatGPT 公式 Bipartite OCR Tensor 版权 Input Vmess NameSilo PyTorch Zip Pillow Quantize Qwen2.5 搞笑 XGBoost Data
站点统计

本站现有博文311篇,共被浏览742085

本站已经建立2381天!

热门文章
文章归档
回到顶部