EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
C++ 版权 uwsgi PyCharm 多线程 Pandas Streamlit Ptyhon LaTeX Git Heatmap Logo Qwen2 Use Crawler Bipartite HuggingFace Python Transformers OpenCV Statistics Vmess CLAP 阿里云 Template 飞书 Hungarian tqdm Google ONNX Ubuntu Miniforge Pytorch ChatGPT Mixtral Sklearn torchinfo UNIX 签证 v2ray SPIE Augmentation 报税 Tracking CV RAR DeepStream LoRA v0.dev hf Quantization git JSON git-lfs Jetson HaggingFace Llama BF16 Dataset mmap ResNet-50 Proxy VGG-16 CTC SQLite 搞笑 Translation VSCode IndexTTS2 SVR InvalidArgumentError Magnet llama.cpp DeepSeek Bitcoin 强化学习 Food Baidu Algorithm GGML 继承 Search News Qwen XGBoost TSV MD5 Card 证件照 算法题 Color FP8 Data Breakpoint NLP FP32 YOLO Claude Markdown UI Attention TensorFlow Gemma PDF Conda Paddle FlashAttention GPT4 logger uWSGI Michelin Website Base64 Shortcut LeetCode Domain CEIR Distillation FP64 WebCrawler Plate Django Animate PDB 音频 NLTK Excel Paper Firewall Hilton Interview XML Bert API CSV NameSilo 财报 TensorRT Diagram WAN Vim Land VPN Review Zip QWEN FastAPI Linux 多进程 PyTorch GoogLeNet EXCEL Plotly GIT 关于博主 printf Jupyter LLAMA Math BTC 域名 scipy 顶会 Bin Freesound Clash SQL Docker BeautifulSoup Input Nginx Quantize CAM 腾讯云 diffusers 净利润 AI OCR Tiktoken 图形思考法 RGB Windows Tensor Random Agent Anaconda GPTQ Github Qwen2.5 Safetensors tar Datetime Hotel TTS CC Password Numpy Cloudreve 公式 Knowledge LLM Web Image2Text OpenAI FP16 第一性原理 COCO Permission Video ModelScope CUDA Pillow 递归学习法 PIP transformers Disk SAM Pickle
站点统计

本站现有博文320篇,共被浏览757018

本站已经建立2421天!

热门文章
文章归档
回到顶部