EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Food 域名 Algorithm PIP Plate RAR HuggingFace Translation Python Data Pytorch WebCrawler logger Knowledge Crawler NameSilo 公式 Github VSCode MD5 Miniforge 继承 多进程 Markdown Breakpoint 报税 JSON OpenAI 签证 Gemma git 算法题 Base64 Proxy TTS Django CC LeetCode Vmess Numpy Pickle Jupyter Bert CSV QWEN Magnet GPT4 Conda Datetime Random IndexTTS2 Linux TSV PDF Firewall Video Transformers Ptyhon Excel FP64 Zip Card VPN SQLite API FastAPI Baidu 递归学习法 v2ray FP16 diffusers Sklearn mmap LoRA Llama BTC PyCharm Nginx Search CLAP News CEIR hf DeepSeek UI 财报 Quantize WAN Anaconda Math Animate uwsgi EXCEL Use COCO SAM torchinfo printf Augmentation Qwen2.5 净利润 v0.dev Cloudreve PyTorch Website Bipartite Google Land Color Windows Tensor TensorFlow Bin llama.cpp Docker Qwen2 图形思考法 Hungarian Michelin scipy OCR Hotel CV HaggingFace ModelScope Ubuntu BeautifulSoup Disk ChatGPT NLTK FP8 Distillation InvalidArgumentError LLAMA RGB CAM Git Password tqdm Bitcoin 关于博主 DeepStream FlashAttention C++ LLM tar TensorRT Input VGG-16 Pandas 第一性原理 Pillow SVR Tracking Mixtral Paddle Shortcut Plotly transformers Safetensors BF16 Image2Text LaTeX Diagram Quantization Logo SPIE Hilton 阿里云 Vim 腾讯云 OpenCV Template XML uWSGI Claude 证件照 GoogLeNet 多线程 Qwen 顶会 ONNX GGML 版权 Freesound Paper YOLO CTC Tiktoken CUDA ResNet-50 Heatmap AI FP32 Clash 音频 GIT PDB Web GPTQ Interview UNIX git-lfs Agent 搞笑 Streamlit Jetson SQL Permission 强化学习 Review NLP Domain XGBoost 飞书 Statistics Dataset Attention
站点统计

本站现有博文320篇,共被浏览759179

本站已经建立2427天!

热门文章
文章归档
回到顶部