EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Pillow Qwen2 CUDA Hotel 签证 Ubuntu Hilton HuggingFace 公式 Docker Safetensors Numpy Proxy TSV Heatmap Land Diagram Algorithm PDB FP64 JSON Magnet 云服务器 图形思考法 NameSilo Vmess transformers GGML FastAPI Disk 多线程 Pickle 腾讯云 COCO Qwen Ptyhon 报税 LLM Statistics Agent AI git-lfs icon CTC Gemma Food Quantization Excel LLAMA 搞笑 tar Video Breakpoint 算法题 Bipartite FlashAttention Pytorch Crawler ResNet-50 UNIX Bitcoin IndexTTS2 OCR Llama git Windows NLP Linux ChatGPT PyTorch Permission Freesound GPT4 音频 Nginx Dataset 阿里云 ModelScope 净利润 Datetime Conda Augmentation Website Sklearn SAM Streamlit Bin TensorFlow Tiktoken Web DeepSeek InvalidArgumentError Base64 DeepStream Paper VPN Qwen2.5 Math Domain CAM llama.cpp OpenCV 论文速读 证件照 Translation PIP FP8 Firewall SQLite 版权 C++ QWEN Claude 强化学习 Random VGG-16 Attention YOLO RGB Logo Interview 继承 tqdm Vim hf 顶会 Plate GoogLeNet 财报 VSCode CEIR Google GPTQ uWSGI SQL BTC Zip LoRA LeetCode 域名 论文 SPIE Markdown BeautifulSoup SVR 多进程 News Password FP32 Distillation RAR 关于博主 Django Template TTS WAN Use Michelin GIT BF16 Tracking Animate Hungarian MD5 XGBoost API CV HaggingFace PDF FP16 Data ONNX UI CLAP Mixtral Cloudreve Search Github Clash mmap PyCharm Baidu Bert printf uwsgi torchinfo Input Python Pandas Jetson scipy Miniforge Rebuttal Tensor Paddle v0.dev Plotly v2ray Review TensorRT OpenAI 飞书 Color EXCEL XML CC 递归学习法 Jupyter diffusers Quantize Knowledge Shortcut Image2Text 第一性原理 LaTeX Git Transformers CSV 图标 logger WebCrawler NLTK Anaconda Card
站点统计

本站现有博文328篇,共被浏览854034

本站已经建立2561天!

热门文章
文章归档
回到顶部