EADST

Transformers Llama 模型代码中文注释 modeling_llama.py

# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# 此代码基于EleutherAI的GPT-NeoX库以及此库中的GPT-NeoX和OPT实现。
# 它已从原始形式修改,以适应与Meta AI团队训练模型时使用的GPT-NeoX和OPT相比的
# 微小架构差异。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件
# 是基于"按原样"分发的,没有任何明示或暗示的担保或条件。
# 有关许可证下特定语言的权限和限制,请参阅许可证。
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
from typing import Callable, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
    LossKwargs,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    can_return_tuple,
    is_torch_flex_attn_available,
    logging,
    replace_return_docstrings,
)
from .configuration_llama import LlamaConfig


# 检查FlexAttention是否可用,如果可用则导入相关模块
if is_torch_flex_attn_available():
    from torch.nn.attention.flex_attention import BlockMask

    from ...integrations.flex_attention import make_flex_block_causal_mask

# 导入用于从hub加载内核前向传递的工具
from ...integrations import use_kernel_forward_from_hub


logger = logging.get_logger(__name__)

# 文档生成用的检查点和配置
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig"

# LlamaRMSNorm类 - 这是LLaMA模型使用的一种特殊归一化层,等同于T5LayerNorm
# 与传统的LayerNorm不同,RMSNorm只使用均方根进行归一化,不涉及均值中心化
# 计算流程:先保存输入类型,转为float32进行高精度计算,然后进行归一化,最后转回原始数据类型
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm等同于T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # 保存输入的数据类型
        input_dtype = hidden_states.dtype
        # 转换为float32进行计算
        hidden_states = hidden_states.to(torch.float32)
        # 计算方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 归一化操作:除以方差的平方根
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 应用权重并转回原始数据类型
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        # 为打印模块信息提供额外的表示
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# 将LlamaRMSNorm添加到所有层归一化层列表中
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)

# LlamaRotaryEmbedding类 - 实现了旋转位置编码(RoPE)
# 支持多种RoPE变体,通过config中的rope_type参数指定
# 处理序列长度扩展问题,允许模型处理比训练时更长的序列
# 包含注意力缩放机制,用于优化不同长度序列的表示
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, config: LlamaConfig, device=None):
        """
        LLaMA模型使用的旋转位置编码(RoPE)实现
        """
        super().__init__()
        # 向后兼容:处理rope_type参数
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        # 缓存的最大序列长度
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        # 根据RoPE类型获取初始化函数
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        # 初始化频率反转和注意力缩放
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # 高级用户:用于高级RoPE类型(例如dynamic rope)
    def forward(self, x, position_ids):
        # 扩展反转频率到批次大小
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        # 确定设备类型,特殊处理MPS设备
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # 强制使用float32
            # 计算频率
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            # 合并频率
            emb = torch.cat((freqs, freqs), dim=-1)
            # 计算余弦和正弦部分,并应用注意力缩放
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        # 返回结果,转换为输入的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# 将张量的一半隐藏维度进行旋转,是RoPE核心操作之一
def rotate_half(x):
    """旋转输入张量的一半隐藏维度。"""
    # 将输入张量在最后一个维度上分成两半
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    # 返回旋转后的结果:(-x2, x1)
    return torch.cat((-x2, x1), dim=-1)

# 将旋转位置编码应用到查询(Q)和键(K)张量上
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`, *可选*):
            已弃用且未使用。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    # 扩展维度以便广播
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    # 应用旋转变换
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# LlamaMLP类 - 实现了LLaMA模型中的前馈神经网络部分
# 使用SwiGLU激活函数,这是一种门控线性单元变体
# 包含三个投影层:gate_proj、up_proj和down_proj
# 通过并行计算gate_proj和up_proj来提高效率
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        # gate_proj:将隐藏状态投影到中间维度的门控投影层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        # up_proj:将隐藏状态投影到中间维度的上投影层
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        # down_proj:将中间表示投影回隐藏维度的下投影层
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        # 激活函数,默认为SiLU(Sigmoid Linear Unit)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # SwiGLU激活:先计算gate_proj经过激活函数的结果,再与up_proj的结果相乘
        # 然后通过down_proj投影回原始维度
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


# repeat_kv函数 - 实现了分组查询注意力(GQA)中的键值头复制操作
# 将键值头的数量从num_key_value_heads扩展到num_attention_heads
# 通过复制每个键值头以匹配注意力头的数量,实现了计算效率和内存使用的平衡
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    这相当于torch.repeat_interleave(x, dim=1, repeats=n_rep)。隐藏状态从(batch,
    num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    # 如果复制次数为1,直接返回原始状态
    if n_rep == 1:
        return hidden_states
    # 在原始形状中插入新维度并扩展,然后重塑以得到所需形状
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# eager_attention_forward函数 - 实现了标准的注意力前向传播计算
# 用于不使用CUDA优化(如FlashAttention)时的注意力计算
# 包含经典的注意力计算步骤:矩阵乘法、缩放、掩码、softmax和dropout
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    # 应用GQA:复制键和值状态以匹配查询头的数量
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    # 计算注意力分数:查询和键的矩阵乘法,并应用缩放
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    # 如果提供了注意力掩码,将其应用到注意力分数上
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # 对注意力分数应用softmax得到注意力权重,使用float32以提高数值稳定性
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    # 在训练时应用dropout
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    # 使用注意力权重对值进行加权汇总
    attn_output = torch.matmul(attn_weights, value_states)
    # 调整输出的维度顺序并确保内存连续
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# LlamaAttention类 - 实现了LLaMA模型中的多头注意力机制
# 支持分组查询注意力(GQA),可以减少内存使用同时保持性能
# 包含旋转位置编码(RoPE)的应用
# 支持KV缓存以加速自回归生成
class LlamaAttention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力机制"""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        # 获取注意力头维度,如果未指定则从隐藏大小和头数计算
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        # 每个键值头对应的查询头数量
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        # 缩放因子,用于缩放注意力分数
        self.scaling = self.head_dim**-0.5
        # 注意力dropout率
        self.attention_dropout = config.attention_dropout
        # 标记这是因果注意力(使用因果掩码)
        self.is_causal = True

        # 查询投影:将隐藏状态投影到查询表示空间
        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        # 键投影:将隐藏状态投影到键表示空间
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        # 值投影:将隐藏状态投影到值表示空间
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        # 输出投影:将多头注意力的输出投影回原始隐藏维度
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 记录输入形状,用于后续重塑
        input_shape = hidden_states.shape[:-1]
        # 计算多头形状
        hidden_shape = (*input_shape, -1, self.head_dim)

        # 计算查询、键、值状态并重塑为多头形式
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # 解包位置嵌入的余弦和正弦部分
        cos, sin = position_embeddings
        # 应用旋转位置编码到查询和键状态
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # 如果有过去的键值缓存,更新当前的键值状态
        if past_key_value is not None:
            # sin和cos特定于RoPE模型;静态缓存需要cache_position
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 默认使用eager注意力计算接口
        attention_interface: Callable = eager_attention_forward

        # 根据配置选择不同的注意力实现
        if self.config._attn_implementation != "eager":
            # sdpa不支持输出注意力权重时,回退到eager实现
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                # 使用配置指定的注意力实现
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        # 调用选定的注意力接口计算注意力输出和权重
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        # 将注意力输出重塑回原始形状
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        # 通过输出投影层变换注意力输出
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

# LlamaDecoderLayer类 - 实现了LLaMA模型中的单个Transformer解码器层
# 继承自GradientCheckpointingLayer,支持梯度检查点以节省内存
# 采用Pre-LayerNorm架构,即在注意力和FFN前应用层归一化
# 使用残差连接,保持信息流通并缓解梯度消失问题
class LlamaDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 自注意力层
        self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)

        # 前馈神经网络层
        self.mlp = LlamaMLP(config)
        # 输入层归一化,应用于自注意力前
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 注意力后层归一化,应用于MLP前
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 必要的,但为了向后兼容而保留在这里
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存残差连接用的隐藏状态
        residual = hidden_states
        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        # 第一个残差连接:原始输入 + 注意力输出
        hidden_states = residual + hidden_states

        # 前馈神经网络(全连接层)
        # 保存第二个残差连接用的隐藏状态
        residual = hidden_states
        # 应用注意力后层归一化
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用MLP
        hidden_states = self.mlp(hidden_states)
        # 第二个残差连接:注意力输出 + MLP输出
        hidden_states = residual + hidden_states

        # 准备返回值
        outputs = (hidden_states,)
        # 如果需要返回注意力权重,将其添加到输出元组
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


# LLAMA模型文档字符串开始
LLAMA_START_DOCSTRING = r"""
    该模型继承自[`PreTrainedModel`]。查看超类文档了解库为所有模型实现的通用方法
    (如下载或保存、调整输入嵌入大小、剪枝头等)

    该模型也是PyTorch的[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)子类。
    将其作为常规PyTorch模块使用,并参考PyTorch文档了解与一般用法和行为相关的所有事项。

    参数:
        config ([`LlamaConfig`]):
            包含模型所有参数的模型配置类。使用配置文件初始化只加载模型的配置,
            而不加载与模型关联的权重。查看[`~PreTrainedModel.from_pretrained`]方法
            来加载模型权重。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
    LLAMA_START_DOCSTRING,
)
# LlamaPreTrainedModel类 - 所有LLAMA模型变体的基础类
# 定义了各种模型功能和支持的特性
# 实现了权重初始化等通用方法
class LlamaPreTrainedModel(PreTrainedModel):
    # 设置配置类
    config_class = LlamaConfig
    # 基础模型前缀
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 指定不拆分的模块
    _no_split_modules = ["LlamaDecoderLayer"]
    # 设备放置时跳过的键
    _skip_keys_device_placement = ["past_key_values"]
    # 支持Flash Attention 2优化
    _supports_flash_attn_2 = True
    # 支持缩放点积注意力(SDPA)优化
    _supports_sdpa = True
    # 支持灵活注意力优化
    _supports_flex_attn = True
    # 支持缓存类
    _supports_cache_class = True
    # 支持量化缓存
    _supports_quantized_cache = True
    # 支持静态缓存
    _supports_static_cache = True
    # 支持注意力后端
    _supports_attention_backend = True

    def _init_weights(self, module):
        # 初始化权重的标准差
        std = self.config.initializer_range
        # 线性层权重初始化
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 嵌入层权重初始化
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # RMSNorm层权重初始化为1.0
        elif isinstance(module, LlamaRMSNorm):
            module.weight.data.fill_(1.0)


# LLAMA输入文档字符串
LLAMA_INPUTS_DOCSTRING = r"""
    参数:
        input_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。如果提供了填充,将默认忽略。

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            [什么是输入ID?](../glossary#input-ids)
        attention_mask (`torch.Tensor` 形状为 `(batch_size, sequence_length) 或 `BlockMask`, *可选*):
            用于避免在填充标记索引上执行注意力的掩码。掩码值选择为`[0, 1]`:

            - 1表示**未被掩盖的**标记,
            - 0表示**被掩盖的**标记。

            如果模型配置为使用flex_attention,它将尝试将掩码Tensor转换为BlockMask,
            但您也可以直接在此处传递`BlockMask`对象。

            [什么是注意力掩码?](../glossary#attention-mask)

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            如果使用了`past_key_values`,可选地只需要输入最后的`input_ids`(参见
            `past_key_values`)。

            如果您想更改填充行为,应阅读[`modeling_opt._prepare_decoder_attention_mask`]
            并根据您的需求进行修改。有关默认策略的更多信息,请参见[论文](https://arxiv.org/abs/1910.13461)中的图1。

            - 1表示头部**未被掩盖**,
            - 0表示头部**被掩盖**。
        position_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
            每个输入序列标记在位置嵌入中的索引。在范围`[0, config.n_positions - 1]`中选择。

            [什么是位置ID?](../glossary#position-ids)
        past_key_values (`Cache`, *可选*):
            预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可用于加速顺序解码。
            这通常由模型在之前阶段的解码中返回的`past_key_values`组成,当`use_cache=True`或
            `config.use_cache=True`时。

            它是一个[`~cache_utils.Cache`]实例。有关更多详细信息,请参见我们的[kv缓存指南](https://huggingface.co/docs/transformers/en/kv_cache)。

            如果使用了`past_key_values`,用户可以选择只输入最后的`input_ids`(那些没有给这个模型的过去键值状态的输入),
            形状为`(batch_size, 1)`,而不是所有的`input_ids`,形状为`(batch_size, sequence_length)`。
        inputs_embeds (`torch.FloatTensor` 形状为 `(batch_size, sequence_length, hidden_size)`, *可选*):
            可选地,您可以选择直接传递嵌入表示,而不是传递`input_ids`。这在您想要更多地控制
            如何将`input_ids`索引转换为相关向量比模型的内部嵌入查找矩阵更有用。
        use_cache (`bool`, *可选*):
            如果设置为`True`,将返回`past_key_values`键值状态,可用于加速解码(参见
            `past_key_values`)。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。有关更多详细信息,请参见`attentions`下的返回
            张量。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。有关更多详细信息,请参见`hidden_states`下的返回
            张量。
        return_dict (`bool`, *可选*):
            是否返回[`~utils.ModelOutput`]而不是普通元组。
        cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
            描述输入序列标记在序列中位置的索引。与`position_ids`不同,这个张量不受填充影响。
            它用于在正确的位置更新缓存并推断完整的序列长度。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
    LLAMA_START_DOCSTRING,
)
# LlamaModel类 - 实现了LLaMA的核心模型结构
# 由多个LlamaDecoderLayer层堆叠而成,应用RMSNorm和RoPE位置编码
# 实现了高效的注意力机制,支持KV缓存以加速推理
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer解码器,由*config.num_hidden_layers*层组成。每一层都是一个[`LlamaDecoderLayer`]

    参数:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        # 设置填充索引和词汇表大小
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词嵌入层:将输入token转换为隐藏表示
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建解码器层堆栈
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 最终的层归一化
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 旋转位置编码
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        # 梯度检查点标志,默认关闭
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        # 设置默认值,优先使用传入参数,其次使用配置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # 梯度检查点与缓存不兼容,训练时警告并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 检查过去键值类型兼容性
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        # 如果没有提供嵌入,通过词嵌入层获取
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果启用缓存但未提供past_key_values,创建动态缓存
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        # 处理缓存位置:确定序列中每个token的位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有提供位置ID,从缓存位置创建
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码:确保自回归属性,即当前token只能看到过去的token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # 设置初始隐藏状态为输入嵌入
        hidden_states = inputs_embeds

        # 创建位置嵌入,在所有解码器层之间共享
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # 解码器层处理
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # 遍历所有解码器层
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            # 如果需要输出所有隐藏状态,保存当前隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 前向传播通过当前解码器层
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

            # 更新隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,保存当前层的注意力权重
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # 最终的层归一化
        hidden_states = self.norm(hidden_states)

        # 添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回模型输出
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: Union[torch.Tensor, "BlockMask"],
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool = False,
    ):
        """
        更新因果掩码:根据不同的注意力实现方式处理掩码
        处理不同的注意力实现(Flash Attention 2、Flex Attention、SDPA等)
        """
        # Flash Attention 2的特殊处理
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None

        # Flex Attention的特殊处理
        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            return attention_mask

        # 对于SDPA(缩放点积注意力),尽可能依赖其`is_causal`参数而非`attn_mask`参数
        # 这样可以调度到Flash Attention 2。此功能与静态缓存不兼容。
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # 当输出注意力为True时,sdpa实现的forward方法调用eager实现的forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # 获取输入张量的数据类型和设备
        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]

        # 确定目标长度
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # 生成4D因果掩码
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # SDPA的特殊处理:对于完全掩码的行,允许注意所有token
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu", "npu"]
            and not output_attentions
        ):
            # 注意完全掩码行中的所有token,例如使用左填充时的相关第一行
            # 这是F.scaled_dot_product_attention内存高效注意力路径所需的
            # 详情:https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        创建一个形状为`(batch_size, 1, query_length, key_value_length)`的因果4D掩码,
        从形状为`(batch_size, key_value_length)`的2D掩码创建,
        或如果输入的`attention_mask`已经是4D,则不做任何处理。

        参数:
            attention_mask (`torch.Tensor`):
                形状为`(batch_size, key_value_length)`的2D注意力掩码,
                或形状为`(batch_size, 1, query_length, key_value_length)`的4D注意力掩码。
            sequence_length (`int`):
                正在处理的序列长度。
            target_length (`int`):
                目标长度:使用静态缓存生成时,掩码应与静态缓存一样长,
                以考虑0填充、尚未填充的缓存部分。
            dtype (`torch.dtype`):
                用于4D注意力掩码的数据类型。
            device (`torch.device`):
                放置4D注意力掩码的设备。
            cache_position (`torch.Tensor`):
                描述输入序列标记在序列中位置的索引。
            batch_size (`torch.Tensor`):
                批次大小。
        """
        # 如果掩码已经是4D,直接使用
        if attention_mask is not None and attention_mask.dim() == 4:
            # 在这种情况下,我们假设掩码已经以反转形式提供,不需要反转或切片
            causal_mask = attention_mask
        else:
            # 创建填充有最小值的因果掩码
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            # 如果序列长度不是1,创建上三角掩码(对角线以上为1)
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            # 应用缓存位置掩码
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            # 扩展掩码维度以匹配批次大小
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

            # 如果提供了注意力掩码,合并掩码
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # 复制到连续内存以进行原地编辑
                mask_length = attention_mask.shape[-1]
                # 合并因果掩码和注意力掩码
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                    causal_mask.device
                )
                padding_mask = padding_mask == 0
                # 应用掩码
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


# LlamaForCausalLM类 - 用于因果语言建模的LLaMA模型
# 在LlamaModel基础上添加了语言建模头,用于生成任务
# 支持高效的文本生成,包含损失计算和推理优化
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    # 定义与主干模型共享权重的参数
    _tied_weights_keys = ["lm_head.weight"]
    # 张量并行计划:lm_head按列方向划分并复制
    _tp_plan = {"lm_head": "colwise_rep"}
    # 流水线并行计划:lm_head接收hidden_states并输出logits
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.model = LlamaModel(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 创建语言模型头:将隐藏状态映射到词汇表大小的logits
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """获取输出嵌入层(lm_head)"""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """设置输出嵌入层(lm_head)"""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """设置解码器模型"""
        self.model = decoder

    def get_decoder(self):
        """获取解码器模型"""
        return self.model

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPast:
        r"""
            labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
                用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
                (参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
                损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。

            logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
                如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
                在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
                对于长序列或大词汇表尤为显著。
                如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
                这在使用打包张量格式(批次和序列长度的单一维度)时很有用。

        返回:

        示例:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # 生成
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        # 设置默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 通过解码器模型获取特征
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        # 获取最后的隐藏状态
        hidden_states = outputs.last_hidden_state
        # 只计算必要的logits,如果不计算损失则不将其转换为浮点类型
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        # 返回结果
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
    带有序列分类头的LLaMa模型transformer(线性层)。

    [`LlamaForSequenceClassification`]使用最后一个token进行分类,就像其他因果模型
    (例如GPT-2)一样。

    由于它对最后一个token进行分类,所以需要知道最后一个token的位置。如果配置中定义了
    `pad_token_id`,它会在每行中找到不是填充token的最后一个token。如果
    没有定义`pad_token_id`,它简单地取每行中的最后一个值。由于当传递`inputs_embeds`
    而不是`input_ids`时它无法猜测填充token,所以它会做同样的事情(取每行中的最后一个值)。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForSequenceClassification类 - 用于序列分类任务的LLaMA模型
# 在LlamaModel基础上添加了分类头,用于序列级别的预测
# 通过获取序列中最后一个非填充token的表示进行分类
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 设置标签数量
        self.num_labels = config.num_labels
        # 创建基础模型
        self.model = LlamaModel(config)
        # 创建分类得分头
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> SequenceClassifierOutputWithPast:
        r"""
        labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
            如果`config.num_labels == 1`,则计算回归损失(均方损失);
            如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
        """

        # 通过Transformer模型获取输出
        transformer_outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # 获取最后的隐藏状态
        hidden_states = transformer_outputs.last_hidden_state
        # 计算所有位置的分类得分
        logits = self.score(hidden_states)

        # 确定批次大小
        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        # 处理填充token:找到每个序列中最后一个非填充token的位置
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("如果没有定义填充token,则无法处理批次大小 > 1。")
        if self.config.pad_token_id is None:
            # 如果没有定义填充token,使用最后一个位置
            last_non_pad_token = -1
        elif input_ids is not None:
            # 处理左填充和右填充,我们取不等于pad_token_id的最右边的token
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            # 找到最后一个非填充token的位置
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            # 如果没有提供input_ids,使用最后一个位置
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} 将不会在 `inputs_embeds` 中检测填充token。如果结合使用填充token和 "
                "`inputs_embeds`,结果可能会出乎意料。"
            )

        # 提取每个序列中最后一个非填充token位置的logits
        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        # 返回结果
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
带有跨度分类头的Llama模型transformer,用于抽取式问答任务,如
SQuAD(在隐藏状态输出之上的线性层,用于计算`跨度开始logits`和`跨度结束logits`)。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForQuestionAnswering类 - 用于抽取式问答任务的LLaMA模型
# 添加了跨度分类头,用于预测答案在文本中的开始和结束位置
# 适用于SQuAD等数据集上的问答任务
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    base_model_prefix = "transformer"

    # 从transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__复制,将Bloom替换为Llama
    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.transformer = LlamaModel(config)
        # 创建问答输出层:将隐藏状态映射到两个输出(开始和结束位置)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.transformer.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.transformer.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> QuestionAnsweringModelOutput:
        r"""
        start_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            标记跨度开始位置(索引)的标签,用于计算token分类损失。
            位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
            不会被用于计算损失。
        end_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            标记跨度结束位置(索引)的标签,用于计算token分类损失。
            位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
            不会被用于计算损失。
        """

        # 通过Transformer模型获取输出
        outputs: BaseModelOutputWithPast = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        # 获取序列输出
        sequence_output = outputs.last_hidden_state

        # 通过问答输出层计算logits
        logits = self.qa_outputs(sequence_output)
        # 将logits分割为开始和结束logits
        start_logits, end_logits = logits.split(1, dim=-1)
        # 压缩维度并确保张量连续
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        # 如果提供了开始和结束位置,计算损失
        loss = None
        if start_positions is not None and end_positions is not None:
            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)

        # 返回问答模型输出
        return QuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
    带有token分类头的Llama模型transformer(在隐藏状态输出之上的线性层),
    例如用于命名实体识别(NER)任务。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForTokenClassification类 - 用于token级别分类任务的LLaMA模型
# 添加了token分类头,用于对每个token进行分类
# 适用于命名实体识别(NER)、词性标注(POS)等任务
class LlamaForTokenClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 设置标签数量
        self.num_labels = config.num_labels
        # 创建基础模型
        self.model = LlamaModel(config)

        # 确定分类器的dropout率
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        # 创建dropout层
        self.dropout = nn.Dropout(classifier_dropout)
        # 创建分类得分层
        self.score = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> TokenClassifierOutput:
        r"""
        labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
            如果`config.num_labels == 1`,则计算回归损失(均方损失);
            如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
        """

        # 通过基础模型获取输出
        outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # 获取序列输出
        sequence_output = outputs.last_hidden_state
        # 应用dropout
        sequence_output = self.dropout(sequence_output)
        # 计算每个token的分类logits
        logits = self.score(sequence_output)

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.config)

        # 返回token分类器输出
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 定义模块中所有公开的类
__all__ = [
    "LlamaForCausalLM",
    "LlamaModel",
    "LlamaPreTrainedModel",
    "LlamaForSequenceClassification",
    "LlamaForQuestionAnswering",
    "LlamaForTokenClassification",
]
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
站点统计

本站现有博文287篇,共被浏览558181

本站已经建立2144天!

热门文章
文章归档
回到顶部