EADST

Transformers Llama 模型代码中文注释 modeling_llama.py

# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# 此代码基于EleutherAI的GPT-NeoX库以及此库中的GPT-NeoX和OPT实现。
# 它已从原始形式修改,以适应与Meta AI团队训练模型时使用的GPT-NeoX和OPT相比的
# 微小架构差异。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件
# 是基于"按原样"分发的,没有任何明示或暗示的担保或条件。
# 有关许可证下特定语言的权限和限制,请参阅许可证。
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
from typing import Callable, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
    LossKwargs,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    can_return_tuple,
    is_torch_flex_attn_available,
    logging,
    replace_return_docstrings,
)
from .configuration_llama import LlamaConfig


# 检查FlexAttention是否可用,如果可用则导入相关模块
if is_torch_flex_attn_available():
    from torch.nn.attention.flex_attention import BlockMask

    from ...integrations.flex_attention import make_flex_block_causal_mask

# 导入用于从hub加载内核前向传递的工具
from ...integrations import use_kernel_forward_from_hub


logger = logging.get_logger(__name__)

# 文档生成用的检查点和配置
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig"

# LlamaRMSNorm类 - 这是LLaMA模型使用的一种特殊归一化层,等同于T5LayerNorm
# 与传统的LayerNorm不同,RMSNorm只使用均方根进行归一化,不涉及均值中心化
# 计算流程:先保存输入类型,转为float32进行高精度计算,然后进行归一化,最后转回原始数据类型
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm等同于T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # 保存输入的数据类型
        input_dtype = hidden_states.dtype
        # 转换为float32进行计算
        hidden_states = hidden_states.to(torch.float32)
        # 计算方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 归一化操作:除以方差的平方根
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 应用权重并转回原始数据类型
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        # 为打印模块信息提供额外的表示
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# 将LlamaRMSNorm添加到所有层归一化层列表中
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)

# LlamaRotaryEmbedding类 - 实现了旋转位置编码(RoPE)
# 支持多种RoPE变体,通过config中的rope_type参数指定
# 处理序列长度扩展问题,允许模型处理比训练时更长的序列
# 包含注意力缩放机制,用于优化不同长度序列的表示
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, config: LlamaConfig, device=None):
        """
        LLaMA模型使用的旋转位置编码(RoPE)实现
        """
        super().__init__()
        # 向后兼容:处理rope_type参数
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        # 缓存的最大序列长度
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        # 根据RoPE类型获取初始化函数
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        # 初始化频率反转和注意力缩放
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # 高级用户:用于高级RoPE类型(例如dynamic rope)
    def forward(self, x, position_ids):
        # 扩展反转频率到批次大小
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        # 确定设备类型,特殊处理MPS设备
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # 强制使用float32
            # 计算频率
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            # 合并频率
            emb = torch.cat((freqs, freqs), dim=-1)
            # 计算余弦和正弦部分,并应用注意力缩放
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        # 返回结果,转换为输入的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# 将张量的一半隐藏维度进行旋转,是RoPE核心操作之一
def rotate_half(x):
    """旋转输入张量的一半隐藏维度。"""
    # 将输入张量在最后一个维度上分成两半
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    # 返回旋转后的结果:(-x2, x1)
    return torch.cat((-x2, x1), dim=-1)

# 将旋转位置编码应用到查询(Q)和键(K)张量上
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
    将旋转位置编码应用到查询和键张量上。

    参数:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`, *可选*):
            已弃用且未使用。
        unsqueeze_dim (`int`, *可选*, 默认为1):
            'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
            以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
            具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
            [batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
            cos[position_ids]和sin[position_ids]可广播到q和k的形状。
            同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
    返回:
        `tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    # 扩展维度以便广播
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    # 应用旋转变换
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# LlamaMLP类 - 实现了LLaMA模型中的前馈神经网络部分
# 使用SwiGLU激活函数,这是一种门控线性单元变体
# 包含三个投影层:gate_proj、up_proj和down_proj
# 通过并行计算gate_proj和up_proj来提高效率
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        # gate_proj:将隐藏状态投影到中间维度的门控投影层
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        # up_proj:将隐藏状态投影到中间维度的上投影层
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        # down_proj:将中间表示投影回隐藏维度的下投影层
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        # 激活函数,默认为SiLU(Sigmoid Linear Unit)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # SwiGLU激活:先计算gate_proj经过激活函数的结果,再与up_proj的结果相乘
        # 然后通过down_proj投影回原始维度
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


# repeat_kv函数 - 实现了分组查询注意力(GQA)中的键值头复制操作
# 将键值头的数量从num_key_value_heads扩展到num_attention_heads
# 通过复制每个键值头以匹配注意力头的数量,实现了计算效率和内存使用的平衡
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    这相当于torch.repeat_interleave(x, dim=1, repeats=n_rep)。隐藏状态从(batch,
    num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    # 如果复制次数为1,直接返回原始状态
    if n_rep == 1:
        return hidden_states
    # 在原始形状中插入新维度并扩展,然后重塑以得到所需形状
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# eager_attention_forward函数 - 实现了标准的注意力前向传播计算
# 用于不使用CUDA优化(如FlashAttention)时的注意力计算
# 包含经典的注意力计算步骤:矩阵乘法、缩放、掩码、softmax和dropout
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    # 应用GQA:复制键和值状态以匹配查询头的数量
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    # 计算注意力分数:查询和键的矩阵乘法,并应用缩放
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    # 如果提供了注意力掩码,将其应用到注意力分数上
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # 对注意力分数应用softmax得到注意力权重,使用float32以提高数值稳定性
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    # 在训练时应用dropout
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    # 使用注意力权重对值进行加权汇总
    attn_output = torch.matmul(attn_weights, value_states)
    # 调整输出的维度顺序并确保内存连续
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


# LlamaAttention类 - 实现了LLaMA模型中的多头注意力机制
# 支持分组查询注意力(GQA),可以减少内存使用同时保持性能
# 包含旋转位置编码(RoPE)的应用
# 支持KV缓存以加速自回归生成
class LlamaAttention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力机制"""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        # 获取注意力头维度,如果未指定则从隐藏大小和头数计算
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        # 每个键值头对应的查询头数量
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        # 缩放因子,用于缩放注意力分数
        self.scaling = self.head_dim**-0.5
        # 注意力dropout率
        self.attention_dropout = config.attention_dropout
        # 标记这是因果注意力(使用因果掩码)
        self.is_causal = True

        # 查询投影:将隐藏状态投影到查询表示空间
        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        # 键投影:将隐藏状态投影到键表示空间
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        # 值投影:将隐藏状态投影到值表示空间
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        # 输出投影:将多头注意力的输出投影回原始隐藏维度
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 记录输入形状,用于后续重塑
        input_shape = hidden_states.shape[:-1]
        # 计算多头形状
        hidden_shape = (*input_shape, -1, self.head_dim)

        # 计算查询、键、值状态并重塑为多头形式
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # 解包位置嵌入的余弦和正弦部分
        cos, sin = position_embeddings
        # 应用旋转位置编码到查询和键状态
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # 如果有过去的键值缓存,更新当前的键值状态
        if past_key_value is not None:
            # sin和cos特定于RoPE模型;静态缓存需要cache_position
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # 默认使用eager注意力计算接口
        attention_interface: Callable = eager_attention_forward

        # 根据配置选择不同的注意力实现
        if self.config._attn_implementation != "eager":
            # sdpa不支持输出注意力权重时,回退到eager实现
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                # 使用配置指定的注意力实现
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        # 调用选定的注意力接口计算注意力输出和权重
        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        # 将注意力输出重塑回原始形状
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        # 通过输出投影层变换注意力输出
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

# LlamaDecoderLayer类 - 实现了LLaMA模型中的单个Transformer解码器层
# 继承自GradientCheckpointingLayer,支持梯度检查点以节省内存
# 采用Pre-LayerNorm架构,即在注意力和FFN前应用层归一化
# 使用残差连接,保持信息流通并缓解梯度消失问题
class LlamaDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 自注意力层
        self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)

        # 前馈神经网络层
        self.mlp = LlamaMLP(config)
        # 输入层归一化,应用于自注意力前
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 注意力后层归一化,应用于MLP前
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 必要的,但为了向后兼容而保留在这里
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存残差连接用的隐藏状态
        residual = hidden_states
        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        # 第一个残差连接:原始输入 + 注意力输出
        hidden_states = residual + hidden_states

        # 前馈神经网络(全连接层)
        # 保存第二个残差连接用的隐藏状态
        residual = hidden_states
        # 应用注意力后层归一化
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用MLP
        hidden_states = self.mlp(hidden_states)
        # 第二个残差连接:注意力输出 + MLP输出
        hidden_states = residual + hidden_states

        # 准备返回值
        outputs = (hidden_states,)
        # 如果需要返回注意力权重,将其添加到输出元组
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


# LLAMA模型文档字符串开始
LLAMA_START_DOCSTRING = r"""
    该模型继承自[`PreTrainedModel`]。查看超类文档了解库为所有模型实现的通用方法
    (如下载或保存、调整输入嵌入大小、剪枝头等)

    该模型也是PyTorch的[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)子类。
    将其作为常规PyTorch模块使用,并参考PyTorch文档了解与一般用法和行为相关的所有事项。

    参数:
        config ([`LlamaConfig`]):
            包含模型所有参数的模型配置类。使用配置文件初始化只加载模型的配置,
            而不加载与模型关联的权重。查看[`~PreTrainedModel.from_pretrained`]方法
            来加载模型权重。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
    LLAMA_START_DOCSTRING,
)
# LlamaPreTrainedModel类 - 所有LLAMA模型变体的基础类
# 定义了各种模型功能和支持的特性
# 实现了权重初始化等通用方法
class LlamaPreTrainedModel(PreTrainedModel):
    # 设置配置类
    config_class = LlamaConfig
    # 基础模型前缀
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 指定不拆分的模块
    _no_split_modules = ["LlamaDecoderLayer"]
    # 设备放置时跳过的键
    _skip_keys_device_placement = ["past_key_values"]
    # 支持Flash Attention 2优化
    _supports_flash_attn_2 = True
    # 支持缩放点积注意力(SDPA)优化
    _supports_sdpa = True
    # 支持灵活注意力优化
    _supports_flex_attn = True
    # 支持缓存类
    _supports_cache_class = True
    # 支持量化缓存
    _supports_quantized_cache = True
    # 支持静态缓存
    _supports_static_cache = True
    # 支持注意力后端
    _supports_attention_backend = True

    def _init_weights(self, module):
        # 初始化权重的标准差
        std = self.config.initializer_range
        # 线性层权重初始化
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 嵌入层权重初始化
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # RMSNorm层权重初始化为1.0
        elif isinstance(module, LlamaRMSNorm):
            module.weight.data.fill_(1.0)


# LLAMA输入文档字符串
LLAMA_INPUTS_DOCSTRING = r"""
    参数:
        input_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。如果提供了填充,将默认忽略。

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            [什么是输入ID?](../glossary#input-ids)
        attention_mask (`torch.Tensor` 形状为 `(batch_size, sequence_length) 或 `BlockMask`, *可选*):
            用于避免在填充标记索引上执行注意力的掩码。掩码值选择为`[0, 1]`:

            - 1表示**未被掩盖的**标记,
            - 0表示**被掩盖的**标记。

            如果模型配置为使用flex_attention,它将尝试将掩码Tensor转换为BlockMask,
            但您也可以直接在此处传递`BlockMask`对象。

            [什么是注意力掩码?](../glossary#attention-mask)

            可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]了解详情。

            如果使用了`past_key_values`,可选地只需要输入最后的`input_ids`(参见
            `past_key_values`)。

            如果您想更改填充行为,应阅读[`modeling_opt._prepare_decoder_attention_mask`]
            并根据您的需求进行修改。有关默认策略的更多信息,请参见[论文](https://arxiv.org/abs/1910.13461)中的图1。

            - 1表示头部**未被掩盖**,
            - 0表示头部**被掩盖**。
        position_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
            每个输入序列标记在位置嵌入中的索引。在范围`[0, config.n_positions - 1]`中选择。

            [什么是位置ID?](../glossary#position-ids)
        past_key_values (`Cache`, *可选*):
            预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可用于加速顺序解码。
            这通常由模型在之前阶段的解码中返回的`past_key_values`组成,当`use_cache=True`或
            `config.use_cache=True`时。

            它是一个[`~cache_utils.Cache`]实例。有关更多详细信息,请参见我们的[kv缓存指南](https://huggingface.co/docs/transformers/en/kv_cache)。

            如果使用了`past_key_values`,用户可以选择只输入最后的`input_ids`(那些没有给这个模型的过去键值状态的输入),
            形状为`(batch_size, 1)`,而不是所有的`input_ids`,形状为`(batch_size, sequence_length)`。
        inputs_embeds (`torch.FloatTensor` 形状为 `(batch_size, sequence_length, hidden_size)`, *可选*):
            可选地,您可以选择直接传递嵌入表示,而不是传递`input_ids`。这在您想要更多地控制
            如何将`input_ids`索引转换为相关向量比模型的内部嵌入查找矩阵更有用。
        use_cache (`bool`, *可选*):
            如果设置为`True`,将返回`past_key_values`键值状态,可用于加速解码(参见
            `past_key_values`)。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。有关更多详细信息,请参见`attentions`下的返回
            张量。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。有关更多详细信息,请参见`hidden_states`下的返回
            张量。
        return_dict (`bool`, *可选*):
            是否返回[`~utils.ModelOutput`]而不是普通元组。
        cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
            描述输入序列标记在序列中位置的索引。与`position_ids`不同,这个张量不受填充影响。
            它用于在正确的位置更新缓存并推断完整的序列长度。
"""


# 添加开始文档字符串
@add_start_docstrings(
    "裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
    LLAMA_START_DOCSTRING,
)
# LlamaModel类 - 实现了LLaMA的核心模型结构
# 由多个LlamaDecoderLayer层堆叠而成,应用RMSNorm和RoPE位置编码
# 实现了高效的注意力机制,支持KV缓存以加速推理
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer解码器,由*config.num_hidden_layers*层组成。每一层都是一个[`LlamaDecoderLayer`]

    参数:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        # 设置填充索引和词汇表大小
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词嵌入层:将输入token转换为隐藏表示
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建解码器层堆栈
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 最终的层归一化
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 旋转位置编码
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        # 梯度检查点标志,默认关闭
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> BaseModelOutputWithPast:
        # 设置默认值,优先使用传入参数,其次使用配置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        # 梯度检查点与缓存不兼容,训练时警告并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 检查过去键值类型兼容性
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        # 如果没有提供嵌入,通过词嵌入层获取
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果启用缓存但未提供past_key_values,创建动态缓存
        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        # 处理缓存位置:确定序列中每个token的位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有提供位置ID,从缓存位置创建
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码:确保自回归属性,即当前token只能看到过去的token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # 设置初始隐藏状态为输入嵌入
        hidden_states = inputs_embeds

        # 创建位置嵌入,在所有解码器层之间共享
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # 解码器层处理
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        # 遍历所有解码器层
        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            # 如果需要输出所有隐藏状态,保存当前隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 前向传播通过当前解码器层
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

            # 更新隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,保存当前层的注意力权重
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # 最终的层归一化
        hidden_states = self.norm(hidden_states)

        # 添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回模型输出
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: Union[torch.Tensor, "BlockMask"],
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool = False,
    ):
        """
        更新因果掩码:根据不同的注意力实现方式处理掩码
        处理不同的注意力实现(Flash Attention 2、Flex Attention、SDPA等)
        """
        # Flash Attention 2的特殊处理
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and (attention_mask == 0.0).any():
                return attention_mask
            return None

        # Flex Attention的特殊处理
        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            return attention_mask

        # 对于SDPA(缩放点积注意力),尽可能依赖其`is_causal`参数而非`attn_mask`参数
        # 这样可以调度到Flash Attention 2。此功能与静态缓存不兼容。
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # 当输出注意力为True时,sdpa实现的forward方法调用eager实现的forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # 获取输入张量的数据类型和设备
        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]

        # 确定目标长度
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # 生成4D因果掩码
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # SDPA的特殊处理:对于完全掩码的行,允许注意所有token
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type in ["cuda", "xpu", "npu"]
            and not output_attentions
        ):
            # 注意完全掩码行中的所有token,例如使用左填充时的相关第一行
            # 这是F.scaled_dot_product_attention内存高效注意力路径所需的
            # 详情:https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        创建一个形状为`(batch_size, 1, query_length, key_value_length)`的因果4D掩码,
        从形状为`(batch_size, key_value_length)`的2D掩码创建,
        或如果输入的`attention_mask`已经是4D,则不做任何处理。

        参数:
            attention_mask (`torch.Tensor`):
                形状为`(batch_size, key_value_length)`的2D注意力掩码,
                或形状为`(batch_size, 1, query_length, key_value_length)`的4D注意力掩码。
            sequence_length (`int`):
                正在处理的序列长度。
            target_length (`int`):
                目标长度:使用静态缓存生成时,掩码应与静态缓存一样长,
                以考虑0填充、尚未填充的缓存部分。
            dtype (`torch.dtype`):
                用于4D注意力掩码的数据类型。
            device (`torch.device`):
                放置4D注意力掩码的设备。
            cache_position (`torch.Tensor`):
                描述输入序列标记在序列中位置的索引。
            batch_size (`torch.Tensor`):
                批次大小。
        """
        # 如果掩码已经是4D,直接使用
        if attention_mask is not None and attention_mask.dim() == 4:
            # 在这种情况下,我们假设掩码已经以反转形式提供,不需要反转或切片
            causal_mask = attention_mask
        else:
            # 创建填充有最小值的因果掩码
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            # 如果序列长度不是1,创建上三角掩码(对角线以上为1)
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            # 应用缓存位置掩码
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            # 扩展掩码维度以匹配批次大小
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

            # 如果提供了注意力掩码,合并掩码
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # 复制到连续内存以进行原地编辑
                mask_length = attention_mask.shape[-1]
                # 合并因果掩码和注意力掩码
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                    causal_mask.device
                )
                padding_mask = padding_mask == 0
                # 应用掩码
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


# LlamaForCausalLM类 - 用于因果语言建模的LLaMA模型
# 在LlamaModel基础上添加了语言建模头,用于生成任务
# 支持高效的文本生成,包含损失计算和推理优化
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    # 定义与主干模型共享权重的参数
    _tied_weights_keys = ["lm_head.weight"]
    # 张量并行计划:lm_head按列方向划分并复制
    _tp_plan = {"lm_head": "colwise_rep"}
    # 流水线并行计划:lm_head接收hidden_states并输出logits
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.model = LlamaModel(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 创建语言模型头:将隐藏状态映射到词汇表大小的logits
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """获取输出嵌入层(lm_head)"""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """设置输出嵌入层(lm_head)"""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """设置解码器模型"""
        self.model = decoder

    def get_decoder(self):
        """获取解码器模型"""
        return self.model

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPast:
        r"""
            labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
                用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
                (参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
                损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。

            logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
                如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
                在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
                对于长序列或大词汇表尤为显著。
                如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
                这在使用打包张量格式(批次和序列长度的单一维度)时很有用。

        返回:

        示例:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # 生成
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        # 设置默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 通过解码器模型获取特征
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        # 获取最后的隐藏状态
        hidden_states = outputs.last_hidden_state
        # 只计算必要的logits,如果不计算损失则不将其转换为浮点类型
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        # 返回结果
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
    带有序列分类头的LLaMa模型transformer(线性层)。

    [`LlamaForSequenceClassification`]使用最后一个token进行分类,就像其他因果模型
    (例如GPT-2)一样。

    由于它对最后一个token进行分类,所以需要知道最后一个token的位置。如果配置中定义了
    `pad_token_id`,它会在每行中找到不是填充token的最后一个token。如果
    没有定义`pad_token_id`,它简单地取每行中的最后一个值。由于当传递`inputs_embeds`
    而不是`input_ids`时它无法猜测填充token,所以它会做同样的事情(取每行中的最后一个值)。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForSequenceClassification类 - 用于序列分类任务的LLaMA模型
# 在LlamaModel基础上添加了分类头,用于序列级别的预测
# 通过获取序列中最后一个非填充token的表示进行分类
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 设置标签数量
        self.num_labels = config.num_labels
        # 创建基础模型
        self.model = LlamaModel(config)
        # 创建分类得分头
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> SequenceClassifierOutputWithPast:
        r"""
        labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
            如果`config.num_labels == 1`,则计算回归损失(均方损失);
            如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
        """

        # 通过Transformer模型获取输出
        transformer_outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # 获取最后的隐藏状态
        hidden_states = transformer_outputs.last_hidden_state
        # 计算所有位置的分类得分
        logits = self.score(hidden_states)

        # 确定批次大小
        if input_ids is not None:
            batch_size = input_ids.shape[0]
        else:
            batch_size = inputs_embeds.shape[0]

        # 处理填充token:找到每个序列中最后一个非填充token的位置
        if self.config.pad_token_id is None and batch_size != 1:
            raise ValueError("如果没有定义填充token,则无法处理批次大小 > 1。")
        if self.config.pad_token_id is None:
            # 如果没有定义填充token,使用最后一个位置
            last_non_pad_token = -1
        elif input_ids is not None:
            # 处理左填充和右填充,我们取不等于pad_token_id的最右边的token
            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
            # 找到最后一个非填充token的位置
            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
        else:
            # 如果没有提供input_ids,使用最后一个位置
            last_non_pad_token = -1
            logger.warning_once(
                f"{self.__class__.__name__} 将不会在 `inputs_embeds` 中检测填充token。如果结合使用填充token和 "
                "`inputs_embeds`,结果可能会出乎意料。"
            )

        # 提取每个序列中最后一个非填充token位置的logits
        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

        # 返回结果
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pooled_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
带有跨度分类头的Llama模型transformer,用于抽取式问答任务,如
SQuAD(在隐藏状态输出之上的线性层,用于计算`跨度开始logits`和`跨度结束logits`)。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForQuestionAnswering类 - 用于抽取式问答任务的LLaMA模型
# 添加了跨度分类头,用于预测答案在文本中的开始和结束位置
# 适用于SQuAD等数据集上的问答任务
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    base_model_prefix = "transformer"

    # 从transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__复制,将Bloom替换为Llama
    def __init__(self, config):
        super().__init__(config)
        # 创建基础模型
        self.transformer = LlamaModel(config)
        # 创建问答输出层:将隐藏状态映射到两个输出(开始和结束位置)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.transformer.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.transformer.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> QuestionAnsweringModelOutput:
        r"""
        start_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            标记跨度开始位置(索引)的标签,用于计算token分类损失。
            位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
            不会被用于计算损失。
        end_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            标记跨度结束位置(索引)的标签,用于计算token分类损失。
            位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
            不会被用于计算损失。
        """

        # 通过Transformer模型获取输出
        outputs: BaseModelOutputWithPast = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        # 获取序列输出
        sequence_output = outputs.last_hidden_state

        # 通过问答输出层计算logits
        logits = self.qa_outputs(sequence_output)
        # 将logits分割为开始和结束logits
        start_logits, end_logits = logits.split(1, dim=-1)
        # 压缩维度并确保张量连续
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        # 如果提供了开始和结束位置,计算损失
        loss = None
        if start_positions is not None and end_positions is not None:
            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)

        # 返回问答模型输出
        return QuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加开始文档字符串
@add_start_docstrings(
    """
    带有token分类头的Llama模型transformer(在隐藏状态输出之上的线性层),
    例如用于命名实体识别(NER)任务。
    """,
    LLAMA_START_DOCSTRING,
)
# LlamaForTokenClassification类 - 用于token级别分类任务的LLaMA模型
# 添加了token分类头,用于对每个token进行分类
# 适用于命名实体识别(NER)、词性标注(POS)等任务
class LlamaForTokenClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 设置标签数量
        self.num_labels = config.num_labels
        # 创建基础模型
        self.model = LlamaModel(config)

        # 确定分类器的dropout率
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        # 创建dropout层
        self.dropout = nn.Dropout(classifier_dropout)
        # 创建分类得分层
        self.score = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        """获取输入嵌入层"""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """设置输入嵌入层"""
        self.model.embed_tokens = value

    @can_return_tuple
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> TokenClassifierOutput:
        r"""
        labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
            用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
            如果`config.num_labels == 1`,则计算回归损失(均方损失);
            如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
        """

        # 通过基础模型获取输出
        outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # 获取序列输出
        sequence_output = outputs.last_hidden_state
        # 应用dropout
        sequence_output = self.dropout(sequence_output)
        # 计算每个token的分类logits
        logits = self.score(sequence_output)

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.config)

        # 返回token分类器输出
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 定义模块中所有公开的类
__all__ = [
    "LlamaForCausalLM",
    "LlamaModel",
    "LlamaPreTrainedModel",
    "LlamaForSequenceClassification",
    "LlamaForQuestionAnswering",
    "LlamaForTokenClassification",
]
相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
OpenCV Jetson Breakpoint Vmess Template Paper ModelScope Conda FP8 XML Permission SVR git-lfs Pytorch Git Pickle LLM SAM CAM Llama IndexTTS2 Linux CUDA NameSilo LLAMA Knowledge Food uWSGI ONNX Freesound 继承 FP64 COCO Docker Hungarian Distillation FlashAttention InvalidArgumentError Color TSV C++ TensorRT CTC Animate Vim Translation Heatmap TensorFlow Bipartite Attention RGB 签证 VSCode Cloudreve Bitcoin transformers Zip PIP 报税 diffusers mmap CSV 算法题 Github SQLite WAN PyCharm Jupyter Review Excel Pandas Paddle v2ray YOLO Bert LeetCode scipy CEIR API Video Tensor Password 公式 域名 WebCrawler Land EXCEL 腾讯云 Firewall Ptyhon NLTK Baidu CV Proxy GPT4 Google Michelin Gemma LoRA Windows Dataset v0.dev Sklearn Magnet Disk printf Nginx LaTeX XGBoost Anaconda Tracking RAR tar TTS QWEN JSON 飞书 Clash BeautifulSoup PDB Qwen2 SPIE GoogLeNet FP16 OpenAI Datetime Interview VGG-16 ChatGPT torchinfo Tiktoken DeepStream Mixtral Website Plotly git BTC llama.cpp GPTQ 搞笑 MD5 BF16 Math HaggingFace Image2Text AI PyTorch Claude FP32 Web Shortcut Qwen2.5 Django HuggingFace Statistics CLAP Bin ResNet-50 Streamlit Random Input SQL Ubuntu DeepSeek Markdown Crawler 财报 UI GIT 多线程 Quantize Qwen Hilton Base64 GGML Card Quantization Domain Algorithm NLP Python uwsgi Logo 版权 FastAPI logger Pillow Use 音频 Miniforge Diagram OCR 证件照 Numpy Plate 净利润 VPN 阿里云 hf Augmentation 关于博主 CC PDF Safetensors 多进程 Transformers tqdm Hotel UNIX Data
站点统计

本站现有博文311篇,共被浏览742117

本站已经建立2381天!

热门文章
文章归档
回到顶部