EADST

Print Transformers Pytorch Model Information

import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np

model_dir = "/dfs/data/model_path_folder/"

def inspect_model_weights(directory_path):
    """
    检索文件夹中所有的bin或safetensors文件并打印模型权重信息

    参数:
        directory_path (str): 包含模型文件的文件夹路径
    """
    # 查找所有bin和safetensors文件
    bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
    safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))

    all_files = bin_files + safetensors_files

    if not all_files:
        print(f"在 {directory_path} 中没有找到bin或safetensors文件")
        return

    print(f"找到 {len(all_files)} 个模型文件:")
    for idx, file_path in enumerate(all_files):
        print(f"{idx+1}. {os.path.basename(file_path)}")

    total_size = 0
    param_count = 0
    layer_stats = defaultdict(int)
    tensor_types = defaultdict(int)
    shape_info = defaultdict(list)

    # 处理每个文件
    for file_path in all_files:
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
        total_size += file_size

        print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")

        # 根据文件扩展名加载权重
        if file_path.endswith('.bin'):
            try:
                weights = torch.load(file_path, map_location='cpu')
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue
        else:  # safetensors
            try:
                weights = load_file(file_path)
            except Exception as e:
                print(f"  无法加载 {file_path}: {e}")
                continue

        # 分析权重
        print(f"  包含 {len(weights)} 个张量")
        for key, tensor in weights.items():
            # 统计参数数量
            num_params = np.prod(tensor.shape)
            param_count += num_params

            # 统计层类型
            layer_type = "other"
            if "attention" in key or "attn" in key:
                layer_type = "attention"
            elif "mlp" in key or "ffn" in key:
                layer_type = "feed_forward"
            elif "embed" in key:
                layer_type = "embedding"
            elif "norm" in key or "ln" in key:
                layer_type = "normalization"
            layer_stats[layer_type] += num_params

            # 统计张量类型
            tensor_types[tensor.dtype] += num_params

            # 记录形状信息
            shape_str = str(tensor.shape)
            shape_info[shape_str].append(key)

            # 打印详细信息(前10个张量)
            if len(shape_info) <= 10 or num_params > 1_000_000:
                print(f"  - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")

    # 打印汇总信息
    print("\n模型权重汇总:")
    print(f"总文件大小: {total_size:.2f} MB")
    print(f"总参数数量: {param_count:,}")

    print("\n按层类型划分的参数:")
    for layer_type, count in layer_stats.items():
        percentage = (count / param_count) * 100
        print(f"  {layer_type}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n张量数据类型分布:")
    for dtype, count in tensor_types.items():
        percentage = (count / param_count) * 100
        print(f"  {dtype}: {count:,} 参数 ({percentage:.2f}%)")

    print("\n常见张量形状:")
    sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
    for i, (shape, keys) in enumerate(sorted_shapes[:10]):
        num_params = np.prod(eval(shape))
        percentage = (num_params * len(keys) / param_count) * 100
        print(f"  {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
        if i < 3:  # 只显示前3种最常见形状的示例
            print(f"    例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))

def main():
    # model_dir = input("请输入模型文件夹路径: ")
    inspect_model_weights(model_dir)

if __name__ == "__main__":
    main()
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
BTC VPN Agent Conda RL Proxy MD5 论文速读 Freesound SAM Website ModelScope Tiktoken Nginx Card CV DeepSeek diffusers mmap tqdm logger RAR Translation Anaconda Statistics 阿里云 EXCEL SPIE 第一性原理 JSON HuggingFace Vim 图标 Domain CEIR GPT4 Git Clash CUDA 继承 Hilton Random Safetensors FastAPI Plate COCO SQL llama.cpp C++ Bipartite VSCode Shortcut 报税 净利润 Use Diagram Bitcoin v2ray printf TTS Qwen torchinfo FP64 QWEN Animate Base64 LLM Knowledge Hungarian GPTQ 证件照 WebCrawler CTC Excel YOLO FP16 TSV 音频 Baidu Jetson Ptyhon Augmentation SQLite Attention Search Heatmap DeepStream LLAMA Miniforge Template Algorithm PIP TensorFlow RGB FP8 CC 域名 Pillow TensorRT 搞笑 UNIX Tracking Zip HaggingFace 算法题 uWSGI Logo Password Markdown Numpy ms-swift Streamlit GoogLeNet 图形思考法 CLAP 云服务器 公式 CAM Permission InvalidArgumentError Michelin Plotly Tensor Transformers 多线程 Color ResNet-50 BeautifulSoup 关于博主 Data AI ONNX FlashAttention Sklearn WAN LeetCode Rebuttal LoRA Ubuntu 递归学习法 Quantization CSV uwsgi Input Image2Text PyTorch Paper Bert Crawler OpenAI Google Llama 签证 hf Python SVR Gemma Breakpoint Cloudreve v0.dev 飞书 transformers PDF Pytorch Food Qwen2.5 Web News Docker Interview Pandas Math PyCharm Hotel icon ChatGPT Quantize 论文 NameSilo Land Datetime Dataset NLTK git-lfs NLP GIT Firewall 版权 Distillation XGBoost 腾讯云 Video scipy Django API Pickle Windows Review FP32 财报 UI IndexTTS2 Linux Jupyter Mixtral tar OCR GGML VGG-16 Claude 强化学习 Qwen2 多进程 Bin Paddle Disk Github 顶会 LaTeX git XML OpenCV Vmess BF16 Magnet PDB
站点统计

本站现有博文332篇,共被浏览867418

本站已经建立2575天!

热门文章
文章归档
回到顶部