EADST

Transformers Mixtral MOE模型代码中文注释 modular_mixtral.py

# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# 此代码基于EleutherAI的GPT-NeoX库以及此库中的GPT-NeoX和OPT实现。
# 它已从原始形式修改,以适应与Meta AI团队训练模型时使用的GPT-NeoX和OPT相比的
# 微小架构差异。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件
# 是基于"按原样"分发的,没有任何明示或暗示的担保或条件。
# 有关许可证下特定语言的权限和限制,请参阅许可证。
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modular_mixtral.py
"""PyTorch Mixtral模型。"""

from functools import partial
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import DynamicCache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
    MoeCausalLMOutputWithPast,
    MoeModelOutputWithPast,
)
from ...processing_utils import Unpack
from ...utils import (
    LossKwargs,
    logging,
)
# 从Mistral模型中导入组件,Mixtral在Mistral架构基础上扩展了专家混合
from ..mistral.modeling_mistral import (
    MistralAttention,
    MistralForCausalLM,
    MistralForQuestionAnswering,
    MistralForSequenceClassification,
    MistralForTokenClassification,
    MistralModel,
    MistralPreTrainedModel,
    MistralRMSNorm,
    MistralRotaryEmbedding,
)
from .configuration_mixtral import MixtralConfig


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
_CONFIG_FOR_DOC = "MixtralConfig"


def load_balancing_loss_func(
    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    num_experts: Optional[int] = None,
    top_k=2,
    attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
    r"""
    计算辅助负载平衡损失,如Switch Transformer中所实现的 - 使用PyTorch实现。

    参见Switch Transformer (https://arxiv.org/abs/2101.03961)获取更多详情。此函数实现了
    论文中方程(4)-(6)所提出的损失函数。它旨在惩罚专家路由过于不平衡的情况。

    参数:
        gate_logits:
            来自`gate`的logits,应该是一个元组,包含model.config.num_hidden_layers个张量,
            每个张量形状为[batch_size X sequence_length, num_experts]。
        num_experts:
            专家数量
        top_k:
            每个token路由到的专家数量,也可以解释为`top-k`路由参数。
        attention_mask (`torch.Tensor`, *可选*):
            在forward函数中使用的attention_mask,
            如果不为None,形状为[batch_size X sequence_length]。

    返回:
        辅助损失。
    """
    # 如果gate_logits为None或不是元组,返回0
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    # 如果gate_logits是元组,将所有层的gate logits连接起来
    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    # 计算路由权重(各专家的概率)
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    # 选择topk的专家
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    # 创建专家掩码,表示每个token路由到哪些专家
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # 计算路由到每个专家的token百分比
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # 计算路由到这些专家的平均概率
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        # 处理有注意力掩码的情况
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)

        # 创建专家注意力掩码,与expert_mask形状相同
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(-1, top_k, num_experts)
            .to(compute_device)
        )

        # 计算路由到每个专家的token百分比(考虑注意力掩码)
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
            expert_attention_mask, dim=0
        )

        # 创建路由专家注意力掩码,与tokens_per_expert形状相同
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(compute_device)
        )

        # 计算路由到这些专家的平均概率(考虑注意力掩码)
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
            router_per_expert_attention_mask, dim=0
        )

    # 计算总体损失:token分布与路由概率的乘积之和
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    # 乘以专家数量作为最终损失
    return overall_loss * num_experts


# MixtralBlockSparseTop2MLP类 - 实现了Mixtral模型中的单个稀疏专家MLP块
# 此类是混合专家系统(MoE)中的单个专家,每个专家是一个标准的feedforward网络
# 使用SwiGLU激活,包含三个投影层和一个激活函数
class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        # 设置前馈维度和隐藏维度
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        # 第一个投影层:隐藏维度 -> 前馈维度
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        # 第二个投影层:前馈维度 -> 隐藏维度
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        # 第三个投影层:隐藏维度 -> 前馈维度(用于门控机制)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        # 激活函数,默认为SiLU/Swish
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        # SwiGLU激活:w1输出经过激活函数后与w3输出相乘,实现门控机制
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        # 通过w2投影回原始维度
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states


# MixtralSparseMoeBlock类 - 实现Mixtral模型中的稀疏混合专家块
# 这是混合专家系统(MoE)的核心实现,使用块稀疏操作处理不平衡的token到专家的分配
class MixtralSparseMoeBlock(nn.Module):
    """
    此实现严格等同于具有完全容量的标准MoE(无丢弃tokens)。
    它更快,因为它使用块稀疏操作来处理tokens到专家的不平衡分配,
    而标准MoE要么:
    (1) 丢弃tokens,降低性能,或
    (2) 将容量因子设置为专家数量,从而在填充上浪费计算和内存。
    """

    def __init__(self, config):
        super().__init__()
        # 设置隐藏维度和前馈维度
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        # 本地专家数量
        self.num_experts = config.num_local_experts
        # 每个token选择的专家数量(通常为2)
        self.top_k = config.num_experts_per_tok

        # 门控网络 - 用于决定每个token应该路由到哪些专家
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

        # 创建专家模块列表
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

        # 抖动参数 - 用于训练过程中添加噪声,提高泛化能力
        self.jitter_noise = config.router_jitter_noise

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        # 获取输入的形状
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        # 训练时如果启用抖动噪声,对隐藏状态添加随机扰动
        if self.training and self.jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

        # 重塑隐藏状态为二维张量,便于处理
        hidden_states = hidden_states.view(-1, hidden_dim)

        # 计算路由logits:为每个token计算分配给每个专家的logits
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        # 计算路由权重(softmax),使用float类型以提高数值稳定性
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        # 选择top-k个专家及其权重
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        # 归一化权重,确保它们的和为1
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # 将权重转换回输入的数据类型
        routing_weights = routing_weights.to(hidden_states.dtype)

        # 创建存储最终隐藏状态的零张量
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # 将选定的专家进行one-hot编码,创建专家掩码
        # 这将用于轻松索引哪个专家将被调用
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # 获取被命中的专家列表(有token分配到的专家)
        expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()

        # 对每个被命中的专家执行前向传播
        for expert_idx in expert_hitted:
            # 获取当前专家层
            expert_layer = self.experts[expert_idx]
            # 找出分配给当前专家的token索引和它们的排名(top-1或top-2)
            idx, top_x = torch.where(expert_mask[expert_idx])

            # 索引正确的隐藏状态并计算当前专家的隐藏状态
            # 确保通过相应的token(top-1和top-2)的`routing_weights`乘以输出隐藏状态
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # 然而,`index_add_`只支持torch张量进行索引,所以我们将使用`top_x`张量
            # 将专家计算的结果添加到最终隐藏状态中
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        # 重塑最终隐藏状态为原始形状
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits


# MixtralRMSNorm类 - Mixtral使用的层归一化
# 直接继承自MistralRMSNorm,功能相同
class MixtralRMSNorm(MistralRMSNorm):
    pass


# MixtralAttention类 - Mixtral使用的注意力机制
# 直接继承自MistralAttention,功能相同
class MixtralAttention(MistralAttention):
    pass


# MixtralDecoderLayer类 - 实现Mixtral模型中的单个解码器层
# 包含自注意力机制和稀疏混合专家块
class MixtralDecoderLayer(nn.Module):
    def __init__(self, config: MixtralConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 自注意力层
        self.self_attn = MixtralAttention(config, layer_idx)

        # 块稀疏MoE层
        self.block_sparse_moe = MixtralSparseMoeBlock(config)
        # 输入层归一化
        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 注意力后层归一化
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        output_router_logits: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 必要的,但为了向后兼容而保留在这里
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        参数:
            hidden_states (`torch.FloatTensor`): 输入到层的形状为`(batch, seq_len, embed_dim)`的张量
            attention_mask (`torch.FloatTensor`, *可选*): 大小为
                `(batch, sequence_length)`的注意力掩码,其中填充元素由0表示。
            past_key_value (`Tuple(torch.FloatTensor)`, *可选*): 缓存的过去键和值投影状态
            output_attentions (`bool`, *可选*):
                是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回的张量中的`attentions`。
            output_router_logits (`bool`, *可选*):
                是否返回所有路由器的logits。它们对于计算路由器损失很有用,并且
                在推理过程中不应返回。
            use_cache (`bool`, *可选*):
                如果设置为`True`,则返回`past_key_values`键值状态,可用于加速解码
                (参见`past_key_values`)。
            cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
                描述输入序列标记在序列中位置的索引。
            kwargs (`dict`, *可选*):
                要忽略的任意kwargs,用于FSDP和其他向模型注入代码的方法
        """

        # 保存残差连接用的隐藏状态
        residual = hidden_states

        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )
        # 第一个残差连接:原始输入 + 注意力输出
        hidden_states = residual + hidden_states

        # 前馈神经网络 (MoE)
        # 保存第二个残差连接用的隐藏状态
        residual = hidden_states
        # 应用注意力后层归一化
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用块稀疏MoE
        hidden_states, router_logits = self.block_sparse_moe(hidden_states)
        # 第二个残差连接:注意力输出 + MoE输出
        hidden_states = residual + hidden_states

        # 准备输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,将其添加到输出元组
        if output_attentions:
            outputs += (self_attn_weights,)

        # 如果需要输出路由器logits,将其添加到输出元组
        if output_router_logits:
            outputs += (router_logits,)

        return outputs


# MixtralRotaryEmbedding类 - Mixtral使用的旋转位置编码
# 直接继承自MistralRotaryEmbedding,功能相同
class MixtralRotaryEmbedding(MistralRotaryEmbedding):
    pass


# MixtralPreTrainedModel类 - Mixtral预训练模型的基类
# 继承自MistralPreTrainedModel,但禁用静态缓存
class MixtralPreTrainedModel(MistralPreTrainedModel):
    # MoE模型不支持静态缓存,因为它们不能与torch.compile一起工作(不支持`torch.where(condition)`)
    _supports_static_cache = False

# MixtralModel类 - Mixtral模型的主体架构
# 继承自MistralModel,主要区别是使用MixtralDecoderLayer替代了MistralDecoderLayer
class MixtralModel(MistralModel):
    def __init__(self, config: MixtralConfig):
        super().__init__(config)
        # 创建层堆栈,使用MixtralDecoderLayer而非MistralDecoderLayer
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> MoeModelOutputWithPast:
        # 设置各种输出选项的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_router_logits = (
            output_router_logits if output_router_logits is not None else self.config.output_router_logits
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # 梯度检查点与缓存不兼容,训练时警告并禁用缓存
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # 如果启用缓存但未提供past_key_values,创建动态缓存
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        # 如果没有提供嵌入,通过词嵌入层获取
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 处理缓存位置和位置ID
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # 设置隐藏状态为输入嵌入
        hidden_states = inputs_embeds

        # 创建位置嵌入,在所有解码器层之间共享
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # 解码器层处理
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_router_logits = () if output_router_logits else None

        # 遍历所有解码器层
        for decoder_layer in self.layers:
            # 如果需要输出所有隐藏状态,保存当前隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 使用梯度检查点或直接前向传播
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    partial(decoder_layer.__call__, **flash_attn_kwargs),
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    output_router_logits,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    output_router_logits=output_router_logits,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            # 更新隐藏状态
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,保存当前层的注意力权重
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            # 如果需要输出路由器logits,保存当前层的路由器logits
            if output_router_logits:
                all_router_logits += (layer_outputs[-1],)

        # 最终的层归一化
        hidden_states = self.norm(hidden_states)

        # 添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回MoE特定的输出格式,包含路由器logits
        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            router_logits=all_router_logits,
        )


# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


# MixtralForCausalLM类 - 用于因果语言建模的Mixtral模型
# 继承自MistralForCausalLM,添加了MoE特定的功能
class MixtralForCausalLM(MistralForCausalLM):
    # 定义与主干模型共享权重的参数
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        # 使用MixtralModel替代MistralModel
        self.model = MixtralModel(config)
        # 路由器辅助损失系数 - 控制负载平衡损失的权重
        self.router_aux_loss_coef = config.router_aux_loss_coef
        # 专家数量和每个token的专家数量
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> MoeCausalLMOutputWithPast:
        r"""
            labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
                用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
                (参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
                损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。

            logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
                如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
                在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
                对于长序列或大词汇表尤为显著。
                如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
                这在使用打包张量格式(批次和序列长度的单一维度)时很有用。

        返回:

        示例:

        ```python
        >>> from transformers import AutoTokenizer, MixtralForCausalLM

        >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
        >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # 生成
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        # 设置输出选项
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_router_logits = (
            output_router_logits if output_router_logits is not None else self.config.output_router_logits
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 通过模型获取输出
        outputs: MoeModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            cache_position=cache_position,
            **kwargs,
        )

        # 获取最后的隐藏状态
        hidden_states = outputs.last_hidden_state
        # 只计算必要的logits
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        # 计算辅助损失(负载平衡损失)
        aux_loss = None
        if output_router_logits:
            aux_loss = load_balancing_loss_func(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
            # 如果有主损失,将辅助损失加入总损失(乘以系数)
            if labels is not None:
                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # 确保在相同设备上

        # 返回MoE特定的输出格式,包含辅助损失
        return MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
        )


# 以下类直接继承自Mistral对应的类,功能相同
# MixtralForSequenceClassification类 - 用于序列分类的Mixtral模型
class MixtralForSequenceClassification(MistralForSequenceClassification):
    pass


# MixtralForTokenClassification类 - 用于token分类的Mixtral模型
class MixtralForTokenClassification(MistralForTokenClassification):
    pass


# MixtralForQuestionAnswering类 - 用于问答任务的Mixtral模型
class MixtralForQuestionAnswering(MistralForQuestionAnswering):
    pass


# 定义模块中所有公开的类
__all__ = [
    "MixtralForCausalLM",
    "MixtralForQuestionAnswering",
    "MixtralModel",
    "MixtralPreTrainedModel",
    "MixtralForSequenceClassification",
    "MixtralForTokenClassification",
]
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Tracking Hilton printf WAN Use Tiktoken Llama FP32 SPIE Review Ubuntu HaggingFace Qwen2.5 Paper LaTeX Land Zip Color Template Food Tensor scipy 阿里云 Video PIP TensorRT diffusers Bert Input Markdown 多线程 NLTK SQL WebCrawler Animate NameSilo Base64 Web Michelin Interview Qwen2 LLM Quantization GPT4 DeepSeek tqdm ChatGPT hf Password Cloudreve 关于博主 Paddle Website 公式 ModelScope Conda 域名 Streamlit v2ray CUDA Dataset Anaconda UI Logo PDF Crawler MD5 Sklearn git-lfs CAM Random Qwen Algorithm TTS VSCode VGG-16 NLP Domain Jetson YOLO OpenCV LeetCode XGBoost Numpy PDB UNIX OCR Plotly GGML Bitcoin Diagram Safetensors Data RGB Miniforge Proxy 签证 Breakpoint LoRA Permission Transformers 视频信息 XML Augmentation LLAMA Git Clash 算法题 Mixtral uwsgi Math Hungarian Jupyter Django Python API JSON Statistics Card Image2Text SQLite Vmess Google Datetime CTC Heatmap TSV Ptyhon VPN AI Pytorch FlashAttention GIT tar Knowledge PyCharm Firewall Gemma 音频 搞笑 Shortcut COCO QWEN Plate Baidu PyTorch Windows CLAP SVR logger BF16 Pillow Vim BeautifulSoup mmap Translation Attention git CC CV C++ Freesound ResNet-50 财报 FP16 FP64 CSV Excel torchinfo v0.dev Distillation HuggingFace IndexTTS2 CEIR GPTQ transformers Docker llama.cpp TensorFlow Disk 净利润 DeepStream SAM Pickle Bin FP8 ONNX uWSGI Github FastAPI Hotel 报税 RAR 多进程 Linux 腾讯云 继承 版权 Nginx BTC GoogLeNet Quantize Claude InvalidArgumentError EXCEL OpenAI Magnet Pandas Bipartite 证件照 飞书
站点统计

本站现有博文311篇,共被浏览740096

本站已经建立2377天!

热门文章
文章归档
回到顶部