Transformers Mixtral MOE模型代码中文注释 modular_mixtral.py
作者:XD / 发表: 2025年4月24日 05:11 / 更新: 2025年4月24日 05:11 / 编程笔记 / 阅读量:19
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# 此代码基于EleutherAI的GPT-NeoX库以及此库中的GPT-NeoX和OPT实现。
# 它已从原始形式修改,以适应与Meta AI团队训练模型时使用的GPT-NeoX和OPT相比的
# 微小架构差异。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件
# 是基于"按原样"分发的,没有任何明示或暗示的担保或条件。
# 有关许可证下特定语言的权限和限制,请参阅许可证。
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modular_mixtral.py
"""PyTorch Mixtral模型。"""
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import DynamicCache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
logging,
)
# 从Mistral模型中导入组件,Mixtral在Mistral架构基础上扩展了专家混合
from ..mistral.modeling_mistral import (
MistralAttention,
MistralForCausalLM,
MistralForQuestionAnswering,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
MistralPreTrainedModel,
MistralRMSNorm,
MistralRotaryEmbedding,
)
from .configuration_mixtral import MixtralConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
_CONFIG_FOR_DOC = "MixtralConfig"
def load_balancing_loss_func(
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
计算辅助负载平衡损失,如Switch Transformer中所实现的 - 使用PyTorch实现。
参见Switch Transformer (https://arxiv.org/abs/2101.03961)获取更多详情。此函数实现了
论文中方程(4)-(6)所提出的损失函数。它旨在惩罚专家路由过于不平衡的情况。
参数:
gate_logits:
来自`gate`的logits,应该是一个元组,包含model.config.num_hidden_layers个张量,
每个张量形状为[batch_size X sequence_length, num_experts]。
num_experts:
专家数量
top_k:
每个token路由到的专家数量,也可以解释为`top-k`路由参数。
attention_mask (`torch.Tensor`, *可选*):
在forward函数中使用的attention_mask,
如果不为None,形状为[batch_size X sequence_length]。
返回:
辅助损失。
"""
# 如果gate_logits为None或不是元组,返回0
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
# 如果gate_logits是元组,将所有层的gate logits连接起来
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
# 计算路由权重(各专家的概率)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
# 选择topk的专家
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# 创建专家掩码,表示每个token路由到哪些专家
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# 计算路由到每个专家的token百分比
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# 计算路由到这些专家的平均概率
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
# 处理有注意力掩码的情况
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
# 创建专家注意力掩码,与expert_mask形状相同
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# 计算路由到每个专家的token百分比(考虑注意力掩码)
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
expert_attention_mask, dim=0
)
# 创建路由专家注意力掩码,与tokens_per_expert形状相同
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# 计算路由到这些专家的平均概率(考虑注意力掩码)
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
router_per_expert_attention_mask, dim=0
)
# 计算总体损失:token分布与路由概率的乘积之和
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
# 乘以专家数量作为最终损失
return overall_loss * num_experts
# MixtralBlockSparseTop2MLP类 - 实现了Mixtral模型中的单个稀疏专家MLP块
# 此类是混合专家系统(MoE)中的单个专家,每个专家是一个标准的feedforward网络
# 使用SwiGLU激活,包含三个投影层和一个激活函数
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
# 设置前馈维度和隐藏维度
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
# 第一个投影层:隐藏维度 -> 前馈维度
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
# 第二个投影层:前馈维度 -> 隐藏维度
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
# 第三个投影层:隐藏维度 -> 前馈维度(用于门控机制)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
# 激活函数,默认为SiLU/Swish
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
# SwiGLU激活:w1输出经过激活函数后与w3输出相乘,实现门控机制
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
# 通过w2投影回原始维度
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# MixtralSparseMoeBlock类 - 实现Mixtral模型中的稀疏混合专家块
# 这是混合专家系统(MoE)的核心实现,使用块稀疏操作处理不平衡的token到专家的分配
class MixtralSparseMoeBlock(nn.Module):
"""
此实现严格等同于具有完全容量的标准MoE(无丢弃tokens)。
它更快,因为它使用块稀疏操作来处理tokens到专家的不平衡分配,
而标准MoE要么:
(1) 丢弃tokens,降低性能,或
(2) 将容量因子设置为专家数量,从而在填充上浪费计算和内存。
"""
def __init__(self, config):
super().__init__()
# 设置隐藏维度和前馈维度
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
# 本地专家数量
self.num_experts = config.num_local_experts
# 每个token选择的专家数量(通常为2)
self.top_k = config.num_experts_per_tok
# 门控网络 - 用于决定每个token应该路由到哪些专家
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
# 创建专家模块列表
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
# 抖动参数 - 用于训练过程中添加噪声,提高泛化能力
self.jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
# 获取输入的形状
batch_size, sequence_length, hidden_dim = hidden_states.shape
# 训练时如果启用抖动噪声,对隐藏状态添加随机扰动
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
# 重塑隐藏状态为二维张量,便于处理
hidden_states = hidden_states.view(-1, hidden_dim)
# 计算路由logits:为每个token计算分配给每个专家的logits
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# 计算路由权重(softmax),使用float类型以提高数值稳定性
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 选择top-k个专家及其权重
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# 归一化权重,确保它们的和为1
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# 将权重转换回输入的数据类型
routing_weights = routing_weights.to(hidden_states.dtype)
# 创建存储最终隐藏状态的零张量
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# 将选定的专家进行one-hot编码,创建专家掩码
# 这将用于轻松索引哪个专家将被调用
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# 获取被命中的专家列表(有token分配到的专家)
expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()
# 对每个被命中的专家执行前向传播
for expert_idx in expert_hitted:
# 获取当前专家层
expert_layer = self.experts[expert_idx]
# 找出分配给当前专家的token索引和它们的排名(top-1或top-2)
idx, top_x = torch.where(expert_mask[expert_idx])
# 索引正确的隐藏状态并计算当前专家的隐藏状态
# 确保通过相应的token(top-1和top-2)的`routing_weights`乘以输出隐藏状态
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# 然而,`index_add_`只支持torch张量进行索引,所以我们将使用`top_x`张量
# 将专家计算的结果添加到最终隐藏状态中
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# 重塑最终隐藏状态为原始形状
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
# MixtralRMSNorm类 - Mixtral使用的层归一化
# 直接继承自MistralRMSNorm,功能相同
class MixtralRMSNorm(MistralRMSNorm):
pass
# MixtralAttention类 - Mixtral使用的注意力机制
# 直接继承自MistralAttention,功能相同
class MixtralAttention(MistralAttention):
pass
# MixtralDecoderLayer类 - 实现Mixtral模型中的单个解码器层
# 包含自注意力机制和稀疏混合专家块
class MixtralDecoderLayer(nn.Module):
def __init__(self, config: MixtralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
# 自注意力层
self.self_attn = MixtralAttention(config, layer_idx)
# 块稀疏MoE层
self.block_sparse_moe = MixtralSparseMoeBlock(config)
# 输入层归一化
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 注意力后层归一化
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # 必要的,但为了向后兼容而保留在这里
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
参数:
hidden_states (`torch.FloatTensor`): 输入到层的形状为`(batch, seq_len, embed_dim)`的张量
attention_mask (`torch.FloatTensor`, *可选*): 大小为
`(batch, sequence_length)`的注意力掩码,其中填充元素由0表示。
past_key_value (`Tuple(torch.FloatTensor)`, *可选*): 缓存的过去键和值投影状态
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回的张量中的`attentions`。
output_router_logits (`bool`, *可选*):
是否返回所有路由器的logits。它们对于计算路由器损失很有用,并且
在推理过程中不应返回。
use_cache (`bool`, *可选*):
如果设置为`True`,则返回`past_key_values`键值状态,可用于加速解码
(参见`past_key_values`)。
cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
描述输入序列标记在序列中位置的索引。
kwargs (`dict`, *可选*):
要忽略的任意kwargs,用于FSDP和其他向模型注入代码的方法
"""
# 保存残差连接用的隐藏状态
residual = hidden_states
# 应用输入层归一化
hidden_states = self.input_layernorm(hidden_states)
# 自注意力计算
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
# 第一个残差连接:原始输入 + 注意力输出
hidden_states = residual + hidden_states
# 前馈神经网络 (MoE)
# 保存第二个残差连接用的隐藏状态
residual = hidden_states
# 应用注意力后层归一化
hidden_states = self.post_attention_layernorm(hidden_states)
# 应用块稀疏MoE
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
# 第二个残差连接:注意力输出 + MoE输出
hidden_states = residual + hidden_states
# 准备输出
outputs = (hidden_states,)
# 如果需要输出注意力权重,将其添加到输出元组
if output_attentions:
outputs += (self_attn_weights,)
# 如果需要输出路由器logits,将其添加到输出元组
if output_router_logits:
outputs += (router_logits,)
return outputs
# MixtralRotaryEmbedding类 - Mixtral使用的旋转位置编码
# 直接继承自MistralRotaryEmbedding,功能相同
class MixtralRotaryEmbedding(MistralRotaryEmbedding):
pass
# MixtralPreTrainedModel类 - Mixtral预训练模型的基类
# 继承自MistralPreTrainedModel,但禁用静态缓存
class MixtralPreTrainedModel(MistralPreTrainedModel):
# MoE模型不支持静态缓存,因为它们不能与torch.compile一起工作(不支持`torch.where(condition)`)
_supports_static_cache = False
# MixtralModel类 - Mixtral模型的主体架构
# 继承自MistralModel,主要区别是使用MixtralDecoderLayer替代了MistralDecoderLayer
class MixtralModel(MistralModel):
def __init__(self, config: MixtralConfig):
super().__init__(config)
# 创建层堆栈,使用MixtralDecoderLayer而非MistralDecoderLayer
self.layers = nn.ModuleList(
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> MoeModelOutputWithPast:
# 设置各种输出选项的默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
# 梯度检查点与缓存不兼容,训练时警告并禁用缓存
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# 如果启用缓存但未提供past_key_values,创建动态缓存
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# 如果没有提供嵌入,通过词嵌入层获取
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# 处理缓存位置和位置ID
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# 更新因果掩码
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# 设置隐藏状态为输入嵌入
hidden_states = inputs_embeds
# 创建位置嵌入,在所有解码器层之间共享
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# 解码器层处理
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
# 遍历所有解码器层
for decoder_layer in self.layers:
# 如果需要输出所有隐藏状态,保存当前隐藏状态
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 使用梯度检查点或直接前向传播
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
# 更新隐藏状态
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,保存当前层的注意力权重
if output_attentions:
all_self_attns += (layer_outputs[1],)
# 如果需要输出路由器logits,保存当前层的路由器logits
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
# 最终的层归一化
hidden_states = self.norm(hidden_states)
# 添加最后一个解码器层的隐藏状态
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 返回MoE特定的输出格式,包含路由器logits
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
# MixtralForCausalLM类 - 用于因果语言建模的Mixtral模型
# 继承自MistralForCausalLM,添加了MoE特定的功能
class MixtralForCausalLM(MistralForCausalLM):
# 定义与主干模型共享权重的参数
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# 使用MixtralModel替代MistralModel
self.model = MixtralModel(config)
# 路由器辅助损失系数 - 控制负载平衡损失的权重
self.router_aux_loss_coef = config.router_aux_loss_coef
# 专家数量和每个token的专家数量
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
(参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。
logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
对于长序列或大词汇表尤为显著。
如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
这在使用打包张量格式(批次和序列长度的单一维度)时很有用。
返回:
示例:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # 生成
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# 设置输出选项
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 通过模型获取输出
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
# 获取最后的隐藏状态
hidden_states = outputs.last_hidden_state
# 只计算必要的logits
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
# 如果提供了标签,计算损失
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
# 计算辅助损失(负载平衡损失)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
# 如果有主损失,将辅助损失加入总损失(乘以系数)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # 确保在相同设备上
# 返回MoE特定的输出格式,包含辅助损失
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
# 以下类直接继承自Mistral对应的类,功能相同
# MixtralForSequenceClassification类 - 用于序列分类的Mixtral模型
class MixtralForSequenceClassification(MistralForSequenceClassification):
pass
# MixtralForTokenClassification类 - 用于token分类的Mixtral模型
class MixtralForTokenClassification(MistralForTokenClassification):
pass
# MixtralForQuestionAnswering类 - 用于问答任务的Mixtral模型
class MixtralForQuestionAnswering(MistralForQuestionAnswering):
pass
# 定义模块中所有公开的类
__all__ = [
"MixtralForCausalLM",
"MixtralForQuestionAnswering",
"MixtralModel",
"MixtralPreTrainedModel",
"MixtralForSequenceClassification",
"MixtralForTokenClassification",
]
相关标签