EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
域名 PyCharm 签证 Ubuntu 多线程 CC Michelin Animate 财报 Qwen2.5 Jupyter Paddle CUDA PIP Math llama.cpp Card Translation Permission TSV Interview FastAPI ChatGPT Tiktoken uwsgi Transformers 音频 Pandas YOLO Ptyhon 多进程 Docker HaggingFace MD5 VSCode 关于博主 Excel TensorRT Pillow git-lfs ONNX Python Heatmap Jetson UI Base64 hf 强化学习 算法题 Agent Review Github RAR ResNet-50 Tracking GIT Breakpoint Algorithm Vmess Google CTC Use EXCEL 图标 Quantize NameSilo BTC Attention GPTQ Statistics v2ray Plotly SQLite Distillation NLP AI Qwen2 PDF Gemma TTS Augmentation 公式 Vim BF16 OpenCV 云服务器 Password Firewall GGML SQL C++ Random uWSGI GPT4 Clash 腾讯云 Diagram 版权 icon PDB Datetime Zip 证件照 Pickle FP8 Land Bipartite tqdm XGBoost Git Pytorch Hilton ModelScope UNIX 阿里云 mmap Bitcoin Numpy Mixtral Search scipy Freesound 飞书 QWEN OpenAI 第一性原理 CV IndexTTS2 DeepSeek Dataset 搞笑 API 顶会 Conda COCO 净利润 Django git CLAP Claude SVR Image2Text Hungarian Crawler Template LLM Linux Safetensors Input Color FlashAttention Website Plate Anaconda transformers Shortcut Windows FP64 Llama PyTorch printf Miniforge Disk Proxy LLAMA XML 报税 SPIE LoRA 图形思考法 CAM Tensor FP32 Paper LeetCode Knowledge FP16 TensorFlow 递归学习法 NLTK Rebuttal Web Hotel Streamlit 继承 OCR diffusers Baidu CEIR Quantization VGG-16 logger WAN Bert Magnet Cloudreve BeautifulSoup v0.dev Logo Video LaTeX Bin Sklearn CSV Data InvalidArgumentError Nginx JSON GoogLeNet HuggingFace RGB Domain News torchinfo Food tar VPN DeepStream Qwen SAM WebCrawler Markdown
站点统计

本站现有博文324篇,共被浏览808737

本站已经建立2511天!

热门文章
文章归档
回到顶部