EADST

Transformers DeepSeek V3 模型代码中文注释 modeling_deepseek_v3.py

#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           此文件是从src/transformers/models/deepseek_v3/modular_deepseek_v3.py自动生成的。
#               请勿手动编辑此文件,因为任何编辑都将在生成文件时被覆盖。
#             如果需要进行任何更改,请直接修改modular_deepseek_v3.py文件。我们的CI会强制执行这一点。
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
'''
DeepSeekV3是一个创新的大型语言模型架构,具有多项独特的技术特点:

混合专家系统(MoE)架构

在深层部分使用DeepseekV3MoE替代传统的MLP
通过TopkRouter路由器实现高效的专家选择
结合共享专家和路由专家,平衡计算效率和模型容量


特殊的注意力机制设计

使用不同维度的头部:qk_rope_head_dim、qk_nope_head_dim、v_head_dim
应用LoRA结构进行参数高效的计算
支持交错旋转位置编码,提高计算效率


高级优化技术

支持Flash Attention 2、SDPA、FlexAttention等高效注意力实现
使用分组注意力优化内存使用
实现滑动窗口和缓存机制,加速长文本生成



DeepSeekV3模型的主要组件包括:

DeepseekV3RMSNorm: 高效的层归一化实现
DeepseekV3RotaryEmbedding: 旋转位置编码,支持多种变体
DeepseekV3MLP: 前馈神经网络,使用SwiGLU激活
DeepseekV3TopkRouter: 专家路由器,实现组级别的专家选择
DeepseekV3MoE: 混合专家系统,结合路由和共享专家
DeepseekV3Attention: 复杂的多头注意力机制,包含多种优化
DeepseekV3DecoderLayer: 单个Transformer解码器层,根据层索引选择MLP或MoE
DeepseekV3Model: 主体模型架构,包含完整的Transformer解码器堆栈
DeepseekV3ForCausalLM: 用于生成任务的因果语言模型

这个模型结合了当代大语言模型的多项技术进步,尤其是在混合专家系统和注意力机制优化方面有独特设计,允许模型在保持推理效率的同时大幅增加参数规模,提高模型能力。
'''
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
import math
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
    LossKwargs,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    can_return_tuple,
    is_torch_flex_attn_available,
    logging,
    replace_return_docstrings,
)
from .configuration_deepseek_v3 import DeepseekV3Config


# 检查FlexAttention是否可用
if is_torch_flex_attn_available():
    from torch.nn.attention.flex_attention import BlockMask

    from ...integrations.flex_attention import make_flex_block_causal_mask


logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeepseekV3Config"


# DeepseekV3RMSNorm类 - 实现了DeepseekV3模型中的层归一化
# 使用RMS归一化代替LayerNorm,不进行均值中心化,只进行方差归一化
@use_kernel_forward_from_hub("RMSNorm")
class DeepseekV3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV3RMSNorm等同于T5LayerNorm
        """
        super().__init__()
        # 初始化权重参数为全1向量
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 设置方差中的epsilon值,防止除零错误
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # 保存输入的数据类型
        input_dtype = hidden_states.dtype
        # 转换为float32进行高精度计算
        hidden_states = hidden_states.to(torch.float32)
        # 计算方差(平方的均值)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 归一化:除以方差的平方根
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 乘以权重并转回原始数据类型
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        # 为打印模块信息提供额外的表示
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# DeepseekV3RotaryEmbedding类 - 实现了DeepseekV3模型中的旋转位置编码
# 支持多种RoPE变体,用于处理位置信息
class DeepseekV3RotaryEmbedding(nn.Module):
    def __init__(self, config: DeepseekV3Config, device=None):
        super().__init__()
        # 向后兼容:处理rope_type参数
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        # 缓存的最大序列长度
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        # 获取RoPE初始化函数
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        # 初始化频率反转和注意力缩放
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # 高级用户:用于高级RoPE类型(例如dynamic rope)
    def forward(self, x, position_ids):
        # 扩展反转频率到批次大小
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        # 确定设备类型,特殊处理MPS设备
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # 强制使用float32
            # 计算频率
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            # 合并频率
            emb = torch.cat((freqs, freqs), dim=-1)
            # 计算余弦和正弦部分,并应用注意力缩放
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        # 返回结果,转换为输入的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# DeepseekV3MLP类 - 实现了DeepseekV3模型中的前馈神经网络部分
# 使用SwiGLU激活函数,包含三个投影层:gate_proj、up_proj和down_proj
class DeepseekV3MLP(nn.Module):
    def __init__(self, config, hidden_size=None, intermediate_size=None):
        super().__init__()
        self.config = config
        # 设置隐藏维度,如果未指定则使用配置中的值
        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
        # 设置中间维度,如果未指定则使用配置中的值
        self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size

        # gate_proj:将隐藏状态投影到中间维度的门控投影层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # up_proj:将隐藏状态投影到中间维度的上投影层
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # down_proj:将中间表示投影回隐藏维度的下投影层
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 激活函数,默认为SiLU(Sigmoid Linear Unit)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # SwiGLU激活:先计算gate_proj经过激活函数的结果,再与up_proj的结果相乘
        # 然后通过down_proj投影回原始维度
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


# DeepseekV3TopkRouter类 - 实现了DeepseekV3模型中的TopK路由器
# 用于混合专家系统(MoE)中选择专家的路由机制
class DeepseekV3TopkRouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 每个token选择的专家数量
        self.top_k = config.num_experts_per_tok
        # 路由专家的数量
        self.n_routed_experts = config.n_routed_experts
        # 路由缩放因子
        self.routed_scaling_factor = config.routed_scaling_factor
        # 专家组数量
        self.n_group = config.n_group
        # 每个组选择的topk数量
        self.topk_group = config.topk_group
        # 是否规范化topk概率
        self.norm_topk_prob = config.norm_topk_prob

        # 路由权重
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
        # 专家分数修正偏置
        self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))

    @torch.no_grad()
    def get_topk_indices(self, scores):
        # 构建选择分数,添加修正偏置
        scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
        # 将专家分组,计算每组的顶部得分和
        group_scores = (
            scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
            .topk(2, dim=-1)[0]
            .sum(dim=-1)
        )
        # 选择顶部的组
        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
        # 创建组掩码
        group_mask = torch.zeros_like(group_scores)
        group_mask.scatter_(1, group_idx, 1)
        # 扩展掩码到专家级别
        score_mask = (
            group_mask.unsqueeze(-1)
            .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
            .reshape(-1, self.n_routed_experts)
        )
        # 掩盖不需要的专家分数
        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
        # 获取最终的topk专家索引
        topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
        return topk_indices

    def forward(self, hidden_states):
        # 重塑隐藏状态
        hidden_states = hidden_states.view(-1, self.config.hidden_size)
        # 计算路由logits
        router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
        # 应用sigmoid激活
        scores = router_logits.sigmoid()
        # 获取topk索引
        topk_indices = self.get_topk_indices(scores)
        # 获取topk权重
        topk_weights = scores.gather(1, topk_indices)
        # 如果需要规范化topk概率
        if self.norm_topk_prob:
            denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
            topk_weights /= denominator
        # 应用路由缩放因子
        topk_weights = topk_weights * self.routed_scaling_factor
        return topk_indices, topk_weights


# DeepseekV3MoE类 - 实现了DeepseekV3模型中的混合专家系统
# 包含路由专家和共享专家,通过门控机制选择专家
class DeepseekV3MoE(nn.Module):
    """
    包含共享专家的混合专家模块。
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        # 创建专家模块列表
        self.experts = nn.ModuleList(
            [
                DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
                for _ in range(config.n_routed_experts)
            ]
        )
        # 创建门控网络
        self.gate = DeepseekV3TopkRouter(config)
        # 创建共享专家
        self.shared_experts = DeepseekV3MLP(
            config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
        )

    def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
        r"""
        贡献请求!我现在没有时间优化这个,但专家权重需要被融合,
        以避免在这里进行循环(deepseek有256个专家,所以是的)。
        """
        # 创建最终隐藏状态的零张量
        final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
        # 创建专家掩码
        expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
        expert_mask = expert_mask.permute(2, 0, 1)

        # 遍历所有专家
        for expert_idx in range(len(self.experts)):
            expert = self.experts[expert_idx]
            mask = expert_mask[expert_idx]
            # 找出分配给当前专家的token索引和权重索引
            token_indices, weight_indices = torch.where(mask)

            # 如果有token分配给当前专家
            if token_indices.numel() > 0:
                # 获取专家权重
                expert_weights = topk_weights[token_indices, weight_indices]
                # 获取专家输入
                expert_input = hidden_states[token_indices]
                # 计算专家输出
                expert_output = expert(expert_input)
                # 加权输出
                weighted_output = expert_output * expert_weights.unsqueeze(-1)
                # 将加权输出添加到最终隐藏状态
                final_hidden_states.index_add_(0, token_indices, weighted_output)

        # 在原始deepseek中,专家的输出在离开此模块时被收集
        # 因此moe模块本身是一个IsolatedParallel模块
        # 所有专家都是"本地的",意味着我们分片但不收集
        return final_hidden_states.type(hidden_states.dtype)

    def forward(self, hidden_states):
        # 保存残差连接
        residuals = hidden_states
        # 保存原始形状
        orig_shape = hidden_states.shape
        # 获取topk索引和权重
        topk_indices, topk_weights = self.gate(hidden_states)
        # 重塑隐藏状态
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        # 应用MoE
        hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
        # 添加共享专家的输出
        hidden_states = hidden_states + self.shared_experts(residuals)
        return hidden_states


# rotate_half函数 - 旋转输入张量的一半隐藏维度
def rotate_half(x):
    """旋转输入张量的一半隐藏维度。"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# apply_rotary_pos_emb函数 - 应用旋转位置编码到查询和键张量
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`, *可选*):
            已弃用且未使用。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# repeat_kv函数 - 实现了分组查询注意力(GQA)中的键值头复制操作
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    这相当于torch.repeat_interleave(x, dim=1, repeats=n_rep)。隐藏状态从(batch,
    num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# eager_attention_forward函数 - 实现了标准的注意力前向传播计算
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    # 应用GQA:复制键和值状态以匹配查询头的数量
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    # 计算注意力分数:查询和键的矩阵乘法,并应用缩放
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    # 如果提供了注意力掩码,将其应用到注意力分数上
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # 对注意力分数应用softmax得到注意力权重,使用float32以提高数值稳定性
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    # 在训练时应用dropout
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    # 使用注意力权重对值进行加权汇总
    attn_output = torch.matmul(attn_weights, value_states)
    # 调整输出的维度顺序并确保内存连续
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# apply_rotary_pos_emb_interleave函数 - 应用交错权重的旋转位置编码
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    r"""
    TODO 让我们使用原始的freqcis计算来避免视图
    转置+重塑!这不是优化的!
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`):
            对应于查询和键张量的标记的位置索引。例如,这可以用于
            在使用KV缓存时传递偏移的位置ID。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # 对查询进行交错重塑
    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    # 对键进行交错重塑
    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    # 应用旋转位置编码
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# yarn_get_mscale函数 - 计算YaRN缩放的比例因子
def yarn_get_mscale(scale=1, mscale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


# DeepseekV3Attention类 - 实现DeepseekV3模型中的多头注意力机制
# 包含多种特殊设计,如不同的头部维度、LoRA等优化
class DeepseekV3Attention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力机制"""

    def __init__(self, config: DeepseekV3Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        # 每个键值头对应的查询头数量
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        # 注意力dropout率
        self.attention_dropout = config.attention_dropout
        # 注意力头数量
        self.num_heads = config.num_attention_heads
        # RoPE的theta参数
        self.rope_theta = config.rope_theta
        # 查询的LoRA秩
        self.q_lora_rank = config.q_lora_rank
        # 应用RoPE的查询键头维度
        self.qk_rope_head_dim = config.qk_rope_head_dim
        # 键值的LoRA秩
        self.kv_lora_rank = config.kv_lora_rank
        # 值头维度
        self.v_head_dim = config.v_head_dim
        # 不应用RoPE的查询键头维度
        self.qk_nope_head_dim = config.qk_nope_head_dim
        # 查询键总头维度
        self.qk_head_dim = config.qk_head_dim

        # 标记这是因果注意力
        self.is_causal = True
        # 查询的LoRA投影A
        self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
        # 查询的LoRA层归一化
        self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
        # 查询的LoRA投影B
        self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)

        # 键值的LoRA投影A与多查询注意力
        self.kv_a_proj_with_mqa = nn.Linear(
            config.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=config.attention_bias,
        )
        # 键值的LoRA层归一化
        self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
        # 键值的LoRA投影B
        self.kv_b_proj = nn.Linear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
        )

        # 输出投影
        self.o_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            config.hidden_size,
            bias=config.attention_bias,
        )

        # 缩放因子
        self.scaling = self.qk_head_dim ** (-0.5)
        # 如果启用了RoPE缩放
        if self.config.rope_scaling is not None:
            mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
            scaling_factor = self.config.rope_scaling["factor"]
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.scaling = self.scaling * mscale * mscale

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取批次大小和序列长度
        batch_size, seq_length = hidden_states.shape[:-1]
        # 定义查询形状和键形状
        query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
        key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)

        # 查询状态计算:通过LoRA投影A和B,加上层归一化
        q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
        # 将查询分为不应用RoPE的部分和应用RoPE的部分
        q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        # 压缩的键值计算
        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        # 分离键和旋转部分
        k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

        # 键值状态计算:通过LoRA投影B和层归一化
        k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
        # 分离键和值状态
        k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        # 重塑键旋转部分
        k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

        # 获取位置嵌入的余弦和正弦部分
        cos, sin = position_embeddings
        # 根据配置选择RoPE应用方式
        if self.config.rope_interleave:  # 支持使用交错权重来提高效率
            q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
        else:
            q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
        # 扩展键旋转部分到与键传递部分相同的形状
        k_rot = k_rot.expand(*k_pass.shape[:-1], -1)

        # 合并查询和键的各部分
        query_states = torch.cat((q_pass, q_rot), dim=-1)
        key_states = torch.cat((k_pass, k_rot), dim=-1)

        # 如果有过去的键值缓存,更新当前的键值状态
        if past_key_value is not None:
            # sin和cos特定于RoPE模型;静态缓存需要cache_position
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 对于flash_attention_2,如果query和value的维度不同,需要padding
        if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
            value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

        # 选择注意力计算接口
        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        # 调用选定的注意力接口计算注意力输出和权重
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        # 如果使用flash_attention_2,需要裁剪输出
        if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
            attn_output = attn_output[:, :, :, : self.v_head_dim]

        # 重塑注意力输出并应用输出投影
        attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


# DeepseekV3DecoderLayer类 - 实现DeepseekV3模型中的单个Transformer解码器层
# 继承自GradientCheckpointingLayer,支持梯度检查点以节省内存
class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: DeepseekV3Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 自注意力层
        self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)

        # 根据层索引选择MLP或MoE
        if layer_idx >= config.first_k_dense_replace:
            # 对于后面的层,使用混合专家系统
            self.mlp = DeepseekV3MoE(config)
        else:
            # 对于前面的层,使用标准MLP
            self.mlp = DeepseekV3MLP(config)

        # 输入层归一化,应用于自注意力前
        self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 注意力后层归一化,应用于MLP前
        self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 必要的,但为了向后兼容而保留在这里
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存残差连接用的隐藏状态
        residual = hidden_states
        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        # 第一个残差连接:原始输入 + 注意力输出
        hidden_states = residual + hidden_states

        # 前馈神经网络或MoE
        # 保存第二个残差连接用的隐藏状态
        residual = hidden_states
        # 应用注意力后层归一化
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用MLP或MoE
        hidden_states = self.mlp(hidden_states)
        # 第二个残差连接:注意力输出 + MLP/MoE输出
        hidden_states = residual + hidden_states

        # 准备返回值
        outputs = (hidden_states,)
        # 如果需要返回注意力权重,将其添加到输出元组
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


# DEEPSEEK_V3模型文档字符串开始
DEEPSEEK_V3_START_DOCSTRING = r"""
    该模型继承自[`PreTrainedModel`]。查看超类文档了解库为所有模型实现的通用方法
    (如下载或保存、调整输入嵌入大小、剪枝头等)

    该模型也是PyTorch的[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)子类。
    将其作为常规PyTorch模块使用,并参考PyTorch文档了解与一般用法和行为相关的所有事项。

    参数:
        config ([`DeepseekV3Config`]):
            包含模型所有参数的模型配置类。使用配置文件初始化只加载模型的配置,
            而不加载与模型关联的权重。查看[`~PreTrainedModel.from_pretrained`]方法
            来加载模型权重。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸DeepseekV3模型,输出原始隐藏状态,顶部没有任何特定的头。",
    DEEPSEEK_V3_START_DOCSTRING,
)
# DeepseekV3PreTrainedModel类 - 所有DeepseekV3模型变体的基础类
# 定义了各种模型功能和支持的特性
class DeepseekV3PreTrainedModel(PreTrainedModel):
    # 设置配置类
    config_class = DeepseekV3Config
    # 基础模型前缀
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 指定不拆分的模块
    _no_split_modules = ["DeepseekV3DecoderLayer"]
    # 设备放置时跳过的键
    _skip_keys_device_placement = ["past_key_values"]
    # 支持Flash Attention 2优化
    _supports_flash_attn_2 = True
    # 支持缩放点积注意力(SDPA)优化
    _supports_sdpa = True
    # 支持灵活注意力优化
    _supports_flex_attn = True
    # 支持缓存类
    _supports_cache_class = True
    # 支持量化缓存
    _supports_quantized_cache = True
    # 支持静态缓存
    _supports_static_cache = True
    # 支持注意力后端
    _supports_attention_backend = True

    def _init_weights(self, module):
        # 初始化权重的标准差
        std = self.config.initializer_range
        # 线性层权重初始化
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 嵌入层权重初始化
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # RMSNorm层权重初始化为1.0
        elif isinstance(module, DeepseekV3RMSNorm):
            module.weight.data.fill_(1.0)
        # TopkRouter权重初始化
        elif isinstance(module, DeepseekV3TopkRouter):
            module.weight.data.normal_(mean=0.0, std=std)


# DEEPSEEK_V3输入文档字符串
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
    参数:
        input_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。如果提供了填充,将默认忽略。

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            [什么是输入ID?](../glossary#input-ids)
        attention_mask (`torch.Tensor` 形状为 `(batch_size, sequence_length) 或 `BlockMask`, *可选*):
            用于避免在填充标记索引上执行注意力的掩码。掩码值选择为`[0, 1]`:

            - 1表示**未被掩盖的**标记,
            - 0表示**被掩盖的**标记。

            如果模型配置为使用flex_attention,它将尝试将掩码Tensor转换为BlockMask,
            但您也可以直接在此处传递`BlockMask`对象。

            [什么是注意力掩码?](../glossary#attention-mask)

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            如果使用了`past_key_values`,可选地只需要输入最后的`input_ids`(参见
            `past_key_values`)。

            如果您想更改填充行为,应阅读[`modeling_opt._prepare_decoder_attention_mask`]
            并根据您的需求进行修改。有关默认策略的更多信息,请参见[论文](https://arxiv.org/abs/1910.13461)中的图1。

            - 1表示头部**未被掩盖**,
            - 0表示头部**被掩盖**。
        position_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
            每个输入序列标记在位置嵌入中的索引。在范围`[0, config.n_positions - 1]`中选择。

            [什么是位置ID?](../glossary#position-ids)
        past_key_values (`Cache`, *可选*):
            预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可用于加速顺序解码。
            这通常由模型在之前阶段的解码中返回的`past_key_values`组成,当`use_cache=True`或
            `config.use_cache=True`时。

            它是一个[`~cache_utils.Cache`]实例。有关更多详细信息,请参见我们的[kv缓存指南](https://huggingface.co/docs/transformers/en/kv_cache)。

            如果使用了`past_key_values`,用户可以选择只输入最后的`input_ids`(那些没有给这个模型的过去键值状态的输入),
            形状为`(batch_size, 1)`,而不是所有的`input_ids`,形状为`(batch_size, sequence_length)`。
        inputs_embeds (`torch.FloatTensor` 形状为 `(batch_size, sequence_length, hidden_size)`, *可选*):
            可选地,您可以选择直接传递嵌入表示,而不是传递`input_ids`。这在您想要更多地控制
            如何将`input_ids`索引转换为相关向量比模型的内部嵌入查找矩阵更有用。
        use_cache (`bool`, *可选*):
            如果设置为`True`,将返回`past_key_values`键值状态,可用于加速解码(参见
            `past_key_values`)。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。有关更多详细信息,请参见`attentions`下的返回
            张量。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。有关更多详细信息,请参见`hidden_states`下的返回
            张量。
        return_dict (`bool`, *可选*):
            是否返回[`~utils.ModelOutput`]而不是普通元组。
        cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
            描述输入序列标记在序列中位置的索引。与`position_ids`不同,这个张量不受填充影响。
            它用于在正确的位置更新缓存并推断完整的序列长度。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸DeepseekV3模型,输出原始隐藏状态,顶部没有任何特定的头。",
    DEEPSEEK_V3_START_DOCSTRING,
)
# DeepseekV3Model类 - DeepseekV3的主体模型结构
# 包含完整的Transformer解码器堆栈,实现了自回归生成的核心功能
class DeepseekV3Model(DeepseekV3PreTrainedModel):
    """
    Transformer解码器,由*config.num_hidden_layers*层组成。每一层都是一个[`DeepseekV3DecoderLayer`]

    参数:
        config: DeepseekV3Config
    """

    # 加载时忽略意外的键
    _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]

    def __init__(self, config: DeepseekV3Config):
        super().__init__(config)
        # 设置填充索引和词汇表大小
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词嵌入层:将输入token转换为隐藏表示
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建解码器层堆栈
        self.layers = nn.ModuleList(
            [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 最终的层归一化
        self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 旋转位置编码
        self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
        # 梯度检查点标志,默认关闭
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        # 设置默认值,优先使用传入参数,其次使用配置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # 梯度检查点与缓存不兼容,训练时警告并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 检查过去键值类型兼容性
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        # 如果没有提供嵌入,通过词嵌入层获取
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果启用缓存但未提供past_key_values,创建动态缓存
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        # 处理缓存位置:确定序列中每个token的位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有提供位置ID,从缓存位置创建
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码:确保自回归属性,即当前token只能看到过去的token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # 设置初始隐藏状态为输入嵌入
        hidden_states = inputs_embeds

        # 创建位置嵌入,在所有解码器层之间共享
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # 解码器层处理
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # 遍历所有解码器层
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            # 如果需要输出所有隐藏状态,保存当前隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 前向传播通过当前解码器层
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

            # 更新隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,保存当前层的注意力权重
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # 最终的层归一化
        hidden_states = self.norm(hidden_states)

        # 添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回模型输出
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: Union[torch.Tensor, "BlockMask"],
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool = False,
    ):
        """
        更新因果掩码:根据不同的注意力实现方式和缓存类型处理掩码
        处理各种特殊情况,包括Flash Attention 2、滑动窗口等
        """
        # Flash Attention 2的特殊处理
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None

        # Flex Attention的特殊处理
        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            return attention_mask

        # 对于SDPA(缩放点积注意力),尽可能依赖其`is_causal`参数而非`attn_mask`参数
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # 当输出注意力为True时,sdpa实现的forward方法调用eager实现的forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # 获取输入张量的数据类型和设备
        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]

        # 确定目标长度
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # 生成4D因果掩码
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # SDPA的特殊处理:对于完全掩码的行,允许注意所有token
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu", "npu"]
            and not output_attentions
        ):
            # 注意完全掩码行中的所有token,例如使用左填充时的相关第一行
            # 这是F.scaled_dot_product_attention内存高效注意力路径所需的
            # 详情:https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        创建一个形状为`(batch_size, 1, query_length, key_value_length)`的因果4D掩码,
        从形状为`(batch_size, key_value_length)`的2D掩码创建,
        或如果输入的`attention_mask`已经是4D,则不做任何处理。

        参数:
            attention_mask (`torch.Tensor`):
                形状为`(batch_size, key_value_length)`的2D注意力掩码,
                或形状为`(batch_size, 1, query_length, key_value_length)`的4D注意力掩码。
            sequence_length (`int`):
                正在处理的序列长度。
            target_length (`int`):
                目标长度:使用静态缓存生成时,掩码应与静态缓存一样长,
                以考虑0填充、尚未填充的缓存部分。
            dtype (`torch.dtype`):
                用于4D注意力掩码的数据类型。
            device (`torch.device`):
                放置4D注意力掩码的设备。
            cache_position (`torch.Tensor`):
                描述输入序列标记在序列中位置的索引。
            batch_size (`torch.Tensor`):
                批次大小。
        """
        # 如果掩码已经是4D,直接使用
        if attention_mask is not None and attention_mask.dim() == 4:
            # 在这种情况下,我们假设掩码已经以反转形式提供,不需要反转或切片
            causal_mask = attention_mask
        else:
            # 创建填充有最小值的因果掩码
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            # 如果序列长度不是1,创建上三角掩码(对角线以上为1)
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            # 应用缓存位置掩码
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            # 扩展掩码维度以匹配批次大小
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

            # 如果提供了注意力掩码,合并掩码
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # 复制到连续内存以进行原地编辑
                mask_length = attention_mask.shape[-1]
                # 合并因果掩码和注意力掩码
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                    causal_mask.device
                )
                padding_mask = padding_mask == 0
                # 应用掩码
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


# DeepseekV3ForCausalLM类 - 用于因果语言建模的DeepseekV3模型
# 在基础DeepseekV3Model上添加了语言模型头,用于生成任务
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
    # 定义与主干模型共享权重的参数
    _tied_weights_keys = ["lm_head.weight"]
    # 张量并行计划:lm_head按列方向划分并复制
    _tp_plan = {"lm_head": "colwise_rep"}
    # 流水线并行计划:lm_head接收hidden_states并输出logits
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.model = DeepseekV3Model(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 创建语言模型头:将隐藏状态映射到词汇表大小的logits
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """获取输出嵌入层(lm_head)"""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """设置输出嵌入层(lm_head)"""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """设置解码器模型"""
        self.model = decoder

    def get_decoder(self):
        """获取解码器模型"""
        return self.model

    @can_return_tuple
    @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPast:
        r"""
            labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
                用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
                (参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
                损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。

            logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
                如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
                在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
                对于长序列或大词汇表尤为显著。
                如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
                这在使用打包张量格式(批次和序列长度的单一维度)时很有用。

        返回:

        示例:

        ```python
        >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM

        >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # 生成
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        # 设置输出选项
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 通过基础模型获取特征
        # 解码器输出包含(dec_features, layer_state, dec_hidden, dec_attn)
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        # 获取最后的隐藏状态
        hidden_states = outputs.last_hidden_state
        # 只计算必要的logits,如果不计算损失则不将其转换为浮点类型
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        # 返回因果语言模型输出,包含损失、logits、过去键值状态等
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 定义模块中所有公开的类
__all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
站点统计

本站现有博文287篇,共被浏览558183

本站已经建立2144天!

热门文章
文章归档
回到顶部