EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
QWEN Safetensors scipy Vim Bitcoin WAN News 递归学习法 Linux CSV 报税 Card UI XGBoost CV Crawler Tiktoken Baidu Website COCO Search 域名 EXCEL 证件照 LLAMA Nginx 强化学习 torchinfo Firewall CC Ptyhon Qwen2.5 Zip Agent FastAPI Tracking GoogLeNet 论文 LeetCode Disk Clash Augmentation BTC Datetime Quantization Llama Data Heatmap ONNX DeepStream Anaconda Windows Docker Paper icon RAR Streamlit PyTorch Hungarian Diagram 音频 VGG-16 Qwen Web ChatGPT Breakpoint YOLO Image2Text Use 算法题 LoRA OpenCV SPIE Jetson CTC CLAP Gemma 版权 MD5 公式 OCR NLP uwsgi Proxy Logo Knowledge Base64 Color diffusers SVR Python Input Land GPTQ Random Magnet NLTK Cloudreve HaggingFace llama.cpp BF16 PDB VSCode Conda Domain 财报 Dataset GGML Django HuggingFace PyCharm Bert FP64 Freesound Pickle v0.dev 签证 SQLite Pillow XML Pandas Claude tar Markdown 关于博主 DeepSeek Math Paddle Statistics Animate 论文速读 Plate Food Hilton Tensor hf Numpy Review SAM 飞书 FlashAttention Rebuttal 多进程 Google SQL Shortcut Excel 第一性原理 Hotel ModelScope UNIX JSON TensorFlow Sklearn ResNet-50 uWSGI CEIR 云服务器 GIT LLM Distillation PIP API CAM git 图形思考法 C++ git-lfs CUDA FP16 Github 搞笑 Password Algorithm 顶会 Vmess logger TensorRT tqdm Attention printf TTS Transformers Git Permission mmap 继承 净利润 Plotly NameSilo OpenAI Miniforge Quantize Pytorch Mixtral Jupyter TSV FP32 Translation Ubuntu Video Qwen2 RGB Interview transformers Michelin BeautifulSoup IndexTTS2 Template FP8 v2ray 多线程 Bin InvalidArgumentError PDF 腾讯云 AI LaTeX WebCrawler 阿里云 图标 GPT4 Bipartite VPN
站点统计

本站现有博文328篇,共被浏览845080

本站已经建立2550天!

热门文章
文章归档
回到顶部