Transformers Llama 模型代码中文注释 modeling_llama.py
作者:XD / 发表: 2025年4月23日 04:50 / 更新: 2025年4月23日 04:50 / 编程笔记 / 阅读量:34
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# 此代码基于EleutherAI的GPT-NeoX库以及此库中的GPT-NeoX和OPT实现。
# 它已从原始形式修改,以适应与Meta AI团队训练模型时使用的GPT-NeoX和OPT相比的
# 微小架构差异。
#
# 根据Apache许可证2.0版("许可证")获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下位置获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件
# 是基于"按原样"分发的,没有任何明示或暗示的担保或条件。
# 有关许可证下特定语言的权限和限制,请参阅许可证。
# 中文代码注释 XD
# 源码请参考 https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
from typing import Callable, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig
# 检查FlexAttention是否可用,如果可用则导入相关模块
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
# 导入用于从hub加载内核前向传递的工具
from ...integrations import use_kernel_forward_from_hub
logger = logging.get_logger(__name__)
# 文档生成用的检查点和配置
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig"
# LlamaRMSNorm类 - 这是LLaMA模型使用的一种特殊归一化层,等同于T5LayerNorm
# 与传统的LayerNorm不同,RMSNorm只使用均方根进行归一化,不涉及均值中心化
# 计算流程:先保存输入类型,转为float32进行高精度计算,然后进行归一化,最后转回原始数据类型
@use_kernel_forward_from_hub("RMSNorm")
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm等同于T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# 保存输入的数据类型
input_dtype = hidden_states.dtype
# 转换为float32进行计算
hidden_states = hidden_states.to(torch.float32)
# 计算方差
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# 归一化操作:除以方差的平方根
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# 应用权重并转回原始数据类型
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
# 为打印模块信息提供额外的表示
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# 将LlamaRMSNorm添加到所有层归一化层列表中
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
# LlamaRotaryEmbedding类 - 实现了旋转位置编码(RoPE)
# 支持多种RoPE变体,通过config中的rope_type参数指定
# 处理序列长度扩展问题,允许模型处理比训练时更长的序列
# 包含注意力缩放机制,用于优化不同长度序列的表示
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig, device=None):
"""
LLaMA模型使用的旋转位置编码(RoPE)实现
"""
super().__init__()
# 向后兼容:处理rope_type参数
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
# 缓存的最大序列长度
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
# 根据RoPE类型获取初始化函数
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
# 初始化频率反转和注意力缩放
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # 高级用户:用于高级RoPE类型(例如dynamic rope)
def forward(self, x, position_ids):
# 扩展反转频率到批次大小
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
# 确定设备类型,特殊处理MPS设备
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # 强制使用float32
# 计算频率
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# 合并频率
emb = torch.cat((freqs, freqs), dim=-1)
# 计算余弦和正弦部分,并应用注意力缩放
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
# 返回结果,转换为输入的数据类型
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# 将张量的一半隐藏维度进行旋转,是RoPE核心操作之一
def rotate_half(x):
"""旋转输入张量的一半隐藏维度。"""
# 将输入张量在最后一个维度上分成两半
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
# 返回旋转后的结果:(-x2, x1)
return torch.cat((-x2, x1), dim=-1)
# 将旋转位置编码应用到查询(Q)和键(K)张量上
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
将旋转位置编码应用到查询和键张量上。
参数:
q (`torch.Tensor`): 查询张量。
k (`torch.Tensor`): 键张量。
cos (`torch.Tensor`): 旋转嵌入的余弦部分。
sin (`torch.Tensor`): 旋转嵌入的正弦部分。
position_ids (`torch.Tensor`, *可选*):
已弃用且未使用。
unsqueeze_dim (`int`, *可选*, 默认为1):
'unsqueeze_dim'参数指定沿着哪个维度对cos[position_ids]和sin[position_ids]进行扩展,
以便它们可以正确地广播到q和k的维度上。例如,注意cos[position_ids]和sin[position_ids]
具有形状[batch_size, seq_len, head_dim]。然后,如果q和k具有形状
[batch_size, heads, seq_len, head_dim],则设置unsqueeze_dim=1使
cos[position_ids]和sin[position_ids]可广播到q和k的形状。
同样,如果q和k具有形状[batch_size, seq_len, heads, head_dim],则设置unsqueeze_dim=2。
返回:
`tuple(torch.Tensor)` 包含使用旋转位置嵌入旋转后的查询和键张量。
"""
# 扩展维度以便广播
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# 应用旋转变换
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# LlamaMLP类 - 实现了LLaMA模型中的前馈神经网络部分
# 使用SwiGLU激活函数,这是一种门控线性单元变体
# 包含三个投影层:gate_proj、up_proj和down_proj
# 通过并行计算gate_proj和up_proj来提高效率
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# gate_proj:将隐藏状态投影到中间维度的门控投影层
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
# up_proj:将隐藏状态投影到中间维度的上投影层
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
# down_proj:将中间表示投影回隐藏维度的下投影层
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
# 激活函数,默认为SiLU(Sigmoid Linear Unit)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
# SwiGLU激活:先计算gate_proj经过激活函数的结果,再与up_proj的结果相乘
# 然后通过down_proj投影回原始维度
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# repeat_kv函数 - 实现了分组查询注意力(GQA)中的键值头复制操作
# 将键值头的数量从num_key_value_heads扩展到num_attention_heads
# 通过复制每个键值头以匹配注意力头的数量,实现了计算效率和内存使用的平衡
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
这相当于torch.repeat_interleave(x, dim=1, repeats=n_rep)。隐藏状态从(batch,
num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
# 如果复制次数为1,直接返回原始状态
if n_rep == 1:
return hidden_states
# 在原始形状中插入新维度并扩展,然后重塑以得到所需形状
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# eager_attention_forward函数 - 实现了标准的注意力前向传播计算
# 用于不使用CUDA优化(如FlashAttention)时的注意力计算
# 包含经典的注意力计算步骤:矩阵乘法、缩放、掩码、softmax和dropout
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# 应用GQA:复制键和值状态以匹配查询头的数量
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
# 计算注意力分数:查询和键的矩阵乘法,并应用缩放
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
# 如果提供了注意力掩码,将其应用到注意力分数上
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# 对注意力分数应用softmax得到注意力权重,使用float32以提高数值稳定性
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# 在训练时应用dropout
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# 使用注意力权重对值进行加权汇总
attn_output = torch.matmul(attn_weights, value_states)
# 调整输出的维度顺序并确保内存连续
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# LlamaAttention类 - 实现了LLaMA模型中的多头注意力机制
# 支持分组查询注意力(GQA),可以减少内存使用同时保持性能
# 包含旋转位置编码(RoPE)的应用
# 支持KV缓存以加速自回归生成
class LlamaAttention(nn.Module):
"""来自'Attention Is All You Need'论文的多头注意力机制"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
# 获取注意力头维度,如果未指定则从隐藏大小和头数计算
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
# 每个键值头对应的查询头数量
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
# 缩放因子,用于缩放注意力分数
self.scaling = self.head_dim**-0.5
# 注意力dropout率
self.attention_dropout = config.attention_dropout
# 标记这是因果注意力(使用因果掩码)
self.is_causal = True
# 查询投影:将隐藏状态投影到查询表示空间
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
# 键投影:将隐藏状态投影到键表示空间
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
# 值投影:将隐藏状态投影到值表示空间
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
# 输出投影:将多头注意力的输出投影回原始隐藏维度
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# 记录输入形状,用于后续重塑
input_shape = hidden_states.shape[:-1]
# 计算多头形状
hidden_shape = (*input_shape, -1, self.head_dim)
# 计算查询、键、值状态并重塑为多头形式
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# 解包位置嵌入的余弦和正弦部分
cos, sin = position_embeddings
# 应用旋转位置编码到查询和键状态
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# 如果有过去的键值缓存,更新当前的键值状态
if past_key_value is not None:
# sin和cos特定于RoPE模型;静态缓存需要cache_position
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# 默认使用eager注意力计算接口
attention_interface: Callable = eager_attention_forward
# 根据配置选择不同的注意力实现
if self.config._attn_implementation != "eager":
# sdpa不支持输出注意力权重时,回退到eager实现
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
# 使用配置指定的注意力实现
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# 调用选定的注意力接口计算注意力输出和权重
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
# 将注意力输出重塑回原始形状
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
# 通过输出投影层变换注意力输出
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# LlamaDecoderLayer类 - 实现了LLaMA模型中的单个Transformer解码器层
# 继承自GradientCheckpointingLayer,支持梯度检查点以节省内存
# 采用Pre-LayerNorm架构,即在注意力和FFN前应用层归一化
# 使用残差连接,保持信息流通并缓解梯度消失问题
class LlamaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
# 自注意力层
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
# 前馈神经网络层
self.mlp = LlamaMLP(config)
# 输入层归一化,应用于自注意力前
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 注意力后层归一化,应用于MLP前
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # 必要的,但为了向后兼容而保留在这里
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
# 保存残差连接用的隐藏状态
residual = hidden_states
# 应用输入层归一化
hidden_states = self.input_layernorm(hidden_states)
# 自注意力计算
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
# 第一个残差连接:原始输入 + 注意力输出
hidden_states = residual + hidden_states
# 前馈神经网络(全连接层)
# 保存第二个残差连接用的隐藏状态
residual = hidden_states
# 应用注意力后层归一化
hidden_states = self.post_attention_layernorm(hidden_states)
# 应用MLP
hidden_states = self.mlp(hidden_states)
# 第二个残差连接:注意力输出 + MLP输出
hidden_states = residual + hidden_states
# 准备返回值
outputs = (hidden_states,)
# 如果需要返回注意力权重,将其添加到输出元组
if output_attentions:
outputs += (self_attn_weights,)
return outputs
# LLAMA模型文档字符串开始
LLAMA_START_DOCSTRING = r"""
该模型继承自[`PreTrainedModel`]。查看超类文档了解库为所有模型实现的通用方法
(如下载或保存、调整输入嵌入大小、剪枝头等)
该模型也是PyTorch的[torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)子类。
将其作为常规PyTorch模块使用,并参考PyTorch文档了解与一般用法和行为相关的所有事项。
参数:
config ([`LlamaConfig`]):
包含模型所有参数的模型配置类。使用配置文件初始化只加载模型的配置,
而不加载与模型关联的权重。查看[`~PreTrainedModel.from_pretrained`]方法
来加载模型权重。
"""
# 添加开始文档字符串
@add_start_docstrings(
"裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
LLAMA_START_DOCSTRING,
)
# LlamaPreTrainedModel类 - 所有LLAMA模型变体的基础类
# 定义了各种模型功能和支持的特性
# 实现了权重初始化等通用方法
class LlamaPreTrainedModel(PreTrainedModel):
# 设置配置类
config_class = LlamaConfig
# 基础模型前缀
base_model_prefix = "model"
# 支持梯度检查点
supports_gradient_checkpointing = True
# 指定不拆分的模块
_no_split_modules = ["LlamaDecoderLayer"]
# 设备放置时跳过的键
_skip_keys_device_placement = ["past_key_values"]
# 支持Flash Attention 2优化
_supports_flash_attn_2 = True
# 支持缩放点积注意力(SDPA)优化
_supports_sdpa = True
# 支持灵活注意力优化
_supports_flex_attn = True
# 支持缓存类
_supports_cache_class = True
# 支持量化缓存
_supports_quantized_cache = True
# 支持静态缓存
_supports_static_cache = True
# 支持注意力后端
_supports_attention_backend = True
def _init_weights(self, module):
# 初始化权重的标准差
std = self.config.initializer_range
# 线性层权重初始化
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
# 嵌入层权重初始化
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# RMSNorm层权重初始化为1.0
elif isinstance(module, LlamaRMSNorm):
module.weight.data.fill_(1.0)
# LLAMA输入文档字符串
LLAMA_INPUTS_DOCSTRING = r"""
参数:
input_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`):
输入序列标记在词汇表中的索引。如果提供了填充,将默认忽略。
可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
[`PreTrainedTokenizer.__call__`]了解详情。
[什么是输入ID?](../glossary#input-ids)
attention_mask (`torch.Tensor` 形状为 `(batch_size, sequence_length) 或 `BlockMask`, *可选*):
用于避免在填充标记索引上执行注意力的掩码。掩码值选择为`[0, 1]`:
- 1表示**未被掩盖的**标记,
- 0表示**被掩盖的**标记。
如果模型配置为使用flex_attention,它将尝试将掩码Tensor转换为BlockMask,
但您也可以直接在此处传递`BlockMask`对象。
[什么是注意力掩码?](../glossary#attention-mask)
可以使用[`AutoTokenizer`]获取索引。查看[`PreTrainedTokenizer.encode`]和
[`PreTrainedTokenizer.__call__`]了解详情。
如果使用了`past_key_values`,可选地只需要输入最后的`input_ids`(参见
`past_key_values`)。
如果您想更改填充行为,应阅读[`modeling_opt._prepare_decoder_attention_mask`]
并根据您的需求进行修改。有关默认策略的更多信息,请参见[论文](https://arxiv.org/abs/1910.13461)中的图1。
- 1表示头部**未被掩盖**,
- 0表示头部**被掩盖**。
position_ids (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
每个输入序列标记在位置嵌入中的索引。在范围`[0, config.n_positions - 1]`中选择。
[什么是位置ID?](../glossary#position-ids)
past_key_values (`Cache`, *可选*):
预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),可用于加速顺序解码。
这通常由模型在之前阶段的解码中返回的`past_key_values`组成,当`use_cache=True`或
`config.use_cache=True`时。
它是一个[`~cache_utils.Cache`]实例。有关更多详细信息,请参见我们的[kv缓存指南](https://huggingface.co/docs/transformers/en/kv_cache)。
如果使用了`past_key_values`,用户可以选择只输入最后的`input_ids`(那些没有给这个模型的过去键值状态的输入),
形状为`(batch_size, 1)`,而不是所有的`input_ids`,形状为`(batch_size, sequence_length)`。
inputs_embeds (`torch.FloatTensor` 形状为 `(batch_size, sequence_length, hidden_size)`, *可选*):
可选地,您可以选择直接传递嵌入表示,而不是传递`input_ids`。这在您想要更多地控制
如何将`input_ids`索引转换为相关向量比模型的内部嵌入查找矩阵更有用。
use_cache (`bool`, *可选*):
如果设置为`True`,将返回`past_key_values`键值状态,可用于加速解码(参见
`past_key_values`)。
output_attentions (`bool`, *可选*):
是否返回所有注意力层的注意力张量。有关更多详细信息,请参见`attentions`下的返回
张量。
output_hidden_states (`bool`, *可选*):
是否返回所有层的隐藏状态。有关更多详细信息,请参见`hidden_states`下的返回
张量。
return_dict (`bool`, *可选*):
是否返回[`~utils.ModelOutput`]而不是普通元组。
cache_position (`torch.LongTensor` 形状为 `(sequence_length)`, *可选*):
描述输入序列标记在序列中位置的索引。与`position_ids`不同,这个张量不受填充影响。
它用于在正确的位置更新缓存并推断完整的序列长度。
"""
# 添加开始文档字符串
@add_start_docstrings(
"裸LLaMA模型,输出原始隐藏状态,顶部没有任何特定的头。",
LLAMA_START_DOCSTRING,
)
# LlamaModel类 - 实现了LLaMA的核心模型结构
# 由多个LlamaDecoderLayer层堆叠而成,应用RMSNorm和RoPE位置编码
# 实现了高效的注意力机制,支持KV缓存以加速推理
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer解码器,由*config.num_hidden_layers*层组成。每一层都是一个[`LlamaDecoderLayer`]
参数:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
# 设置填充索引和词汇表大小
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# 词嵌入层:将输入token转换为隐藏表示
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# 创建解码器层堆栈
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
# 最终的层归一化
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# 旋转位置编码
self.rotary_emb = LlamaRotaryEmbedding(config=config)
# 梯度检查点标志,默认关闭
self.gradient_checkpointing = False
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
# 设置默认值,优先使用传入参数,其次使用配置
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# 检查输入格式:input_ids和inputs_embeds必须且只能提供一个
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
# 梯度检查点与缓存不兼容,训练时警告并禁用缓存
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# 检查过去键值类型兼容性
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
# 如果没有提供嵌入,通过词嵌入层获取
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# 如果启用缓存但未提供past_key_values,创建动态缓存
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# 处理缓存位置:确定序列中每个token的位置
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# 如果没有提供位置ID,从缓存位置创建
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# 更新因果掩码:确保自回归属性,即当前token只能看到过去的token
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# 设置初始隐藏状态为输入嵌入
hidden_states = inputs_embeds
# 创建位置嵌入,在所有解码器层之间共享
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# 解码器层处理
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# 遍历所有解码器层
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
# 如果需要输出所有隐藏状态,保存当前隐藏状态
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 前向传播通过当前解码器层
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
# 更新隐藏状态为当前层的输出
hidden_states = layer_outputs[0]
# 如果需要输出注意力权重,保存当前层的注意力权重
if output_attentions:
all_self_attns += (layer_outputs[1],)
# 最终的层归一化
hidden_states = self.norm(hidden_states)
# 添加最后一个解码器层的隐藏状态
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 返回模型输出
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
"""
更新因果掩码:根据不同的注意力实现方式处理掩码
处理不同的注意力实现(Flash Attention 2、Flex Attention、SDPA等)
"""
# Flash Attention 2的特殊处理
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
# Flex Attention的特殊处理
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# 对于SDPA(缩放点积注意力),尽可能依赖其`is_causal`参数而非`attn_mask`参数
# 这样可以调度到Flash Attention 2。此功能与静态缓存不兼容。
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# 当输出注意力为True时,sdpa实现的forward方法调用eager实现的forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
# 获取输入张量的数据类型和设备
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
# 确定目标长度
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# 生成4D因果掩码
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
# SDPA的特殊处理:对于完全掩码的行,允许注意所有token
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# 注意完全掩码行中的所有token,例如使用左填充时的相关第一行
# 这是F.scaled_dot_product_attention内存高效注意力路径所需的
# 详情:https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
创建一个形状为`(batch_size, 1, query_length, key_value_length)`的因果4D掩码,
从形状为`(batch_size, key_value_length)`的2D掩码创建,
或如果输入的`attention_mask`已经是4D,则不做任何处理。
参数:
attention_mask (`torch.Tensor`):
形状为`(batch_size, key_value_length)`的2D注意力掩码,
或形状为`(batch_size, 1, query_length, key_value_length)`的4D注意力掩码。
sequence_length (`int`):
正在处理的序列长度。
target_length (`int`):
目标长度:使用静态缓存生成时,掩码应与静态缓存一样长,
以考虑0填充、尚未填充的缓存部分。
dtype (`torch.dtype`):
用于4D注意力掩码的数据类型。
device (`torch.device`):
放置4D注意力掩码的设备。
cache_position (`torch.Tensor`):
描述输入序列标记在序列中位置的索引。
batch_size (`torch.Tensor`):
批次大小。
"""
# 如果掩码已经是4D,直接使用
if attention_mask is not None and attention_mask.dim() == 4:
# 在这种情况下,我们假设掩码已经以反转形式提供,不需要反转或切片
causal_mask = attention_mask
else:
# 创建填充有最小值的因果掩码
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
# 如果序列长度不是1,创建上三角掩码(对角线以上为1)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
# 应用缓存位置掩码
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
# 扩展掩码维度以匹配批次大小
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
# 如果提供了注意力掩码,合并掩码
if attention_mask is not None:
causal_mask = causal_mask.clone() # 复制到连续内存以进行原地编辑
mask_length = attention_mask.shape[-1]
# 合并因果掩码和注意力掩码
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
# 应用掩码
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# 组合FlashAttentionKwargs和LossKwargs为因果语言模型创建一个新的参数类
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
# LlamaForCausalLM类 - 用于因果语言建模的LLaMA模型
# 在LlamaModel基础上添加了语言建模头,用于生成任务
# 支持高效的文本生成,包含损失计算和推理优化
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
# 定义与主干模型共享权重的参数
_tied_weights_keys = ["lm_head.weight"]
# 张量并行计划:lm_head按列方向划分并复制
_tp_plan = {"lm_head": "colwise_rep"}
# 流水线并行计划:lm_head接收hidden_states并输出logits
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
# 创建基础模型
self.model = LlamaModel(config)
# 设置词汇表大小
self.vocab_size = config.vocab_size
# 创建语言模型头:将隐藏状态映射到词汇表大小的logits
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""获取输出嵌入层(lm_head)"""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""设置输出嵌入层(lm_head)"""
self.lm_head = new_embeddings
def set_decoder(self, decoder):
"""设置解码器模型"""
self.model = decoder
def get_decoder(self):
"""获取解码器模型"""
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` 形状为 `(batch_size, sequence_length)`, *可选*):
用于计算掩码语言模型损失的标签。索引应该在`[0, ..., config.vocab_size]`范围内或为-100
(参见`input_ids`文档)。设置为`-100`的索引将被忽略(掩码),
损失仅针对标签在`[0, ..., config.vocab_size]`范围内的token计算。
logits_to_keep (`int` 或 `torch.Tensor`, *可选*):
如果是`int`,则为最后`logits_to_keep`个token计算logits。如果为`0`,则为所有`input_ids`计算logits(特殊情况)。
在生成过程中,通常只需要最后一个token的logits,为该token单独计算可以节省内存,
对于长序列或大词汇表尤为显著。
如果是`torch.Tensor`,必须是1D张量,对应于在序列长度维度上要保留的索引。
这在使用打包张量格式(批次和序列长度的单一维度)时很有用。
返回:
示例:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # 生成
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# 设置默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 通过解码器模型获取特征
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
# 获取最后的隐藏状态
hidden_states = outputs.last_hidden_state
# 只计算必要的logits,如果不计算损失则不将其转换为浮点类型
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
# 如果提供了标签,计算损失
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
# 返回结果
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 添加开始文档字符串
@add_start_docstrings(
"""
带有序列分类头的LLaMa模型transformer(线性层)。
[`LlamaForSequenceClassification`]使用最后一个token进行分类,就像其他因果模型
(例如GPT-2)一样。
由于它对最后一个token进行分类,所以需要知道最后一个token的位置。如果配置中定义了
`pad_token_id`,它会在每行中找到不是填充token的最后一个token。如果
没有定义`pad_token_id`,它简单地取每行中的最后一个值。由于当传递`inputs_embeds`
而不是`input_ids`时它无法猜测填充token,所以它会做同样的事情(取每行中的最后一个值)。
""",
LLAMA_START_DOCSTRING,
)
# LlamaForSequenceClassification类 - 用于序列分类任务的LLaMA模型
# 在LlamaModel基础上添加了分类头,用于序列级别的预测
# 通过获取序列中最后一个非填充token的表示进行分类
class LlamaForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 设置标签数量
self.num_labels = config.num_labels
# 创建基础模型
self.model = LlamaModel(config)
# 创建分类得分头
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.model.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> SequenceClassifierOutputWithPast:
r"""
labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
如果`config.num_labels == 1`,则计算回归损失(均方损失);
如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
"""
# 通过Transformer模型获取输出
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 获取最后的隐藏状态
hidden_states = transformer_outputs.last_hidden_state
# 计算所有位置的分类得分
logits = self.score(hidden_states)
# 确定批次大小
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
# 处理填充token:找到每个序列中最后一个非填充token的位置
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("如果没有定义填充token,则无法处理批次大小 > 1。")
if self.config.pad_token_id is None:
# 如果没有定义填充token,使用最后一个位置
last_non_pad_token = -1
elif input_ids is not None:
# 处理左填充和右填充,我们取不等于pad_token_id的最右边的token
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
# 找到最后一个非填充token的位置
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
# 如果没有提供input_ids,使用最后一个位置
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} 将不会在 `inputs_embeds` 中检测填充token。如果结合使用填充token和 "
"`inputs_embeds`,结果可能会出乎意料。"
)
# 提取每个序列中最后一个非填充token位置的logits
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
# 如果提供了标签,计算损失
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
# 返回结果
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# 添加开始文档字符串
@add_start_docstrings(
"""
带有跨度分类头的Llama模型transformer,用于抽取式问答任务,如
SQuAD(在隐藏状态输出之上的线性层,用于计算`跨度开始logits`和`跨度结束logits`)。
""",
LLAMA_START_DOCSTRING,
)
# LlamaForQuestionAnswering类 - 用于抽取式问答任务的LLaMA模型
# 添加了跨度分类头,用于预测答案在文本中的开始和结束位置
# 适用于SQuAD等数据集上的问答任务
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
base_model_prefix = "transformer"
# 从transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__复制,将Bloom替换为Llama
def __init__(self, config):
super().__init__(config)
# 创建基础模型
self.transformer = LlamaModel(config)
# 创建问答输出层:将隐藏状态映射到两个输出(开始和结束位置)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.transformer.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> QuestionAnsweringModelOutput:
r"""
start_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
标记跨度开始位置(索引)的标签,用于计算token分类损失。
位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
不会被用于计算损失。
end_positions (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
标记跨度结束位置(索引)的标签,用于计算token分类损失。
位置会被限制在序列长度内(`sequence_length`)。超出序列外的位置
不会被用于计算损失。
"""
# 通过Transformer模型获取输出
outputs: BaseModelOutputWithPast = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 获取序列输出
sequence_output = outputs.last_hidden_state
# 通过问答输出层计算logits
logits = self.qa_outputs(sequence_output)
# 将logits分割为开始和结束logits
start_logits, end_logits = logits.split(1, dim=-1)
# 压缩维度并确保张量连续
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
# 如果提供了开始和结束位置,计算损失
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
# 返回问答模型输出
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 添加开始文档字符串
@add_start_docstrings(
"""
带有token分类头的Llama模型transformer(在隐藏状态输出之上的线性层),
例如用于命名实体识别(NER)任务。
""",
LLAMA_START_DOCSTRING,
)
# LlamaForTokenClassification类 - 用于token级别分类任务的LLaMA模型
# 添加了token分类头,用于对每个token进行分类
# 适用于命名实体识别(NER)、词性标注(POS)等任务
class LlamaForTokenClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# 设置标签数量
self.num_labels = config.num_labels
# 创建基础模型
self.model = LlamaModel(config)
# 确定分类器的dropout率
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
# 创建dropout层
self.dropout = nn.Dropout(classifier_dropout)
# 创建分类得分层
self.score = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重并应用最终处理
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.model.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> TokenClassifierOutput:
r"""
labels (`torch.LongTensor` 形状为 `(batch_size,)`, *可选*):
用于计算序列分类/回归损失的标签。索引应该在`[0, ..., config.num_labels - 1]`范围内。
如果`config.num_labels == 1`,则计算回归损失(均方损失);
如果`config.num_labels > 1`,则计算分类损失(交叉熵)。
"""
# 通过基础模型获取输出
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 获取序列输出
sequence_output = outputs.last_hidden_state
# 应用dropout
sequence_output = self.dropout(sequence_output)
# 计算每个token的分类logits
logits = self.score(sequence_output)
# 如果提供了标签,计算损失
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
# 返回token分类器输出
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# 定义模块中所有公开的类
__all__ = [
"LlamaForCausalLM",
"LlamaModel",
"LlamaPreTrainedModel",
"LlamaForSequenceClassification",
"LlamaForQuestionAnswering",
"LlamaForTokenClassification",
]
相关标签