EADST

Transformers DeepSeek V3 模型代码中文注释 modeling_deepseek_v3.py

#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           此文件是从src/transformers/models/deepseek_v3/modular_deepseek_v3.py自动生成的。
#               请勿手动编辑此文件,因为任何编辑都将在生成文件时被覆盖。
#             如果需要进行任何更改,请直接修改modular_deepseek_v3.py文件。我们的CI会强制执行这一点。
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
'''
DeepSeekV3是一个创新的大型语言模型架构,具有多项独特的技术特点:

混合专家系统(MoE)架构

在深层部分使用DeepseekV3MoE替代传统的MLP
通过TopkRouter路由器实现高效的专家选择
结合共享专家和路由专家,平衡计算效率和模型容量


特殊的注意力机制设计

使用不同维度的头部:qk_rope_head_dim、qk_nope_head_dim、v_head_dim
应用LoRA结构进行参数高效的计算
支持交错旋转位置编码,提高计算效率


高级优化技术

支持Flash Attention 2、SDPA、FlexAttention等高效注意力实现
使用分组注意力优化内存使用
实现滑动窗口和缓存机制,加速长文本生成



DeepSeekV3模型的主要组件包括:

DeepseekV3RMSNorm: 高效的层归一化实现
DeepseekV3RotaryEmbedding: 旋转位置编码,支持多种变体
DeepseekV3MLP: 前馈神经网络,使用SwiGLU激活
DeepseekV3TopkRouter: 专家路由器,实现组级别的专家选择
DeepseekV3MoE: 混合专家系统,结合路由和共享专家
DeepseekV3Attention: 复杂的多头注意力机制,包含多种优化
DeepseekV3DecoderLayer: 单个Transformer解码器层,根据层索引选择MLP或MoE
DeepseekV3Model: 主体模型架构,包含完整的Transformer解码器堆栈
DeepseekV3ForCausalLM: 用于生成任务的因果语言模型

这个模型结合了当代大语言模型的多项技术进步,尤其是在混合专家系统和注意力机制优化方面有独特设计,允许模型在保持推理效率的同时大幅增加参数规模,提高模型能力。
'''
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
import math
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
    LossKwargs,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    can_return_tuple,
    is_torch_flex_attn_available,
    logging,
    replace_return_docstrings,
)
from .configuration_deepseek_v3 import DeepseekV3Config


# 检查FlexAttention是否可用
if is_torch_flex_attn_available():
    from torch.nn.attention.flex_attention import BlockMask

    from ...integrations.flex_attention import make_flex_block_causal_mask


logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeepseekV3Config"


# DeepseekV3RMSNorm类 - 实现了DeepseekV3模型中的层归一化
# 使用RMS归一化代替LayerNorm,不进行均值中心化,只进行方差归一化
@use_kernel_forward_from_hub("RMSNorm")
class DeepseekV3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        DeepseekV3RMSNorm等同于T5LayerNorm
        """
        super().__init__()
        # 初始化权重参数为全1向量
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 设置方差中的epsilon值,防止除零错误
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # 保存输入的数据类型
        input_dtype = hidden_states.dtype
        # 转换为float32进行高精度计算
        hidden_states = hidden_states.to(torch.float32)
        # 计算方差(平方的均值)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 归一化:除以方差的平方根
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 乘以权重并转回原始数据类型
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        # 为打印模块信息提供额外的表示
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# DeepseekV3RotaryEmbedding类 - 实现了DeepseekV3模型中的旋转位置编码
# 支持多种RoPE变体,用于处理位置信息
class DeepseekV3RotaryEmbedding(nn.Module):
    def __init__(self, config: DeepseekV3Config, device=None):
        super().__init__()
        # 向后兼容:处理rope_type参数
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        # 缓存的最大序列长度
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        # 获取RoPE初始化函数
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        # 初始化频率反转和注意力缩放
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # 高级用户:用于高级RoPE类型(例如dynamic rope)
    def forward(self, x, position_ids):
        # 扩展反转频率到批次大小
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        # 确定设备类型,特殊处理MPS设备
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # 强制使用float32
            # 计算频率
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            # 合并频率
            emb = torch.cat((freqs, freqs), dim=-1)
            # 计算余弦和正弦部分,并应用注意力缩放
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        # 返回结果,转换为输入的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# DeepseekV3MLP类 - 实现了DeepseekV3模型中的前馈神经网络部分
# 使用SwiGLU激活函数,包含三个投影层:gate_proj、up_proj和down_proj
class DeepseekV3MLP(nn.Module):
    def __init__(self, config, hidden_size=None, intermediate_size=None):
        super().__init__()
        self.config = config
        # 设置隐藏维度,如果未指定则使用配置中的值
        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
        # 设置中间维度,如果未指定则使用配置中的值
        self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size

        # gate_proj:将隐藏状态投影到中间维度的门控投影层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # up_proj:将隐藏状态投影到中间维度的上投影层
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # down_proj:将中间表示投影回隐藏维度的下投影层
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 激活函数,默认为SiLU(Sigmoid Linear Unit)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # SwiGLU激活:先计算gate_proj经过激活函数的结果,再与up_proj的结果相乘
        # 然后通过down_proj投影回原始维度
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


# DeepseekV3TopkRouter类 - 实现了DeepseekV3模型中的TopK路由器
# 用于混合专家系统(MoE)中选择专家的路由机制
class DeepseekV3TopkRouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 每个token选择的专家数量
        self.top_k = config.num_experts_per_tok
        # 路由专家的数量
        self.n_routed_experts = config.n_routed_experts
        # 路由缩放因子
        self.routed_scaling_factor = config.routed_scaling_factor
        # 专家组数量
        self.n_group = config.n_group
        # 每个组选择的topk数量
        self.topk_group = config.topk_group
        # 是否规范化topk概率
        self.norm_topk_prob = config.norm_topk_prob

        # 路由权重
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
        # 专家分数修正偏置
        self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))

    @torch.no_grad()
    def get_topk_indices(self, scores):
        # 构建选择分数,添加修正偏置
        scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
        # 将专家分组,计算每组的顶部得分和
        group_scores = (
            scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
            .topk(2, dim=-1)[0]
            .sum(dim=-1)
        )
        # 选择顶部的组
        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
        # 创建组掩码
        group_mask = torch.zeros_like(group_scores)
        group_mask.scatter_(1, group_idx, 1)
        # 扩展掩码到专家级别
        score_mask = (
            group_mask.unsqueeze(-1)
            .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
            .reshape(-1, self.n_routed_experts)
        )
        # 掩盖不需要的专家分数
        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
        # 获取最终的topk专家索引
        topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
        return topk_indices

    def forward(self, hidden_states):
        # 重塑隐藏状态
        hidden_states = hidden_states.view(-1, self.config.hidden_size)
        # 计算路由logits
        router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
        # 应用sigmoid激活
        scores = router_logits.sigmoid()
        # 获取topk索引
        topk_indices = self.get_topk_indices(scores)
        # 获取topk权重
        topk_weights = scores.gather(1, topk_indices)
        # 如果需要规范化topk概率
        if self.norm_topk_prob:
            denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
            topk_weights /= denominator
        # 应用路由缩放因子
        topk_weights = topk_weights * self.routed_scaling_factor
        return topk_indices, topk_weights


# DeepseekV3MoE类 - 实现了DeepseekV3模型中的混合专家系统
# 包含路由专家和共享专家,通过门控机制选择专家
class DeepseekV3MoE(nn.Module):
    """
    包含共享专家的混合专家模块。
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        # 创建专家模块列表
        self.experts = nn.ModuleList(
            [
                DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
                for _ in range(config.n_routed_experts)
            ]
        )
        # 创建门控网络
        self.gate = DeepseekV3TopkRouter(config)
        # 创建共享专家
        self.shared_experts = DeepseekV3MLP(
            config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
        )

    def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
        r"""
        贡献请求!我现在没有时间优化这个,但专家权重需要被融合,
        以避免在这里进行循环(deepseek有256个专家,所以是的)。
        """
        # 创建最终隐藏状态的零张量
        final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
        # 创建专家掩码
        expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
        expert_mask = expert_mask.permute(2, 0, 1)

        # 遍历所有专家
        for expert_idx in range(len(self.experts)):
            expert = self.experts[expert_idx]
            mask = expert_mask[expert_idx]
            # 找出分配给当前专家的token索引和权重索引
            token_indices, weight_indices = torch.where(mask)

            # 如果有token分配给当前专家
            if token_indices.numel() > 0:
                # 获取专家权重
                expert_weights = topk_weights[token_indices, weight_indices]
                # 获取专家输入
                expert_input = hidden_states[token_indices]
                # 计算专家输出
                expert_output = expert(expert_input)
                # 加权输出
                weighted_output = expert_output * expert_weights.unsqueeze(-1)
                # 将加权输出添加到最终隐藏状态
                final_hidden_states.index_add_(0, token_indices, weighted_output)

        # 在原始deepseek中,专家的输出在离开此模块时被收集
        # 因此moe模块本身是一个IsolatedParallel模块
        # 所有专家都是"本地的",意味着我们分片但不收集
        return final_hidden_states.type(hidden_states.dtype)

    def forward(self, hidden_states):
        # 保存残差连接
        residuals = hidden_states
        # 保存原始形状
        orig_shape = hidden_states.shape
        # 获取topk索引和权重
        topk_indices, topk_weights = self.gate(hidden_states)
        # 重塑隐藏状态
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        # 应用MoE
        hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
        # 添加共享专家的输出
        hidden_states = hidden_states + self.shared_experts(residuals)
        return hidden_states


# rotate_half函数 - 旋转输入张量的一半隐藏维度
def rotate_half(x):
    """旋转输入张量的一半隐藏维度。"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# apply_rotary_pos_emb函数 - 应用旋转位置编码到查询和键张量
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`, *可选*):
            已弃用且未使用。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# repeat_kv函数 - 实现了分组查询注意力(GQA)中的键值头复制操作
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    这相当于torch.repeat_interleave(x, dim=1, repeats=n_rep)。隐藏状态从(batch,
    num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# eager_attention_forward函数 - 实现了标准的注意力前向传播计算
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    # 应用GQA:复制键和值状态以匹配查询头的数量
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    # 计算注意力分数:查询和键的矩阵乘法,并应用缩放
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    # 如果提供了注意力掩码,将其应用到注意力分数上
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # 对注意力分数应用softmax得到注意力权重,使用float32以提高数值稳定性
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    # 在训练时应用dropout
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    # 使用注意力权重对值进行加权汇总
    attn_output = torch.matmul(attn_weights, value_states)
    # 调整输出的维度顺序并确保内存连续
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# apply_rotary_pos_emb_interleave函数 - 应用交错权重的旋转位置编码
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    r"""
    TODO 让我们使用原始的freqcis计算来避免视图
    转置+重塑!这不是优化的!
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`):
            对应于查询和键张量的标记的位置索引。例如,这可以用于
            在使用KV缓存时传递偏移的位置ID。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # 对查询进行交错重塑
    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    # 对键进行交错重塑
    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    # 应用旋转位置编码
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# yarn_get_mscale函数 - 计算YaRN缩放的比例因子
def yarn_get_mscale(scale=1, mscale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


# DeepseekV3Attention类 - 实现DeepseekV3模型中的多头注意力机制
# 包含多种特殊设计,如不同的头部维度、LoRA等优化
class DeepseekV3Attention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力机制"""

    def __init__(self, config: DeepseekV3Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        # 每个键值头对应的查询头数量
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        # 注意力dropout率
        self.attention_dropout = config.attention_dropout
        # 注意力头数量
        self.num_heads = config.num_attention_heads
        # RoPE的theta参数
        self.rope_theta = config.rope_theta
        # 查询的LoRA秩
        self.q_lora_rank = config.q_lora_rank
        # 应用RoPE的查询键头维度
        self.qk_rope_head_dim = config.qk_rope_head_dim
        # 键值的LoRA秩
        self.kv_lora_rank = config.kv_lora_rank
        # 值头维度
        self.v_head_dim = config.v_head_dim
        # 不应用RoPE的查询键头维度
        self.qk_nope_head_dim = config.qk_nope_head_dim
        # 查询键总头维度
        self.qk_head_dim = config.qk_head_dim

        # 标记这是因果注意力
        self.is_causal = True
        # 查询的LoRA投影A
        self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
        # 查询的LoRA层归一化
        self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
        # 查询的LoRA投影B
        self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)

        # 键值的LoRA投影A与多查询注意力
        self.kv_a_proj_with_mqa = nn.Linear(
            config.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=config.attention_bias,
        )
        # 键值的LoRA层归一化
        self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
        # 键值的LoRA投影B
        self.kv_b_proj = nn.Linear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
        )

        # 输出投影
        self.o_proj = nn.Linear(
            self.num_heads * self.v_head_dim,
            config.hidden_size,
            bias=config.attention_bias,
        )

        # 缩放因子
        self.scaling = self.qk_head_dim ** (-0.5)
        # 如果启用了RoPE缩放
        if self.config.rope_scaling is not None:
            mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
            scaling_factor = self.config.rope_scaling["factor"]
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.scaling = self.scaling * mscale * mscale

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 获取批次大小和序列长度
        batch_size, seq_length = hidden_states.shape[:-1]
        # 定义查询形状和键形状
        query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
        key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)

        # 查询状态计算:通过LoRA投影A和B,加上层归一化
        q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
        # 将查询分为不应用RoPE的部分和应用RoPE的部分
        q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        # 压缩的键值计算
        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        # 分离键和旋转部分
        k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

        # 键值状态计算:通过LoRA投影B和层归一化
        k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
        # 分离键和值状态
        k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        # 重塑键旋转部分
        k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

        # 获取位置嵌入的余弦和正弦部分
        cos, sin = position_embeddings
        # 根据配置选择RoPE应用方式
        if self.config.rope_interleave:  # 支持使用交错权重来提高效率
            q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
        else:
            q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
        # 扩展键旋转部分到与键传递部分相同的形状
        k_rot = k_rot.expand(*k_pass.shape[:-1], -1)

        # 合并查询和键的各部分
        query_states = torch.cat((q_pass, q_rot), dim=-1)
        key_states = torch.cat((k_pass, k_rot), dim=-1)

        # 如果有过去的键值缓存,更新当前的键值状态
        if past_key_value is not None:
            # sin和cos特定于RoPE模型;静态缓存需要cache_position
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 对于flash_attention_2,如果query和value的维度不同,需要padding
        if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
            value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

        # 选择注意力计算接口
        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        # 调用选定的注意力接口计算注意力输出和权重
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        # 如果使用flash_attention_2,需要裁剪输出
        if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
            attn_output = attn_output[:, :, :, : self.v_head_dim]

        # 重塑注意力输出并应用输出投影
        attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


# DeepseekV3DecoderLayer类 - 实现DeepseekV3模型中的单个Transformer解码器层
# 继承自GradientCheckpointingLayer,支持梯度检查点以节省内存
class DeepseekV3DecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: DeepseekV3Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 自注意力层
        self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)

        # 根据层索引选择MLP或MoE
        if layer_idx >= config.first_k_dense_replace:
            # 对于后面的层,使用混合专家系统
            self.mlp = DeepseekV3MoE(config)
        else:
            # 对于前面的层,使用标准MLP
            self.mlp = DeepseekV3MLP(config)

        # 输入层归一化,应用于自注意力前
        self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 注意力后层归一化,应用于MLP前
        self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 必要的,但为了向后兼容而保留在这里
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存残差连接用的隐藏状态
        residual = hidden_states
        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        # 第一个残差连接:原始输入 + 注意力输出
        hidden_states = residual + hidden_states

        # 前馈神经网络或MoE
        # 保存第二个残差连接用的隐藏状态
        residual = hidden_states
        # 应用注意力后层归一化
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用MLP或MoE
        hidden_states = self.mlp(hidden_states)
        # 第二个残差连接:注意力输出 + MLP/MoE输出
        hidden_states = residual + hidden_states

        # 准备返回值
        outputs = (hidden_states,)
        # 如果需要返回注意力权重,将其添加到输出元组
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


# DEEPSEEK_V3模型文档字符串开始
DEEPSEEK_V3_START_DOCSTRING = r"""
    该模型继承自[`PreTrainedModel`]。查看超类文档了解库为所有模型实现的通用方法
    (如下载或保存、调整输入嵌入大小、剪枝头等)

    该模型也是PyTorch的[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)子类。
    将其作为常规PyTorch模块使用,并参考PyTorch文档了解与一般用法和行为相关的所有事项。

    参数:
        config ([`DeepseekV3Config`]):
            包含模型所有参数的模型配置类。使用配置文件初始化只加载模型的配置,
            而不加载与模型关联的权重。查看[`~PreTrainedModel.from_pretrained`]方法
            来加载模型权重。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸DeepseekV3模型,输出原始隐藏状态,顶部没有任何特定的头。",
    DEEPSEEK_V3_START_DOCSTRING,
)
# DeepseekV3PreTrainedModel类 - 所有DeepseekV3模型变体的基础类
# 定义了各种模型功能和支持的特性
class DeepseekV3PreTrainedModel(PreTrainedModel):
    # 设置配置类
    config_class = DeepseekV3Config
    # 基础模型前缀
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 指定不拆分的模块
    _no_split_modules = ["DeepseekV3DecoderLayer"]
    # 设备放置时跳过的键
    _skip_keys_device_placement = ["past_key_values"]
    # 支持Flash Attention 2优化
    _supports_flash_attn_2 = True
    # 支持缩放点积注意力(SDPA)优化
    _supports_sdpa = True
    # 支持灵活注意力优化
    _supports_flex_attn = True
    # 支持缓存类
    _supports_cache_class = True
    # 支持量化缓存
    _supports_quantized_cache = True
    # 支持静态缓存
    _supports_static_cache = True
    # 支持注意力后端
    _supports_attention_backend = True

    def _init_weights(self, module):
        # 初始化权重的标准差
        std = self.config.initializer_range
        # 线性层权重初始化
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 嵌入层权重初始化
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # RMSNorm层权重初始化为1.0
        elif isinstance(module, DeepseekV3RMSNorm):
            module.weight.data.fill_(1.0)
        # TopkRouter权重初始化
        elif isinstance(module, DeepseekV3TopkRouter):
            module.weight.data.normal_(mean=0.0, std=std)


# DEEPSEEK_V3输入文档字符串
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
    参数:
        input_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。如果提供了填充,将默认忽略。

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            [什么是输入ID?](../glossary#input-ids)
        attention_mask (`torch.Tensor` 形状为 `(batch_size, sequence_length) 或 `BlockMask`, *可选*):
            用于避免在填充标记索引上执行注意力的掩码。掩码值选择为`[0, 1]`:

            - 1表示**未被掩盖的**标记,
            - 0表示**被掩盖的**标记。

            如果模型配置为使用flex_attention,它将尝试将掩码Tensor转换为BlockMask,
            但您也可以直接在此处传递`BlockMask`对象。

            [什么是注意力掩码?](../glossary#attention-mask)

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            如果使用了`past_key_values`,可选地只需要输入最后的`input_ids`(参见
            `past_key_values`)。

            如果您想更改填充行为,应阅读[`modeling_opt._prepare_decoder_attention_mask`]
            并根据您的需求进行修改。有关默认策略的更多信息,请参见[论文](https://arxiv.org/abs/1910.13461)中的图1。

            - 1表示头部**未被掩盖**,
            - 0表示头部**被掩盖**。
        position_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
            每个输入序列标记在位置嵌入中的索引。在范围`[0, config.n_positions - 1]`中选择。

            [什么是位置ID?](../glossary#position-ids)
        past_key_values (`Cache`, *可选*):
            预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可用于加速顺序解码。
            这通常由模型在之前阶段的解码中返回的`past_key_values`组成,当`use_cache=True`或
            `config.use_cache=True`时。

            它是一个[`~cache_utils.Cache`]实例。有关更多详细信息,请参见我们的[kv缓存指南](https://huggingface.co/docs/transformers/en/kv_cache)。

            如果使用了`past_key_values`,用户可以选择只输入最后的`input_ids`(那些没有给这个模型的过去键值状态的输入),
            形状为`(batch_size, 1)`,而不是所有的`input_ids`,形状为`(batch_size, sequence_length)`。
        inputs_embeds (`torch.FloatTensor` 形状为 `(batch_size, sequence_length, hidden_size)`, *可选*):
            可选地,您可以选择直接传递嵌入表示,而不是传递`input_ids`。这在您想要更多地控制
            如何将`input_ids`索引转换为相关向量比模型的内部嵌入查找矩阵更有用。
        use_cache (`bool`, *可选*):
            如果设置为`True`,将返回`past_key_values`键值状态,可用于加速解码(参见
            `past_key_values`)。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。有关更多详细信息,请参见`attentions`下的返回
            张量。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。有关更多详细信息,请参见`hidden_states`下的返回
            张量。
        return_dict (`bool`, *可选*):
            是否返回[`~utils.ModelOutput`]而不是普通元组。
        cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
            描述输入序列标记在序列中位置的索引。与`position_ids`不同,这个张量不受填充影响。
            它用于在正确的位置更新缓存并推断完整的序列长度。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸DeepseekV3模型,输出原始隐藏状态,顶部没有任何特定的头。",
    DEEPSEEK_V3_START_DOCSTRING,
)
# DeepseekV3Model类 - DeepseekV3的主体模型结构
# 包含完整的Transformer解码器堆栈,实现了自回归生成的核心功能
class DeepseekV3Model(DeepseekV3PreTrainedModel):
    """
    Transformer解码器,由*config.num_hidden_layers*层组成。每一层都是一个[`DeepseekV3DecoderLayer`]

    参数:
        config: DeepseekV3Config
    """

    # 加载时忽略意外的键
    _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]

    def __init__(self, config: DeepseekV3Config):
        super().__init__(config)
        # 设置填充索引和词汇表大小
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词嵌入层:将输入token转换为隐藏表示
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建解码器层堆栈
        self.layers = nn.ModuleList(
            [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 最终的层归一化
        self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 旋转位置编码
        self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
        # 梯度检查点标志,默认关闭
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        # 设置默认值,优先使用传入参数,其次使用配置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # 梯度检查点与缓存不兼容,训练时警告并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 检查过去键值类型兼容性
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        # 如果没有提供嵌入,通过词嵌入层获取
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果启用缓存但未提供past_key_values,创建动态缓存
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        # 处理缓存位置:确定序列中每个token的位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有提供位置ID,从缓存位置创建
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码:确保自回归属性,即当前token只能看到过去的token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # 设置初始隐藏状态为输入嵌入
        hidden_states = inputs_embeds

        # 创建位置嵌入,在所有解码器层之间共享
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # 解码器层处理
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # 遍历所有解码器层
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            # 如果需要输出所有隐藏状态,保存当前隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 前向传播通过当前解码器层
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

            # 更新隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,保存当前层的注意力权重
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # 最终的层归一化
        hidden_states = self.norm(hidden_states)

        # 添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回模型输出
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: Union[torch.Tensor, "BlockMask"],
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool = False,
    ):
        """
        更新因果掩码:根据不同的注意力实现方式和缓存类型处理掩码
        处理各种特殊情况,包括Flash Attention 2、滑动窗口等
        """
        # Flash Attention 2的特殊处理
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None

        # Flex Attention的特殊处理
        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            return attention_mask

        # 对于SDPA(缩放点积注意力),尽可能依赖其`is_causal`参数而非`attn_mask`参数
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # 当输出注意力为True时,sdpa实现的forward方法调用eager实现的forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # 获取输入张量的数据类型和设备
        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]

        # 确定目标长度
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # 生成4D因果掩码
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # SDPA的特殊处理:对于完全掩码的行,允许注意所有token
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu", "npu"]
            and not output_attentions
        ):
            # 注意完全掩码行中的所有token,例如使用左填充时的相关第一行
            # 这是F.scaled_dot_product_attention内存高效注意力路径所需的
            # 详情:https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        创建一个形状为`(batch_size, 1, query_length, key_value_length)`的因果4D掩码,
        从形状为`(batch_size, key_value_length)`的2D掩码创建,
        或如果输入的`attention_mask`已经是4D,则不做任何处理。

        参数:
            attention_mask (`torch.Tensor`):
                形状为`(batch_size, key_value_length)`的2D注意力掩码,
                或形状为`(batch_size, 1, query_length, key_value_length)`的4D注意力掩码。
            sequence_length (`int`):
                正在处理的序列长度。
            target_length (`int`):
                目标长度:使用静态缓存生成时,掩码应与静态缓存一样长,
                以考虑0填充、尚未填充的缓存部分。
            dtype (`torch.dtype`):
                用于4D注意力掩码的数据类型。
            device (`torch.device`):
                放置4D注意力掩码的设备。
            cache_position (`torch.Tensor`):
                描述输入序列标记在序列中位置的索引。
            batch_size (`torch.Tensor`):
                批次大小。
        """
        # 如果掩码已经是4D,直接使用
        if attention_mask is not None and attention_mask.dim() == 4:
            # 在这种情况下,我们假设掩码已经以反转形式提供,不需要反转或切片
            causal_mask = attention_mask
        else:
            # 创建填充有最小值的因果掩码
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            # 如果序列长度不是1,创建上三角掩码(对角线以上为1)
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            # 应用缓存位置掩码
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            # 扩展掩码维度以匹配批次大小
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

            # 如果提供了注意力掩码,合并掩码
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # 复制到连续内存以进行原地编辑
                mask_length = attention_mask.shape[-1]
                # 合并因果掩码和注意力掩码
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                    causal_mask.device
                )
                padding_mask = padding_mask == 0
                # 应用掩码
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


# DeepseekV3ForCausalLM类 - 用于因果语言建模的DeepseekV3模型
# 在基础DeepseekV3Model上添加了语言模型头,用于生成任务
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
    # 定义与主干模型共享权重的参数
    _tied_weights_keys = ["lm_head.weight"]
    # 张量并行计划:lm_head按列方向划分并复制
    _tp_plan = {"lm_head": "colwise_rep"}
    # 流水线并行计划:lm_head接收hidden_states并输出logits
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.model = DeepseekV3Model(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 创建语言模型头:将隐藏状态映射到词汇表大小的logits
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """获取输出嵌入层(lm_head)"""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """设置输出嵌入层(lm_head)"""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """设置解码器模型"""
        self.model = decoder

    def get_decoder(self):
        """获取解码器模型"""
        return self.model

    @can_return_tuple
    @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPast:
        r"""
            labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
                用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
                (参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
                损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。

            logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
                如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
                在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
                对于长序列或大词汇表尤为显著。
                如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
                这在使用打包张量格式(批次和序列长度的单一维度)时很有用。

        返回:

        示例:

        ```python
        >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM

        >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # 生成
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        # 设置输出选项
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 通过基础模型获取特征
        # 解码器输出包含(dec_features, layer_state, dec_hidden, dec_attn)
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        # 获取最后的隐藏状态
        hidden_states = outputs.last_hidden_state
        # 只计算必要的logits,如果不计算损失则不将其转换为浮点类型
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        # 返回因果语言模型输出,包含损失、logits、过去键值状态等
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 定义模块中所有公开的类
__all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"]
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
FP32 Card Animate CUDA HuggingFace Miniforge v2ray Pytorch 签证 TensorRT Python Shortcut Qwen2 COCO Template CV 证件照 GPT4 Crawler UI LoRA Firewall Github Ubuntu Input Disk Bitcoin Michelin Zip Pickle Freesound 算法题 Base64 DeepStream Food Windows ONNX Hotel BeautifulSoup Bin Attention Logo Excel AI Statistics 继承 UNIX Interview RGB Plotly Linux Diagram Distillation Google CLAP OpenAI FlashAttention scipy Data v0.dev CAM 飞书 uWSGI tqdm logger Proxy Color OCR DeepSeek printf OpenCV GGML SAM Quantize Anaconda Random FP8 Streamlit Baidu YOLO 音频 QWEN VGG-16 Augmentation C++ 阿里云 多进程 多线程 WAN Gemma Django 公式 GoogLeNet PyCharm NLTK Knowledge ChatGPT torchinfo Safetensors Pandas git Datetime hf Use Image2Text Llama Tensor Sklearn Password Permission Nginx LLAMA Qwen2.5 Claude GIT Numpy 财报 Markdown Quantization 域名 Vim Plate Cloudreve PDF Bipartite NameSilo RAR Heatmap InvalidArgumentError SVR CEIR uwsgi llama.cpp Git mmap TSV VPN API FP16 BF16 FP64 Magnet tar Pillow Conda WebCrawler Mixtral Algorithm TTS Vmess Breakpoint Tiktoken TensorFlow Review CTC Jetson ModelScope Dataset Ptyhon LeetCode 搞笑 Bert 关于博主 NLP Tracking Docker VSCode 腾讯云 CC SPIE LaTeX Web transformers Qwen LLM JSON SQL Hungarian PyTorch Transformers BTC PDB Paper HaggingFace Jupyter XGBoost Clash PIP XML SQLite IndexTTS2 Domain Hilton MD5 diffusers 净利润 Video Math Translation 版权 Paddle Land EXCEL ResNet-50 报税 CSV FastAPI git-lfs Website GPTQ
站点统计

本站现有博文309篇,共被浏览730468

本站已经建立2367天!

热门文章
文章归档
回到顶部