EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Transformers v2ray 签证 SPIE Nginx NameSilo Paper 版权 Pandas git-lfs FP64 Numpy Quantization 域名 LLAMA Vmess tqdm Web Image2Text API Land PyCharm 多线程 音频 BeautifulSoup Dataset WAN IndexTTS2 继承 PDF CLAP Markdown XML Anaconda EXCEL Attention Breakpoint 净利润 Ubuntu ResNet-50 QWEN DeepStream GoogLeNet ChatGPT FP16 Sklearn VPN Animate Firewall TensorFlow Pytorch LLM GPTQ SAM Git CEIR OpenAI llama.cpp SQL 关于博主 Hilton MD5 uwsgi diffusers JSON Pickle Password Color transformers Distillation GGML ONNX Video 财报 hf LeetCode Streamlit Hungarian Baidu Safetensors FP8 Tiktoken Github Bipartite CV Domain XGBoost 腾讯云 OpenCV Cloudreve RAR Vim Website Tracking NLP TTS 报税 FP32 Quantize Google Use Math Python Qwen2.5 Conda tar 算法题 Ptyhon Random 飞书 Pillow v0.dev Bitcoin Qwen Gemma RGB Algorithm Card Jupyter HaggingFace FastAPI UI PIP Miniforge Clash CSV Translation CUDA 公式 BF16 YOLO GPT4 Jetson Knowledge CC ModelScope GIT 阿里云 Linux LoRA git PDB Paddle SVR OCR TensorRT Interview printf Diagram uWSGI WebCrawler Template mmap C++ CTC 搞笑 Logo Plate Michelin CAM FlashAttention Datetime Data Bert VGG-16 Hotel LaTeX Tensor HuggingFace TSV Claude BTC Input InvalidArgumentError 证件照 COCO Freesound torchinfo UNIX Review Crawler Zip Base64 Augmentation Docker logger Django Plotly Proxy DeepSeek scipy Excel Disk Food Magnet NLTK Statistics 多进程 Bin Shortcut Qwen2 VSCode Mixtral PyTorch Llama SQLite Windows Permission Heatmap AI
站点统计

本站现有博文309篇,共被浏览730469

本站已经建立2367天!

热门文章
文章归档
回到顶部