EADST

Qwen-7B-Chat模型结构注释

# 版权所有 (c) Alibaba Cloud.
# 本源代码根据根目录中的LICENSE文件的许可证进行许可。
# 注释作者: eadst
# 创建日期: 2023-10-31
# 版本: v1.0.0

import copy  # 导入copy模块,用于复制对象
import importlib  # 导入importlib模块,用于动态导入模块
import math  # 导入math模块,提供数学运算函数
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator  # 导入类型检查和类型注解相关的模块和类

import torch  # 导入PyTorch库
import torch.nn.functional as F  # 导入PyTorch的nn.functional模块,提供神经网络相关的函数
import torch.utils.checkpoint  # 导入PyTorch的checkpoint模块,用于实现梯度检查点
from torch.cuda.amp import autocast  # 导入PyTorch的autocast模块,用于自动混合精度训练

from torch.nn import CrossEntropyLoss  # 导入PyTorch的CrossEntropyLoss类,用于计算交叉熵损失
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList  # 导入transformers库的相关类
from transformers.generation.logits_process import LogitsProcessorList  # 导入transformers库的LogitsProcessorList类

if TYPE_CHECKING:  # 如果启用了类型检查
    from transformers.generation.streamers import BaseStreamer  # 导入BaseStreamer类
from transformers.generation.utils import GenerateOutput  # 导入GenerateOutput类
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)  # 导入BaseModelOutputWithPast和CausalLMOutputWithPast类
from transformers.modeling_utils import PreTrainedModel  # 导入PreTrainedModel类
from transformers.utils import logging  # 导入transformers库的logging模块

try:
    from einops import rearrange  # 尝试导入einops库的rearrange函数
except ImportError:  # 如果导入失败
    rearrange = None  # 将rearrange设置为None
from torch import nn  # 导入PyTorch的nn模块,提供神经网络相关的类和函数

# 检查系统是否支持CUDA、BF16精度、FP16精度、PyTorch 2.0以上版本
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2

# 导入自定义的配置和工具函数
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
    HistoryType,
    make_context,
    decode_tokens,
    get_stop_words_ids,
    StopWordsLogitsProcessor,
)

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义文档中用到的常量
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"

# 定义预训练模型的存档列表
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]

# 定义错误提示信息
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
"""

_SENTINEL = object()  # 定义一个哨兵对象,用于特殊标识

# 定义错误提示信息
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
"""

# 定义错误提示信息
_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\
We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained).
检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。
"""

# 定义全局变量
apply_rotary_emb_func = None  # 用于存储apply_rotary_emb_func函数的变量
rms_norm = None  # 用于存储rms_norm函数的变量
flash_attn_unpadded_func = None  # 用于存储flash_attn_unpadded_func函数的变量

# 定义函数,用于导入flash attention相关的功能
def _import_flash_attn():
    global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func  # 声明全局变量
    try:
        from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
        apply_rotary_emb_func = __apply_rotary_emb_func  # 导入并设置apply_rotary_emb_func函数
    except ImportError:
        logger.warn(
            "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
            "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
        )

    try:
        from flash_attn.ops.rms_norm import rms_norm as __rms_norm
        rms_norm = __rms_norm  # 导入并设置rms_norm函数
    except ImportError:
        logger.warn(
            "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
            "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
        )

    try:
        import flash_attn
        if not hasattr(flash_attn, '__version__'):
            from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
        else:
            if int(flash_attn.__version__.split(".")[0]) >= 2:
                from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
            else:
                from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
        flash_attn_unpadded_func = __flash_attn_unpadded_func  # 导入并设置flash_attn_unpadded_func函数
    except ImportError:
        logger.warn(
            "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
            "https://github.com/Dao-AILab/flash-attention"
        )

# 定义函数用于量化cache
def quantize_cache_v(fdata, bits, qmax, qmin):
    # b, s, head, h-dim->b, head, s, h-dim
    qtype = torch.uint8  # 定义量化后的数据类型为uint8
    device = fdata.device  # 获取设备信息
    shape = fdata.shape  # 获取数据的形状

    fdata_cal = torch.flatten(fdata, 2)  # 将数据展平
    fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)  # 计算每个batch的最大值
    fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)  # 计算每个batch的最小值
    # 计算量化参数
    if qmax.device != fmax.device:
        qmax = qmax.to(device)
        qmin = qmin.to(device)
    scale = (fmax - fmin) / (qmax - qmin)  # 计算缩放因子
    zero = qmin - fmin / scale  # 计算零点
    scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()  # 重复缩放因子以匹配数据的形状
    zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()  # 重复零点以匹配数据的形状
    # 量化
    res_data = fdata / scale + zero  # 计算量化后的数据
    qdata = torch.clamp(res_data, qmin, qmax).to(qtype)  # 将数据限制在qmin和qmax之间,并转换为uint8类型
    return qdata.contiguous(), scale, zero  # 返回量化后的数据,缩放因子和零点

# 定义函数用于反量化cache
def dequantize_cache_torch(qdata, scale, zero):
    data = scale * (qdata - zero)  # 计算反量化后的数据
    return data  # 返回反量化后的数据

# 定义FlashSelfAttention类
class FlashSelfAttention(torch.nn.Module):
    def __init__(
        self,
        causal=False,  # 是否是因果关系
        softmax_scale=None,  # softmax的缩放因子
        attention_dropout=0.0,  # 注意力机制的dropout概率
    ):
        super().__init__()
        assert flash_attn_unpadded_func is not None, (
            "Please install FlashAttention first, " "e.g., with pip install flash-attn"
        )
        assert (
            rearrange is not None
        ), "Please install einops first, e.g., with pip install einops"
        self.causal = causal  # 设置因果关系标志
        self.softmax_scale = softmax_scale  # 设置softmax的缩放因子
        self.dropout_p = attention_dropout  # 设置注意力机制的dropout概率

    # 定义函数用于处理未填充的输入
    def unpad_input(self, hidden_states, attention_mask):
        valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)  # 获取有效的掩码
        seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)  # 计算每个batch的序列长度
        indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()  # 获取有效数据的索引
        max_seqlen_in_batch = seqlens_in_batch.max().item()  # 获取最大的序列长度
        cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))  # 计算累积序列长度
        hidden_states = hidden_states[indices]  # 获取有效的隐藏状态
        return hidden_states, indices, cu_seqlens, max_seqlen_in_batch  # 返回处理后的数据

    # 定义函数用于填充输入
    def pad_input(self, hidden_states, indices, batch, seqlen):
        output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
                             dtype=hidden_states.dtype)  # 创建用于存储填充后的

class QWenAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        # 用于处理masked位置的值
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

        # 序列长度
        self.seq_length = config.seq_length

        # 隐藏层大小
        self.hidden_size = config.hidden_size
        self.split_size = config.hidden_size

        # 多头注意力的头数
        self.num_heads = config.num_attention_heads
        # 每个头的维度
        self.head_dim = self.hidden_size // self.num_heads

        # 是否使用FlashAttention
        self.use_flash_attn = config.use_flash_attn
        # 是否对注意力权重进行缩放
        self.scale_attn_weights = True

        # 投影大小
        self.projection_size = config.kv_channels * config.num_attention_heads

        assert self.projection_size % config.num_attention_heads == 0
        # 每个注意力头的隐藏层大小
        self.hidden_size_per_attention_head = (
            self.projection_size // config.num_attention_heads
        )

        # 线性层,用于将输入转换为query, key, value
        self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)

        # 线性层,用于将注意力输出转换为最终输出
        self.c_proj = nn.Linear(
            config.hidden_size, self.projection_size, bias=not config.no_bias
        )

        # 是否使用32位浮点数
        self.is_fp32 = not (config.bf16 or config.fp16)
        # 如果使用FlashAttention,初始化FlashSelfAttention对象
        if (
            self.use_flash_attn
            and flash_attn_unpadded_func is not None
            and not self.is_fp32
        ):
            self.core_attention_flash = FlashSelfAttention(
                causal=True, attention_dropout=config.attn_dropout_prob
            )
        # 是否使用bfloat16
        self.bf16 = config.bf16

        # 是否使用动态NTK
        self.use_dynamic_ntk = config.use_dynamic_ntk
        # 是否使用logn注意力
        self.use_logn_attn = config.use_logn_attn

        # 预计算logn值,用于logn注意力
        logn_list = [
            math.log(i, self.seq_length) if i > self.seq_length else 1
            for i in range(1, 32768)
        ]
        logn_tensor = torch.tensor(logn_list)[None, :, None, None]
        self.register_buffer("logn_tensor", logn_tensor, persistent=False)

        # 注意力丢弃
        self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
        # 是否在32位浮点数中计算softmax
        self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
        # 是否使用缓存量化
        self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
        # 是否使用缓存核
        self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
        # 缓存的数据类型
        cache_dtype = torch.float
        if self.bf16:
            cache_dtype=torch.bfloat16
        elif config.fp16:
            cache_dtype = torch.float16
        # 缓存量化的最大值和最小值
        self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
        self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)

        # 如果使用缓存量化和缓存核,尝试导入缓存核函数
        if config.use_cache_quantization and config.use_cache_kernel:
            from .cpp_kernels import cache_autogptq_cuda_256
            try:
                self.cache_kernels = cache_autogptq_cuda_256
            except ImportError:
                self.cache_kernels = None

    def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
        # 获取query的设备
        device = query.device
        # 如果使用缓存量化
        if self.use_cache_quantization:
            # 拆分key为量化后的key,缩放因子和零点
            qk, qk_scale, qk_zero = key
            # 如果使用缓存核
            if self.use_cache_kernel and self.cache_kernels is not None:
                shape = query.shape[:-1] + (qk.shape[-2],)
                attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
                # 使用缓存核函数进行计算
                self.cache_kernels.vecquant8matmul_batched_faster_old(
                    query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
                    qk.transpose(-1, -2).contiguous(),
                    attn_weights,
                    qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
                    qk_zero.contiguous() if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
            else:
                # 如果不使用缓存核,先进行反量化
                key = dequantize_cache_torch(qk, qk_scale, qk_zero)
                attn_weights = torch.matmul(query, key.transpose(-1, -2))
        else:
            # 如果不使用缓存量化,直接进行矩阵乘法
            attn_weights = torch.matmul(query, key.transpose(-1, -2))

        # 如果需要对注意力权重进行缩放
        if self.scale_attn_weights:
            if self.use_cache_quantization:
                size_temp = value[0].size(-1)
            else:
                size_temp = value.size(-1)
            attn_weights = attn_weights / torch.full(
                [],
                size_temp ** 0.5,
                dtype=attn_weights.dtype,
                device=attn_weights.device,
            )
        # 计算因果掩码
        if self.use_cache_quantization:
            query_length, key_length = query.size(-2), key[0].size(-2)
        else:
            query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = registered_causal_mask[
            :, :, key_length - query_length : key_length, :key_length
        ]
        mask_value = torch.finfo(attn_weights.dtype).min
        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
            attn_weights.device
        )
        attn_weights = torch.where(
            causal_mask, attn_weights.to(attn_weights.dtype), mask_value
        )

        # 如果有额外的注意力掩码,加上它
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # 如果在32位浮点数中计算softmax
        if self.softmax_in_fp32:
            attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
        else:
            attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_weights = attn_weights.type(query.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # 如果有头部掩码,乘上它
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        # 如果使用缓存量化
        if self.use_cache_quantization:
            # 拆分value为量化后的value,缩放因子和零点
            qv, qv_scale, qv_zero = value
            # 如果使用缓存核
            if self.use_cache_kernel and self.cache_kernels is not None:
                shape = attn_weights.shape[:-1] + (query.shape[-1],)
                attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
                # 使用缓存核函数进行计算
                self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
                    attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
                    qv.contiguous(),  # dtype: int32
                    attn_output,
                    qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
                    qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
                if attn_output.dtype != query.dtype:
                    attn_output = attn_output.to(query.dtype)
                    attn_weights = attn_weights.to(query.dtype)
            else:
                # 如果不使用缓存核,先进行反量化
                value = dequantize_cache_torch(qv, qv_scale, qv_zero)
                attn_output = torch.matmul(attn_weights, value)
        else:
            # 如果不使用缓存量化,直接进行矩阵乘法
            attn_output = torch.matmul(attn_weights, value)

        attn_output = attn_output.transpose(1, 2)

        return attn_output, attn_weights

    def _upcast_and_reordered_attn(
        self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
    ):
        # 获取query的维度信息
        bsz, num_heads, q_seq_len, dk = query.size()
        _, _, k_seq_len, _ = key.size()

        # 初始化注意力权重矩阵
        attn_weights = torch.empty(
            bsz * num_heads,
            q_seq_len,
            k_seq_len,
            dtype=torch.float32,
            device=query.device,
        )

        # 计算缩放因子
        scale_factor = 1.0
        if self.scale_attn_weights:
            scale_factor /= float(value.size(-1)) ** 0.5

        # 使用autocast禁用自动类型转换
        with autocast(enabled=False):
            # 重塑query和key的维度
            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
                -1, dk, k_seq_len
            )
            # 计算注意力权重
            attn_weights = torch.baddbmm(
                attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
            )
            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

        # 计算因果掩码
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = registered_causal_mask[
            :, :, key_length - query_length : key_length, :key_length
        ]
        mask_value = torch.finfo(attn_weights.dtype).min
        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
            attn_weights.device
        )
        attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        # 如果有额外的注意力掩码,加上它
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # 计算softmax
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # 检查注意力权重的数据类型
        if attn_weights.dtype != torch.float32:
            raise RuntimeError(
                "Error with upcasting, attn_weights does not have dtype torch.float32"
            )
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # 如果有头部掩码,乘上它
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        # 计算注意力输出
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def _split_heads(self, tensor, num_heads, attn_head_size):
        # 重塑tensor的维度,分割成多个头
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(new_shape)
        return tensor

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        # 将多个头合并
        tensor = tensor.contiguous()
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
        registered_causal_mask: Optional[torch.Tensor] = None,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ):
        # 将隐藏状态通过一个线性变换层得到查询、键和值
        mixed_x_layer = self.c_attn(hidden_states)

        # 按照头的数量划分查询、键和值
        query, key, value = mixed_x_layer.split(self.split_size, dim=2)

        # 将查询、键和值划分为多个子矩阵
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # 如果提供了位置嵌入,将其应用到查询和键上
        if rotary_pos_emb_list is not None:
            cur_len = query.shape[1]
            if len(rotary_pos_emb_list) == 1:
                rotary_pos_emb = rotary_pos_emb_list[0]
                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
                rotary_pos_emb = (rotary_pos_emb,) * 2
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # 切片当前推理的位置嵌入
                query = apply_rotary_pos_emb(query, q_pos_emb)
                key = apply_rotary_pos_emb(key, k_pos_emb)
            else:
                query_list = []
                key_list = []
                for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
                    rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
                    rotary_pos_emb = (rotary_pos_emb,) * 2
                    q_pos_emb, k_pos_emb = rotary_pos_emb
                    # 切片当前推理的位置嵌入
                    query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
                    key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
                query = torch.cat(query_list, dim=0)
                key = torch.cat(key_list, dim=0)

        # 如果启用了缓存量化,将键和值量化
        if self.use_cache_quantization:
            key = quantize_cache_v(key.permute(0, 2, 1, 3),
                                       bits=8,
                                       qmin=self.cache_qmin,
                                       qmax=self.cache_qmax)
            value = quantize_cache_v(value.permute(0, 2, 1, 3),
                                         bits=8,
                                         qmin=self.cache_qmin,
                                         qmax=self.cache_qmax)

        if layer_past is not None:
            # 如果有上一层的历史信息,获取其键和值
            past_key, past_value = layer_past[0], layer_past[1]
            if self.use_cache_quantization:
                # 如果启用了缓存量化,更新当前的键和值
                key = (torch.cat((past_key[0], key[0]), dim=2),
                    torch.cat((past_key[1], key[1]), dim=2),
                    torch.cat((past_key[2], key[2]), dim=2))
                value = (torch.cat((past_value[0], value[0]), dim=2),
                        torch.cat((past_value[1], value[1]), dim=2),
                        torch.cat((past_value[2], value[2]), dim=2))
            else:
                # 如果未启用缓存量化,直接更新键和值
                key = torch.cat((past_key, key), dim=1)
                value = torch.cat((past_value, value), dim=1)

        # 根据是否使用缓存来确定输出的present
        if use_cache:
            present = (key, value)
        else:
            present = None

        # 如果启用了logn注意力且不在训练模式下
        if self.use_logn_attn and not self.training:
            if self.use_cache_quantization:
                seq_start = key[0].size(2) - query.size(1)
                seq_end = key[0].size(2)
            else:
                seq_start = key.size(1) - query.size(1)
                seq_end = key.size(1)
            # 调整查询的权重
            logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
            query = query * logn_tensor.expand_as(query)

        # 根据是否使用flash注意力和其他条件选择注意力机制
        if (self.use_flash_attn
                and flash_attn_unpadded_func is not None
                and not self.is_fp32
                and query.is_cuda):
            q, k, v = query, key, value
            attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
        else:
            query = query.permute(0, 2, 1, 3)
            if not self.use_cache_quantization:
                key = key.permute(0, 2, 1, 3)
                value = value.permute(0, 2, 1, 3)
            if (registered_causal_mask is None
                    and self.use_flash_attn
                    and flash_attn_unpadded_func is not None
                    and not self.is_fp32
                    and not query.is_cuda):
                raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)

            # 使用标准的点积注意力机制或自定义的注意力机制
            if not self.use_cache_quantization and SUPPORT_TORCH2:
                causal_mask = registered_causal_mask[
                    :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
                ]
                if attention_mask is not None:
                    attention_mask = attention_mask.expand(
                        -1, -1, causal_mask.size(2), -1
                    ).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
                else:
                    attention_mask = causal_mask
                attn_output = F.scaled_dot_product_attention(
                    query, key, value, attn_mask=attention_mask
                ).transpose(1, 2)
                attn_weight = None
            else:
                attn_output, attn_weight = self._attn(
                    query, key, value, registered_causal_mask, attention_mask, head_mask
                )

        # 合并注意力输出
        context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)

        # 通过一个线性变换层得到最终的输出
        attn_output = self.c_proj(context_layer)

        outputs = (attn_output, present)
        if output_attentions:
            if (self.use_flash_attn
                    and flash_attn_unpadded_func is not None
                    and not self.is_fp32):
                raise ValueError("Cannot output attentions while using flash-attn")
            else:
                outputs += (attn_weight,)

        return outputs



class QWenMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化两个线性变换层
        self.w1 = nn.Linear(
            config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
        )
        self.w2 = nn.Linear(
            config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
        )
        ff_dim_in = config.intermediate_size // 2
        # 初始化一个投影层
        self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)

    def forward(self, hidden_states):
        # 通过两个线性变换层
        a1 = self.w1(hidden_states)
        a2 = self.w2(hidden_states)
        # 计算中间结果,并通过激活函数
        intermediate_parallel = a1 * F.silu(a2)
        # 通过投影层得到最终结果
        output = self.c_proj(intermediate_parallel)
        return output


class QWenBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        hidden_size = config.hidden_size
        self.bf16 = config.bf16

        # 初始化两个归一化层
        self.ln_1 = RMSNorm(
            hidden_size,
            eps=config.layer_norm_epsilon,
        )
        self.attn = QWenAttention(config)
        self.ln_2 = RMSNorm(
            hidden_size,
            eps=config.layer_norm_epsilon,
        )

        # 初始化一个多层感知机
        self.mlp = QWenMLP(config)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
        registered_causal_mask: Optional[torch.Tensor] = None,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ):
        # 通过第一个归一化层
        layernorm_output = self.ln_1(hidden_states)

        # 通过注意力层
        attn_outputs = self.attn(
            layernorm_output,
            rotary_pos_emb_list,
            registered_causal_mask=registered_causal_mask,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]

        outputs = attn_outputs[1:]

        # 计算残差连接
        residual = hidden_states
        layernorm_input = attn_output + residual

        # 通过第二个归一化层
        layernorm_output = self.ln_2(layernorm_input)

        # 计算残差连接
        residual = layernorm_input
        mlp_output = self.mlp(layernorm_output)
        hidden_states = residual + mlp_output

        # 根据是否使用缓存来确定输出
        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs


class QWenPreTrainedModel(PreTrainedModel):
    # 设置模型配置类和基础模型前缀
    config_class = QWenConfig
    base_model_prefix = "transformer"
    is_parallelizable = False
    supports_gradient_checkpointing = True
    _no_split_modules = ["QWenBlock"]

    def __init__(self, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        """初始化模型权重."""
        if isinstance(module, nn.Linear):
            # 如果模块是线性层,则权重初始化为正态分布,偏差初始化为0
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 如果模块是嵌入层,则权重初始化为正态分布,padding位置的权重初始化为0
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, RMSNorm):
            # 如果模块是RMSNorm,则权重初始化为1
            module.weight.data.fill_(1.0)

        for name, p in module.named_parameters():
            # 如果模块的名字是"c_proj.weight",则权重初始化为正态分布,标准差调整为config中的initializer_range除以根号下的2倍的隐藏层数
            if name == "c_proj.weight":
                p.data.normal_(
                    mean=0.0,
                    std=(
                        self.config.initializer_range
                        / math.sqrt(2 * self.config.num_hidden_layers)
                    ),
                )

    def _set_gradient_checkpointing(self, module, value=False):
        # 设置梯度检查点
        if isinstance(module, QWenModel):
            module.gradient_checkpointing = value


class QWenModel(QWenPreTrainedModel):
    # 在加载时忽略的键值
    _keys_to_ignore_on_load_missing = ["attn.masked_bias"]

    def __init__(self, config):
        super().__init__(config)
        self.vocab_size = config.vocab_size
        self.num_hidden_layers = config.num_hidden_layers
        self.embed_dim = config.hidden_size
        # 判断是否使用缓存量化,如果配置中有此属性则使用,否则默认为False
        self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False

        self.gradient_checkpointing = False
        self.use_dynamic_ntk = config.use_dynamic_ntk
        self.seq_length = config.seq_length

        # 创建词嵌入层
        self.wte = nn.Embedding(self.vocab_size, self.embed_dim)

        # 创建dropout层
        self.drop = nn.Dropout(config.emb_dropout_prob)

        # 判断是否使用rotary编码
        if config.rotary_pct == 1.0:
            self.rotary_ndims = None
        else:
            assert config.rotary_pct < 1
            self.rotary_ndims = int(
                config.kv_channels * config.rotary_pct
            )
        dim = (
            self.rotary_ndims
            if self.rotary_ndims is not None
            else config.kv_channels
        )
        self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)

        self.use_flash_attn = config.use_flash_attn
        self.is_fp32 = not (config.bf16 or config.fp16)
        if (
            self.use_flash_attn
            and flash_attn_unpadded_func is not None
            and not self.is_fp32
        ):
            self.registered_causal_mask = None
        else:
            max_positions = config.max_position_embeddings
            # 注册因果掩码
            self.register_buffer(
                "registered_causal_mask",
                torch.tril(
                    torch.ones((max_positions, max_positions), dtype=torch.bool)
                ).view(1, 1, max_positions, max_positions),
                persistent=False,
            )

        # 创建多个QWenBlock层
        self.h = nn.ModuleList(
            [
                QWenBlock(
                    config
                )
                for i in range(config.num_hidden_layers)
            ]
        )
        # 创建RMSNorm层
        self.ln_f = RMSNorm(
            self.embed_dim,
            eps=config.layer_norm_epsilon,
        )

        # 进行模型初始化
        self.post_init()

    def get_input_embeddings(self):
        # 获取输入的词嵌入
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        # 设置新的词嵌入
        self.wte = new_embeddings

    def get_ntk_alpha(self, true_seq_len):
        # 获取ntk的alpha值
        context_value = math.log(true_seq_len / self.seq_length, 2) + 1
        ntk_alpha = 2 ** math.ceil(context_value) - 1
        ntk_alpha = max(ntk_alpha, 1)
        return ntk_alpha

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            token_type_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ):
        # 设置输出attention的标志位
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        # 设置输出隐藏层的标志位
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        # 设置使用缓存的标志位
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        # 设置返回字典的标志位
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # 如果同时给定了input_ids和inputs_embeds,则报错
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        # 如果给定了input_ids,则获取输入的形状和batch_size
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        # 如果给定了inputs_embeds,则获取输入的形状和batch_size
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        # 如果没有给定input_ids和inputs_embeds,则报错
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # 获取设备信息
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # 如果给定了token_type_ids和position_ids,则调整其形状
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

        # 如果没有给定past_key_values,则设置past_length为0,并创建一个空的past_key_values
        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        # 如果给定了past_key_values,则获取past_length
        else:
            if self.use_cache_quantization:
                past_length = past_key_values[0][0][0].size(2)
            else:
                past_length = past_key_values[0][0].size(-2)
        # 如果没有给定position_ids,则创建一个position_ids
        if position_ids is None:
            position_ids = torch.arange(
                past_length,
                input_shape[-1] + past_length,
                dtype=torch.long,
                device=device,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

        # 如果给定了attention_mask,则调整其形状
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=self.dtype)
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        # 设置encoder_attention_mask和head_mask
        encoder_attention_mask = None
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # 如果没有给定inputs_embeds,则创建一个inputs_embeds
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        hidden_states = inputs_embeds

        # 计算kv_seq_len,如果有past_key_values,则加上其长度
        kv_seq_len = hidden_states.size()[1]
        if past_key_values[0] is not None:
            # past key values[0][0] shape: bs * seq_len * head_num * dim
            if self.use_cache_quantization:
                kv_seq_len += past_key_values[0][0][0].shape[2]
            else:
                kv_seq_len += past_key_values[0][0].shape[1]

        # 计算ntk_alpha_list
        if self.training or not self.use_dynamic_ntk:
            ntk_alpha_list = [1.0]
        elif kv_seq_len != hidden_states.size()[1]:
            ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
        else:
            ntk_alpha_list = []
            if attention_mask is not None and kv_seq_len > self.seq_length:
                true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
                for i in range(hidden_states.size()[0]):
                    true_seq_len = true_seq_lens[i].item()
                    ntk_alpha = self.get_ntk_alpha(true_seq_len)
                    ntk_alpha_list.append(ntk_alpha)
            else:
                ntk_alpha = self.get_ntk_alpha(kv_seq_len)
                ntk_alpha_list.append(ntk_alpha)
        self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list

        # 计算rotary_pos_emb_list
        rotary_pos_emb_list = [
            self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
        ]

        # dropout
        hidden_states = self.drop(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)

        # 梯度检查点
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # 初始化presents、all_self_attentions和all_hidden_states
        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        # 对每一层进行处理
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache, output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    rotary_pos_emb_list,
                    self.registered_causal_mask,
                    None,
                    attention_mask,
                    head_mask[i],
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    rotary_pos_emb_list=rotary_pos_emb_list,
                    registered_causal_mask=self.registered_causal_mask,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

            # 更新hidden_states
            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

        # 对hidden_states进行层归一化
        hidden_states = self.ln_f(hidden_states)
        hidden_states = hidden_states.view(output_shape)

        # 添加最后的隐藏层状态
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v for v in [hidden_states, presents, all_hidden_states] if v is not None
            )

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class QWenLMHeadModel(QWenPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
    _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]

    def __init__(self, config):
        super().__init__(config)
        # 确保"bf16", "fp16", "fp32"中只有一个是True
        assert (
            config.bf16 + config.fp16 + config.fp32 <= 1
        ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"

        # 提示用户使用最新的代码和模型
        logger.warn(
            "Warning: please make sure that you are using the latest codes and checkpoints, "
            "especially if you used Qwen-7B before 09.25.2023."
            "请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。"
        )

        # 自动设置精度
        autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
        if autoset_precision:
            if SUPPORT_BF16:
                logger.warn(
                    "The model is automatically converting to bf16 for faster inference. "
                    "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
                )
                config.bf16 = True
            elif SUPPORT_FP16:
                logger.warn(
                    "The model is automatically converting to fp16 for faster inference. "
                    "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
                )
                config.fp16 = True
            else:
                config.fp32 = True

        # 检查设备对bf16和fp16的支持情况
        if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
            logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
        if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
            logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
        if config.fp32:
            if SUPPORT_BF16:
                logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
            elif SUPPORT_FP16:
                logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")

        # 检查是否使用flash attention
        if config.use_flash_attn == "auto":
            if config.bf16 or config.fp16:
                logger.warn("Try importing flash-attention for faster inference...")
                config.use_flash_attn = True
            else:
                config.use_flash_attn = False
        if config.use_flash_attn and config.fp32:
            logger.warn("Flash attention will be disabled because it does NOT support fp32.")

        # 如果使用flash attention,则尝试导入相关模块
        if config.use_flash_attn:
            _import_flash_attn()

        # 初始化transformer和lm_head
        self.transformer = QWenModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 设置精度
        if config.bf16:
            self.transformer.bfloat16()
            self.lm_head.bfloat16()
        if config.fp16:
            self.transformer.half()
            self.lm_head.half()
        self.post_init()

    def get_output_embeddings(self):
        # 获取输出嵌入
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置输出嵌入
        self.lm_head = new_embeddings

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ):
        # 为生成准备输入
        token_type_ids = kwargs.get("token_type_ids", None)
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)
        else:
            position_ids = None

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "position_ids": position_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
            }
        )
        return model_inputs

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        # 前向传播
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:

        return tuple(
            tuple(
                past_state.index_select(0, beam_idx.to(past_state.device))
                for past_state in layer_past
            )
            for layer_past in past_key_values
        )

    @staticmethod
    def _reorder_cache(
        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        根据给定的beam索引重新排序past_key_values。
        """
        return tuple(
            tuple(
                past_state.index_select(0, beam_idx.to(past_state.device))
                for past_state in layer_past
            )
            for layer_past in past_key_values
        )

    def chat(
        self,
        tokenizer: PreTrainedTokenizer,
        query: str,
        history: Optional[HistoryType],
        system: str = "You are a helpful assistant.",
        stream: Optional[bool] = _SENTINEL,
        stop_words_ids: Optional[List[List[int]]] = None,
        generation_config: Optional[GenerationConfig] = None,
        **kwargs,
    ) -> Tuple[str, HistoryType]:
        """
        与模型聊天的接口。
        """
        # 确保stream和generation_config.chat_format是预期的值。
        generation_config = generation_config if generation_config is not None else self.generation_config
        assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
        assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT

        # 如果未提供历史记录,则创建一个空的历史记录列表。
        # 否则,将对用户输入的历史记录进行深度复制,以确保原始输入保持不变。
        if history is None:
            history = []
        else:
            history = copy.deepcopy(history)

        # 如果未提供stop_words_ids,则创建一个空的列表。
        if stop_words_ids is None:
            stop_words_ids = []

        # 获取max_window_size的值。
        max_window_size = kwargs.get('max_window_size', None)
        if max_window_size is None:
            max_window_size = generation_config.max_window_size

        # 使用make_context函数为给定的查询和历史创建一个上下文。
        raw_text, context_tokens = make_context(
            tokenizer,
            query,
            history=history,
            system=system,
            max_window_size=max_window_size,
            chat_format=generation_config.chat_format,
        )

        # 添加停止词的ID。
        stop_words_ids.extend(get_stop_words_ids(
            generation_config.chat_format, tokenizer
        ))

        # 转换为模型所需的输入格式。
        input_ids = torch.tensor([context_tokens]).to(self.device)
        outputs = self.generate(
                    input_ids,
                    stop_words_ids=stop_words_ids,
                    return_dict_in_generate=False,
                    generation_config=generation_config,
                    **kwargs,
                )

        # 解码生成的响应。
        response = decode_tokens(
            outputs[0],
            tokenizer,
            raw_text_len=len(raw_text),
            context_length=len(context_tokens),
            chat_format=generation_config.chat_format,
            verbose=False,
            errors='replace'
        )

        # 将新的查询-响应对添加到历史记录中,并返回响应和更新的历史记录。
        history.append((query, response))
        return response, history

    def chat_stream(
            self,
            tokenizer: PreTrainedTokenizer,
            query: str,
            history: Optional[HistoryType],
            system: str = "You are a helpful assistant.",
            stop_words_ids: Optional[List[List[int]]] = None,
            logits_processor: Optional[LogitsProcessorList] = None,
            generation_config: Optional[GenerationConfig] = None,
            **kwargs,
    ) -> Generator[str, Any, None]:
        """
        与模型进行流式聊天的接口,以实现更实时的响应生成。
        """
        # 确保generation_config.chat_format是预期的值。
        generation_config = generation_config if generation_config is not None else self.generation_config
        assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT

        # 如果未提供历史记录,则创建一个空的历史记录列表。
        if history is None:
            history = []

        # 如果未提供stop_words_ids,则创建一个空的列表。
        if stop_words_ids is None:
            stop_words_ids = []

        # 获取max_window_size的值。
        max_window_size = kwargs.get('max_window_size', None)
        if max_window_size is None:
            max_window_size = generation_config.max_window_size

        # 使用make_context函数为给定的查询和历史创建一个上下文。
        raw_text, context_tokens = make_context(
            tokenizer,
            query,
            history=history,
            system=system,
            max_window_size=max_window_size,
            chat_format=generation_config.chat_format,
        )

        # 添加停止词的ID。
        stop_words_ids.extend(get_stop_words_ids(
            generation_config.chat_format, tokenizer
        ))

        # 如果提供了停止词的ID,创建一个StopWordsLogitsProcessor并添加到logits_processor中。
        if stop_words_ids is not None:
            stop_words_logits_processor = StopWordsLogitsProcessor(
                stop_words_ids=stop_words_ids,
                eos_token_id=generation_config.eos_token_id,
            )
            if logits_processor is None:
                logits_processor = LogitsProcessorList([stop_words_logits_processor])
            else:
                logits_processor.append(stop_words_logits_processor)

        # 转换为模型所需的输入格式。
        input_ids = torch.tensor([context_tokens]).to(self.device)

        # 导入用于流式生成的模块和配置。
        from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
        self.__class__.generate_stream = NewGenerationMixin.generate
        self.__class__.sample_stream = NewGenerationMixin.sample_stream
        stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)

        def stream_generator():
            outputs = []
            for token in self.generate_stream(
                    input_ids,
                    return_dict_in_generate=False,
                    generation_config=stream_config,
                    logits_processor=logits_processor,
                    seed=-1,
                    **kwargs):
                outputs.append(token.item())
                yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')

        return stream_generator()

    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[
            Callable[[int, torch.Tensor], List[int]]
        ] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        """
        生成文本的主要方法。
        """
        # 确保generation_config不为None,如果为None则使用模型的默认配置。
        generation_config = generation_config if generation_config is not None else self.generation_config

        # 处理停止词ID。
        stop_words_ids = kwargs.pop("stop_words_ids", None)
        if stop_words_ids is None and generation_config is not None:
            stop_words_ids = getattr(generation_config, "stop_words_ids", None)
        if stop_words_ids is None:
            stop_words_ids = getattr(generation_config, "stop_words_ids", None)

        # 如果提供了停止词ID,创建一个StopWordsLogitsProcessor并添加到logits_processor中。
        if stop_words_ids is not None:
            stop_words_logits_processor = StopWordsLogitsProcessor(
                stop_words_ids=stop_words_ids,
                eos_token_id=generation_config.eos_token_id,
            )
            if logits_processor is None:
                logits_processor = LogitsProcessorList([stop_words_logits_processor])
            else:
                logits_processor.append(stop_words_logits_processor)

        # 调用父类的generate方法生成文本。
        return super().generate(
            inputs,
            generation_config=generation_config,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            synced_gpus=synced_gpus,
            assistant_model=assistant_model,
            streamer=streamer,
            **kwargs,
        )


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000):
        """
        Rotary Embedding初始化方法。

        参数:
            dim: int, 嵌入维度。
            base: int, 基数,默认为10000。
        """
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        if importlib.util.find_spec("einops") is None:
            raise RuntimeError("einops is required for Rotary Embedding")

        self._rotary_pos_emb_cache = None
        self._seq_len_cached = 0
        self._ntk_alpha_cached = 1.0
        self._ntk_alpha_cached_list = [1.0]

    def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
        """
        更新Rotary Position Embedding的缓存。

        参数:
            max_seq_len: int, 最大序列长度。
            offset: int, 偏移量,默认为0。
            ntk_alpha: float, NTK alpha值,默认为1.0。
        """
        seqlen = max_seq_len + offset
        if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
            base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
            self.inv_freq = 1.0 / (
                base
                ** (
                    torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
                    / self.dim
                )
            )
            self._seq_len_cached = max(2 * seqlen, 16)
            self._ntk_alpha_cached = ntk_alpha
            seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
            freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)

            emb = torch.cat((freqs, freqs), dim=-1)
            from einops import rearrange

            emb = rearrange(emb, "n d -> 1 n 1 d")

            cos, sin = emb.cos(), emb.sin()
            self._rotary_pos_emb_cache = [cos, sin]

    def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
        """
        前向传播方法。

        参数:
            max_seq_len: int, 最大序列长度。
            offset: int, 偏移量,默认为0。
            ntk_alpha: float, NTK alpha值,默认为1.0。

        返回:
            list, 包含cos和sin的列表。
        """
        self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
        cos, sin = self._rotary_pos_emb_cache
        return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]


def _rotate_half(x):
    """
    将x旋转半圈。

    参数:
        x: Tensor, 输入张量。

    返回:
        Tensor, 旋转后的张量。
    """
    from einops import rearrange

    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(t, freqs):
    """
    应用Rotary Position Embedding。

    参数:
        t: Tensor, 输入张量。
        freqs: list, 包含cos和sin的列表。

    返回:
        Tensor, 应用Rotary Position Embedding后的张量。
    """
    cos, sin = freqs
    if apply_rotary_emb_func is not None and t.is_cuda:
        t_ = t.float()
        cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
        sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
        output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
        return output
    else:
        rot_dim = freqs[0].shape[-1]
        cos, sin = freqs
        t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
        t_ = t_.float()
        t_pass_ = t_pass_.float()
        t_ = (t_ * cos) + (_rotate_half(t_) * sin)
        return torch.cat((t_, t_pass_), dim=-1).type_as(t)


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        RMSNorm初始化方法。

        参数:
            dim: int, 嵌入维度。
            eps: float, 避免除以零的小数,默认为1e-6。
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        计算RMSNorm。

        参数:
            x: Tensor, 输入张量。

        返回:
            Tensor, 计算后的张量。
        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        前向传播方法。

        参数:
            x: Tensor, 输入张量。

        返回:
            Tensor, 前向传播后的张量。
        """
        if rms_norm is not None and x.is_cuda:
            return rms_norm(x, self.weight, self.eps)
        else:
            output = self._norm(x.float()).type_as(x)
            return output * self.weight

Reference

Qwen/Qwen-7B-Chat

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
站点统计

本站现有博文242篇,共被浏览292004

本站已经建立1782天!

热门文章
文章归档
回到顶部