EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Anaconda TTS JSON Markdown Plate Random Windows EXCEL git 多线程 Vmess COCO PIP Tiktoken Attention Safetensors GoogLeNet v2ray Bipartite 财报 Disk Diagram Bin PyTorch Claude ChatGPT TSV HuggingFace 报税 Dataset 云服务器 SVR Freesound Statistics 域名 tqdm Algorithm Pillow v0.dev 递归学习法 OpenCV Template Food Vim CSV Shortcut 阿里云 Math Bert 多进程 Hotel BF16 OCR SPIE Streamlit transformers Paddle ModelScope 算法题 Quantization AI ONNX Pandas GIT Search 继承 Password Hilton 证件照 Plotly Image2Text VSCode PDF TensorFlow Transformers WAN VPN printf Base64 LaTeX hf Proxy Card Agent 版权 飞书 Excel TensorRT RGB Pickle Tracking Heatmap InvalidArgumentError Baidu Distillation Quantize GPTQ 腾讯云 Jetson mmap Review C++ LLAMA FP8 Linux Hungarian CUDA 关于博主 Michelin API Augmentation Use 图形思考法 GGML Qwen XML Conda Llama Google Breakpoint PDB News CEIR uWSGI Docker Jupyter PyCharm SAM CC Color 顶会 Qwen2.5 YOLO BTC WebCrawler CAM 音频 XGBoost CTC OpenAI SQL Miniforge Cloudreve uwsgi CV FlashAttention 强化学习 LLM Permission logger Sklearn Knowledge LeetCode Zip QWEN FP16 LoRA HaggingFace Website git-lfs NLP Domain Video FP32 Data UNIX Qwen2 Git Translation UI Crawler torchinfo MD5 Datetime FP64 Ubuntu SQLite Clash scipy Mixtral Web NameSilo Tensor diffusers DeepStream CLAP Input Github Pytorch GPT4 Django Ptyhon Firewall BeautifulSoup Gemma IndexTTS2 llama.cpp 搞笑 FastAPI Paper Land Animate Nginx RAR 公式 tar NLTK Numpy Magnet 签证 DeepSeek VGG-16 Interview ResNet-50 Bitcoin 净利润 Logo Python 第一性原理
站点统计

本站现有博文321篇,共被浏览776146

本站已经建立2466天!

热门文章
文章归档
回到顶部