EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Git Python Github Input 强化学习 Logo QWEN OpenAI Qwen Pandas PDB Hilton Template FlashAttention GPT4 Image2Text GoogLeNet Algorithm TensorRT git-lfs CAM 证件照 icon HuggingFace Paper DeepSeek AI Quantize Web JSON Jupyter 多线程 报税 v2ray Bin 公式 hf printf YOLO SQL FP16 Disk NLTK v0.dev Safetensors News BF16 RGB Website Miniforge Color Jetson CEIR Crawler FP32 logger Firewall Docker Datetime Rebuttal GGML Tensor DeepStream tar Tracking Hungarian Paddle LLAMA Streamlit uWSGI 净利润 scipy Data Hotel Quantization Tiktoken Llama UNIX Django 域名 Review Land Zip 签证 git Search Diagram OpenCV NameSilo Statistics llama.cpp 算法题 Windows CV ChatGPT CTC PyTorch Ubuntu BTC uwsgi Pickle 飞书 Markdown 论文速读 搞笑 PDF Use Augmentation Qwen2.5 Linux RAR 顶会 IndexTTS2 Michelin Clash SAM GPTQ Pytorch Excel Domain mmap WebCrawler PIP Qwen2 Claude VSCode Plate Interview Sklearn 第一性原理 LoRA Google Anaconda HaggingFace Shortcut PyCharm 多进程 腾讯云 TensorFlow EXCEL Plotly 关于博主 Ptyhon SQLite 云服务器 transformers Video Vim CC Password VGG-16 Base64 Cloudreve 递归学习法 继承 FastAPI TSV XGBoost LLM WAN Heatmap 图形思考法 SPIE C++ Conda Nginx 图标 LeetCode XML Proxy Card FP8 Agent InvalidArgumentError UI 阿里云 Dataset tqdm Pillow Knowledge Food LaTeX TTS Animate GIT Freesound diffusers Attention Breakpoint Translation Bert Vmess Magnet NLP CLAP 音频 ModelScope Numpy Distillation VPN Random torchinfo Bipartite Gemma ONNX Math Bitcoin FP64 SVR OCR Transformers 版权 Mixtral ResNet-50 CSV BeautifulSoup COCO MD5 Baidu Permission CUDA API 财报
站点统计

本站现有博文326篇,共被浏览825495

本站已经建立2531天!

热门文章
文章归档
回到顶部