EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
QWEN FlashAttention Safetensors llama.cpp MD5 Docker hf Pandas CAM Vmess LaTeX FP32 NLP C++ 关于博主 Python LeetCode TensorFlow 第一性原理 Use Quantize git-lfs CTC Transformers PyTorch COCO Permission Llama IndexTTS2 tqdm 签证 腾讯云 Baidu Breakpoint CEIR Password CV 版权 SAM 图形思考法 Hotel Jupyter 证件照 Freesound Data Datetime Markdown Algorithm Ptyhon Michelin OCR PyCharm torchinfo GPT4 Pytorch Anaconda Bipartite CLAP Image2Text GGML Gemma Clash Qwen2.5 SVR Nginx DeepStream 多线程 Knowledge 云服务器 LLM Distillation Proxy mmap Github Domain v0.dev Website 继承 GIT ONNX CSV Statistics TSV Qwen transformers FP64 SQL UNIX Heatmap OpenCV NameSilo 强化学习 Bert Vim Jetson HuggingFace 音频 OpenAI BF16 tar ResNet-50 Windows Shortcut TTS Dataset diffusers Plotly Template Web Tensor uWSGI Review Cloudreve GoogLeNet WebCrawler Plate TensorRT PDB XGBoost 多进程 Quantization Search SPIE Augmentation Input VGG-16 财报 Land Agent Bin AI Pillow Math DeepSeek WAN InvalidArgumentError Tracking SQLite 图标 Streamlit v2ray 报税 FP8 Base64 LLAMA GPTQ Firewall FP16 VSCode 顶会 JSON CUDA Disk CC logger Ubuntu Interview git RGB LoRA Random icon BeautifulSoup Excel Tiktoken EXCEL YOLO News Paper FastAPI Diagram scipy Conda 公式 Claude ModelScope VPN Attention PDF Video 阿里云 净利润 Bitcoin Animate ChatGPT Card RAR 飞书 Paddle 算法题 API BTC Django Translation Magnet Logo Mixtral Sklearn uwsgi Numpy Linux UI 搞笑 Google Qwen2 Color Pickle NLTK Food printf Zip Hungarian Git Hilton XML 递归学习法 Miniforge HaggingFace 域名 Crawler PIP
站点统计

本站现有博文322篇,共被浏览793418

本站已经建立2490天!

热门文章
文章归档
回到顶部