EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
logger mmap AI Miniforge Baidu Transformers Paper Attention Base64 递归学习法 BeautifulSoup Plotly GIT Statistics Windows uWSGI scipy 顶会 Knowledge PDB LoRA Github Sklearn Vim Rebuttal WebCrawler Vmess OCR Firewall PyCharm Numpy TensorRT Git Conda v0.dev FP64 签证 多进程 Google Magnet torchinfo Food Tracking FlashAttention XML CUDA WAN Mixtral 关于博主 EXCEL Hungarian Streamlit Pandas 腾讯云 diffusers CAM Augmentation 继承 Pillow Anaconda Disk Qwen2.5 CEIR TensorFlow BTC ONNX Hilton tar 证件照 Plate Template NLTK SPIE Freesound Jetson Algorithm Linux Pickle Claude Crawler NameSilo C++ Michelin FP16 Cloudreve UI Animate 版权 PIP 云服务器 Use XGBoost GPT4 llama.cpp Pytorch SAM Image2Text GPTQ PyTorch VPN Quantization hf Heatmap Domain PDF CLAP Random BF16 Math Ubuntu SQLite Search Markdown Land 飞书 MD5 OpenCV Jupyter CV JSON ResNet-50 IndexTTS2 GGML 财报 YOLO 域名 InvalidArgumentError 图标 Permission HaggingFace tqdm FastAPI LLM FP8 ChatGPT Tensor FP32 Data API RGB NLP Diagram Input Tiktoken QWEN Agent CC transformers Quantize Datetime Dataset TSV Logo icon Llama 论文 算法题 UNIX v2ray git-lfs Website SQL 强化学习 Distillation Excel 公式 Review VSCode 阿里云 COCO Paddle 图形思考法 CTC Password Video Proxy Bitcoin Hotel OpenAI Zip LaTeX Python 音频 printf 第一性原理 Qwen2 Shortcut git DeepStream 论文速读 Translation Django Gemma VGG-16 净利润 Bin 搞笑 Clash Ptyhon LLAMA Interview Safetensors DeepSeek SVR News LeetCode Breakpoint CSV Bipartite Bert uwsgi HuggingFace 多线程 Docker Web GoogLeNet TTS Nginx ModelScope Color Card Qwen RAR 报税
站点统计

本站现有博文328篇,共被浏览850667

本站已经建立2557天!

热门文章
文章归档
回到顶部