gemma-3n-E4B-it 模型框架图+简要解析
作者:XD / 发表: 2025年6月30日 04:08 / 更新: 2025年6月30日 04:20 / 科研学习 / 阅读量:244
================================================================================
1. 输入层 (INPUT STAGE)
================================================================================
+------------------+ +------------------+ +--------------------+
| 图像输入 | | 文本输入 | | 音频输入 |
| (pixel_values) | | (input_ids) | | (input_features) |
+------------------+ +------------------+ +--------------------+
| | |
v v v
================================================================================
2. 编码层 (ENCODING STAGE)
================================================================================
+------------------+ +-----------------------------+ +---------------------------+
| Vision Tower | | 文本嵌入模块 | | Audio Tower |
| (MobileNetV5) | | 1. 生成标准词元嵌入 | | 1. 子采样卷积 (SubSample)|
+------------------+ | 2. 准备逐层输入嵌入 | | 2. Conformer 模块 (x12) |
| +-----------------------------+ +---------------------------+
v | v
+------------------+ | +------------------+
| 视觉软词元 | | | 音频软词元 |
+------------------+ | +------------------+
| | |
+-------------------------------------> | <--------------------------------------+
|
v
================================================================================
3. 融合层 (FUSION STAGE)
================================================================================
+-----------------------------------------+
| 特征融合模块 |
| (将视觉/音频软词元插入文本嵌入序列) |
+-----------------------------------------+
|
| [生成统一的多模态序列]
v
================================================================================
4. 核心处理层 (CORE PROCESSING STAGE)
(以下为单个解码器层的详细流程,该流程共重复 35 次)
================================================================================
+-----------------------------------------+
| 输入:来自上一层的隐藏状态 |
| + 并行的“逐层输入”数据 |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| Step 4.1: AltUp 预测 & Pre-Attn 处理 |
| (包含 RMSNorm, LAUREL) |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| Step 4.2: 自注意力模块 (滑动或全注意力) |
| (包含 GQA, RoPE 位置编码) |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| Step 4.3: 注意力后处理 |
| (包含残差连接, LAUREL融合, RMSNorm) |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| Step 4.4: 前馈网络 / MLP |
| (门控结构, 可能有稀疏激活) |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| Step 4.5: MLP 后处理与 AltUp 校正 |
| (残差连接, Norm, AltUp Correct) |
+--------------------+--------------------+
|
v
+-----------------------------------------+
| 输出:向下一层传递隐藏状态 |
+-----------------------------------------+
| (循环34次后进入下一阶段)
v
================================================================================
5. 输出层 (OUTPUT STAGE)
================================================================================
+-----------------------------------------+
| 最终层 RMSNorm |
+-----------------------------------------+
|
v
+-----------------------------------------+
| 输出头 (LM Head) |
+-----------------------------------------+
|
v
+-----------------------------------------+
| 生成的文本 |
+-----------------------------------------+
Gemma-3N 模型框架详细流程
该模型的整体架构被定义为Gemma3nForConditionalGeneration
,它将三个独立的编码塔(视觉、音频、文本)与一个强大的核心语言模型相结合,以实现多模态的条件生成任务。
阶段一:输入模态的独立编码
模型首先并行处理三种不同类型的输入。
1.1 视觉处理层 (Vision Tower)
- 输入: 接收原始图像数据,形式为
pixel_values
张量。 - 骨干网络编码: 图像被送入一个基于
mobilenetv5_300m_enc
架构的视觉编码器。model.safetensors.index.json
中的大量model.vision_tower.timm_model...
权重条目证实了其存在和复杂性。 - 特征提取: 经过视觉塔处理后,图像被转换成一个固定长度的特征序列,其隐藏维度为2048。
- 投影与生成软词元:
- 该特征序列被送入一个专用的
Gemma3nMultimodalEmbedder
模块。 - 在这个模块中,视觉特征首先经过
RMSNorm
归一化,然后通过一个线性投影层,最后再经过一次RMSNorm
,最终生成256个代表图像内容的“视觉软词元”。这些软词元已经处于语言模型可以理解的语义空间中。
- 该特征序列被送入一个专用的
1.2 音频处理层 (Audio Tower)
- 输入: 接收音频特征,通常是梅尔频谱图,形式为
input_features
张量,特征维度为128。 - 子采样卷积层 (
Gemma3nAudioSubSampleConvProjection
):- 为了降低序列长度并提取初步特征,音频特征首先通过一个包含两层2D卷积的子采样模块。
- 这两层卷积的步长均为
[2, 2]
,有效降低了时间和频率维度的分辨率。
- Conformer 编码层 (
Gemma3nAudioEncoder
):- 降采样后的特征被送入一个由12个
Gemma3nAudioConformerBlock
堆叠而成的Conformer编码器。 - 每个Conformer模块内部都包含自注意力(Self-Attention)、一维卷积(1D Convolution)和前馈网络(Feed-Forward Network),这是一种非常适合处理音频等时序信号的结构。
- 降采样后的特征被送入一个由12个
- 投影与生成软词元:
- 经过12层Conformer模块处理后,输出的隐藏状态(维度为1536)同样被送入一个专用的
Gemma3nMultimodalEmbedder
模块。 - 经过归一化和线性投影,最终生成188个代表音频内容的“音频软词元”。
- 经过12层Conformer模块处理后,输出的隐藏状态(维度为1536)同样被送入一个专用的
1.3 文本处理层 (Text Embedding)
- 输入: 接收文本序列,形式为
input_ids
(一串token ID)。 - 双嵌入流:
- 主嵌入流:
input_ids
通过Gemma3nTextScaledWordEmbedding
层被转换为隐藏维度为2048的词嵌入序列。这个序列中包含了特殊的占位符ID(如图像的image_token_id
: 262145),用于后续的多模态信息插入。 - 逐层输入流: 同时,
input_ids
被送入另一个embed_tokens_per_layer
嵌入层,为后续的35个解码器层分别生成一个独立的、并行的输入序列,即“逐层输入”(Per-Layer Inputs)。
- 主嵌入流:
阶段二:多模态特征融合
- 序列拼接: 这是一个关键步骤。模型将阶段1.1生成的视觉软词元和阶段1.2生成的音频软词元,通过
masked_scatter
操作,精确地“粘贴”到阶段1.3生成的主嵌入流序列中对应的占位符位置上。 - 统一序列形成: 这个操作完成后,形成了一个单一的、无缝融合了文本、视觉和音频信息的多模态嵌入序列。
阶段三:核心语言模型处理 (35个解码器层)
融合后的多模态序列,连同并行的“逐层输入”序列,被送入由35个Gemma3nTextDecoderLayer
组成的解码器核心进行深度处理。下面是单个解码器层的详细流程:
- AltUp 预测与前置处理:
AltUp
模块首先对该层的输出进行预测。- 预测结果经过一次RMSNorm归一化,并被送入
LAUREL
低秩适配器模块。
- 自注意力计算:
- 这是信息交互的核心。模块内部使用分组查询注意力 (GQA)(8个查询头,2个键/值头)以提高效率。
- 根据当前层的索引,注意力机制可能是滑动窗口注意力(
sliding_window
大小为512)或全注意力,这由config.json
中的layer_types
列表定义。
- 注意力后处理与融合:
- 注意力模块的输出与输入进行残差连接。
- 结果与
LAUREL
模块的输出融合,形成增强的残差流,并通过一次RMSNorm归一化。
- 前馈网络 (MLP):
- 数据通过一个门控MLP(Gated MLP)进行非线性变换。
- 对于前10层,MLP的激活函数会应用稀疏化处理(稀疏度为0.95),而后续层则为密集计算。
- 最终处理与AltUp校正:
- MLP的输出再次与输入进行残差连接和归一化。
- 最终结果被送入
AltUp
模块的校正步骤,并与该层的“逐层输入”融合,形成最终输出。 - 此输出被传递给下一个解码器层。
这个包含5个步骤的复杂流程会重复35次,每一层都在逐步深化模型对多模态输入的理解。
阶段四:输出生成
- 最终归一化: 经过全部35层解码器处理后,最终的隐藏状态序列通过一个顶层的
RMSNorm
层进行最后一次归一化。 - 语言模型头 (
lm_head
):- 归一化后的向量被送入一个线性层,即
lm_head
。 - 该层将2048维的隐藏向量投影到大小为262400的词汇表空间上,生成logits(即每个词元的得分)。
- 归一化后的向量被送入一个线性层,即
- 生成文本: 这些logits经过soft-capping(上限为30.0)和softmax函数后,可以用来预测下一个最有可能的token,从而实现自回归式的文本生成。
参考链接
相关标签