EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
SQLite ModelScope News Anaconda Github uwsgi NLP FP16 GIT Zip RAR Bipartite Tensor Translation Color Docker 音频 Proxy Password Statistics 递归学习法 Knowledge 第一性原理 Disk ChatGPT printf Quantization NameSilo WAN LoRA RGB uWSGI Attention LLAMA Hotel Freesound Bitcoin TSV Python CTC Hungarian OpenAI 继承 Qwen2 LLM Gemma SQL Distillation CV Pytorch Use 云服务器 icon transformers FlashAttention Logo Plotly Review OpenCV Ptyhon AI ONNX Quantize Shortcut TensorRT Pillow 阿里云 Food GoogLeNet 版权 VPN CAM Diagram scipy YOLO Git Mixtral Claude 飞书 CUDA 关于博主 图标 PDB Markdown Qwen2.5 LeetCode Math Input Algorithm WebCrawler Base64 FP8 多进程 SVR 域名 Django Video Bin Permission SAM CSV diffusers Animate EXCEL LaTeX COCO llama.cpp InvalidArgumentError 净利润 DeepSeek Nginx Windows Jetson Hilton Heatmap 签证 证件照 Search FastAPI Domain Card NLTK 顶会 Dataset 财报 v2ray Paddle GPTQ Firewall DeepStream BTC Linux Plate 强化学习 Web hf 多线程 Qwen Safetensors MD5 TensorFlow JSON IndexTTS2 Agent Interview Magnet Pickle PyCharm 公式 Sklearn BeautifulSoup Data UNIX API 报税 CLAP GGML Datetime TTS CC Crawler Land tar Template XGBoost logger Augmentation FP32 Baidu v0.dev Excel Ubuntu 算法题 Llama UI ResNet-50 C++ PIP Transformers Breakpoint Numpy Tracking Streamlit Image2Text FP64 Random Website XML 搞笑 BF16 Bert git HuggingFace torchinfo git-lfs VSCode GPT4 Clash PDF VGG-16 Pandas OCR Conda QWEN Paper 图形思考法 PyTorch Michelin tqdm Cloudreve CEIR Jupyter Vim Tiktoken Google Miniforge HaggingFace SPIE mmap 腾讯云 Vmess
站点统计

本站现有博文322篇,共被浏览791083

本站已经建立2487天!

热门文章
文章归档
回到顶部