EADST

Pytorch GPTQ Dequantizing Function

Pytorch GPTQ Dequantizing Function

Here is the Python code optimizing the dequantization of a GPTQ model to torch FP16 format.

import torch

# Function: Dequantize quantized weights
def dequantization(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
    # Create a tensor for bitwise right shift operation
    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

    # Apply bitwise right shift and convert qzeros to the appropriate type
    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

    # Reshape the zeros tensor
    zeros = zeros + 1
    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

    # Reshape the scales tensor
    scales = scales.reshape(-1, 1, scales.shape[-1])

    # Similar bitwise right shift operation for qweight and reshape
    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
    weight = weight.reshape(-1, group_size, weight.shape[2])

    # Apply dequantization formula and reshape the final weight
    weight = (scales * (weight - zeros))
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

    # Return the transposed weight
    return weight.transpose(0, 1)

# Function: Load quantized model and perform dequantization
def get_pytorch_bin():
    # Specify model file path
    path = "./your_model_folder/gptq_model-4bit-128g.bin"

    # Dictionary to store processed weights
    tensors = {}

    # Load the model file
    f = torch.load(path, map_location="cpu")

    # Iterate through keys in the model
    for idx, k in enumerate(f.keys()):
        ori_w = f[k]  # Original weight
        keys = k  # Original key name

        # Skip non-weight entries
        if ".qzeros" in k or ".scales" in k or ".g_idx" in k:
            continue
        if "o_proj.bias" in k or "up_proj.bias" in k or "down_proj.bias" in k or "gate_proj.bias" in k:
            continue

        # Process quantized weights
        if ".qweight" in k:
            qweight = f[k]  # Quantized weight
            qzeros = f[k.replace(".qweight", ".qzeros")]  # Zero points
            scales = f[k.replace(".qweight", ".scales")]  # Scales
            g_idx = f[k.replace(".qweight", ".g_idx")]  # Group index
            ori_w = dequantization(qweight, qzeros, scales, g_idx)  # Perform dequantization
            keys = k.replace(".qweight", ".weight")  # Update key name

        # Add processed weight to the dictionary
        tensors[keys] = ori_w

    # Print the number of processed weights and save as a new model file
    print(len(tensors))
    torch.save(tensors, "./your_model_folder/pytorch_model.bin")

# Main program entry point
if __name__ == '__main__':
    get_pytorch_bin()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
签证 mmap 继承 Color 多进程 Pickle Use NLP CEIR C++ 强化学习 Bert Card CUDA torchinfo Math Firewall Bin Animate Algorithm SQL uwsgi Ubuntu Qwen2 WebCrawler Gemma 图形思考法 Paddle LeetCode hf TensorRT 搞笑 FP64 News Augmentation Breakpoint Bitcoin 域名 Docker 财报 LLAMA Crawler Random Cloudreve Hungarian Vmess ChatGPT ONNX SVR XGBoost Hotel Land Pandas FP32 TTS Proxy Michelin CV Statistics Disk HuggingFace Numpy Safetensors GIT Plotly SQLite GPT4 Magnet BF16 Llama UI Windows RGB Quantize Knowledge Qwen2.5 Miniforge Pytorch Python GPTQ Transformers DeepStream llama.cpp Pillow Datetime Hilton 云服务器 Domain NLTK Linux CTC Excel diffusers NameSilo 算法题 Review LaTeX Claude Search Password FP16 tar CLAP PIP Interview ResNet-50 CSV Tracking HaggingFace Ptyhon YOLO Conda scipy Base64 tqdm Git GoogLeNet Django FP8 Permission 公式 OpenCV Shortcut BTC Paper RAR uWSGI Diagram git-lfs git TensorFlow UNIX PDB XML Tensor AI DeepSeek Plate Dataset WAN Qwen COCO Freesound Sklearn VSCode 报税 Template transformers SPIE Quantization Translation VPN Clash ModelScope JSON v0.dev Agent MD5 证件照 Markdown TSV 净利润 递归学习法 CAM Tiktoken PyCharm Image2Text Anaconda Bipartite 第一性原理 Input 多线程 关于博主 Github printf Baidu 音频 logger PDF Website VGG-16 OCR IndexTTS2 QWEN Logo Nginx EXCEL SAM FlashAttention Distillation 飞书 GGML 阿里云 OpenAI Data PyTorch InvalidArgumentError FastAPI Web LLM 版权 API 腾讯云 BeautifulSoup Google Mixtral LoRA Heatmap Jetson Attention 顶会 CC Video Food Zip Jupyter v2ray Vim Streamlit
站点统计

本站现有博文321篇,共被浏览770826

本站已经建立2457天!

热门文章
文章归档
回到顶部