Print Transformers Pytorch Model Information
作者:XD / 发表: 2025年4月23日 04:15 / 更新: 2025年4月23日 04:15 / 编程笔记 / 阅读量:24
import os
import re
import torch
from safetensors import safe_open
from safetensors.torch import load_file
import glob
from collections import defaultdict
import numpy as np
model_dir = "/dfs/data/model_path_folder/"
def inspect_model_weights(directory_path):
"""
检索文件夹中所有的bin或safetensors文件并打印模型权重信息
参数:
directory_path (str): 包含模型文件的文件夹路径
"""
# 查找所有bin和safetensors文件
bin_files = glob.glob(os.path.join(directory_path, "*.bin"))
safetensors_files = glob.glob(os.path.join(directory_path, "*.safetensors"))
all_files = bin_files + safetensors_files
if not all_files:
print(f"在 {directory_path} 中没有找到bin或safetensors文件")
return
print(f"找到 {len(all_files)} 个模型文件:")
for idx, file_path in enumerate(all_files):
print(f"{idx+1}. {os.path.basename(file_path)}")
total_size = 0
param_count = 0
layer_stats = defaultdict(int)
tensor_types = defaultdict(int)
shape_info = defaultdict(list)
# 处理每个文件
for file_path in all_files:
file_size = os.path.getsize(file_path) / (1024 * 1024) # MB
total_size += file_size
print(f"\n检查文件: {os.path.basename(file_path)} ({file_size:.2f} MB)")
# 根据文件扩展名加载权重
if file_path.endswith('.bin'):
try:
weights = torch.load(file_path, map_location='cpu')
except Exception as e:
print(f" 无法加载 {file_path}: {e}")
continue
else: # safetensors
try:
weights = load_file(file_path)
except Exception as e:
print(f" 无法加载 {file_path}: {e}")
continue
# 分析权重
print(f" 包含 {len(weights)} 个张量")
for key, tensor in weights.items():
# 统计参数数量
num_params = np.prod(tensor.shape)
param_count += num_params
# 统计层类型
layer_type = "other"
if "attention" in key or "attn" in key:
layer_type = "attention"
elif "mlp" in key or "ffn" in key:
layer_type = "feed_forward"
elif "embed" in key:
layer_type = "embedding"
elif "norm" in key or "ln" in key:
layer_type = "normalization"
layer_stats[layer_type] += num_params
# 统计张量类型
tensor_types[tensor.dtype] += num_params
# 记录形状信息
shape_str = str(tensor.shape)
shape_info[shape_str].append(key)
# 打印详细信息(前10个张量)
if len(shape_info) <= 10 or num_params > 1_000_000:
print(f" - {key}: 形状={tensor.shape}, 类型={tensor.dtype}, 参数数={num_params:,}")
# 打印汇总信息
print("\n模型权重汇总:")
print(f"总文件大小: {total_size:.2f} MB")
print(f"总参数数量: {param_count:,}")
print("\n按层类型划分的参数:")
for layer_type, count in layer_stats.items():
percentage = (count / param_count) * 100
print(f" {layer_type}: {count:,} 参数 ({percentage:.2f}%)")
print("\n张量数据类型分布:")
for dtype, count in tensor_types.items():
percentage = (count / param_count) * 100
print(f" {dtype}: {count:,} 参数 ({percentage:.2f}%)")
print("\n常见张量形状:")
sorted_shapes = sorted(shape_info.items(), key=lambda x: np.prod(eval(x[0])), reverse=True)
for i, (shape, keys) in enumerate(sorted_shapes[:10]):
num_params = np.prod(eval(shape))
percentage = (num_params * len(keys) / param_count) * 100
print(f" {shape}: {len(keys)} 个张量, 每个 {num_params:,} 参数 (总共占 {percentage:.2f}%)")
if i < 3: # 只显示前3种最常见形状的示例
print(f" 例如: {', '.join(keys[:3])}" + ("..." if len(keys) > 3 else ""))
def main():
# model_dir = input("请输入模型文件夹路径: ")
inspect_model_weights(model_dir)
if __name__ == "__main__":
main()
相关标签