EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
LLAMA Attention FP16 FP64 SQLite transformers 域名 Mixtral Disk Markdown TTS Vim Dataset Streamlit Excel Random Distillation PDF WebCrawler Bin CTC Template Tiktoken 版权 Paddle Docker CSV 公式 Algorithm Ptyhon FP8 Numpy Llama SPIE HaggingFace CC FastAPI 算法题 Shortcut DeepSeek 多线程 Diagram LeetCode GGML Image2Text Gemma CAM FP32 BF16 hf NameSilo 财报 LLM Qwen2 PyTorch EXCEL Qwen2.5 Jupyter GIT ModelScope Hungarian WAN ChatGPT Math Bert TensorRT diffusers 腾讯云 Linux NLTK uWSGI Food API Qwen printf Miniforge GPT4 Anaconda IndexTTS2 YOLO Datetime 阿里云 Website 云服务器 Agent Pytorch Video RGB COCO 强化学习 证件照 AI Vmess Django Knowledge Logo CV Tensor Safetensors Tracking Michelin Pickle Use Cloudreve OCR UI OpenAI Github 递归学习法 Baidu Transformers MD5 搞笑 BeautifulSoup VPN Python uwsgi CEIR Zip Domain Sklearn Plotly Proxy LoRA git-lfs FlashAttention RAR Statistics Review 继承 Crawler Password Augmentation 签证 Permission CUDA tqdm Bitcoin ResNet-50 Quantize PIP Ubuntu Clash SAM VSCode Pillow Base64 mmap C++ PyCharm Paper Claude Breakpoint logger OpenCV Firewall Windows QWEN Freesound Jetson XML tar 报税 SQL HuggingFace 飞书 Hotel GoogLeNet Plate Google Input Bipartite 图形思考法 Interview Land Git git Translation v0.dev v2ray LaTeX UNIX scipy DeepStream llama.cpp torchinfo Quantization 净利润 TSV ONNX Nginx 多进程 关于博主 Web JSON InvalidArgumentError BTC Card CLAP Color XGBoost Animate 顶会 第一性原理 Hilton Search Data Magnet Pandas 音频 PDB NLP Heatmap News TensorFlow GPTQ SVR VGG-16 Conda
站点统计

本站现有博文321篇,共被浏览766752

本站已经建立2447天!

热门文章
文章归档
回到顶部