EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Jetson Disk Linux RGB Dataset 多进程 Miniforge 图标 XGBoost git-lfs Bin Vim Windows Breakpoint LeetCode SPIE Streamlit Claude InvalidArgumentError Knowledge Use UNIX Zip Hungarian Diagram Qwen2.5 FastAPI Augmentation Google Image2Text LLAMA NLP Domain Cloudreve Pickle Quantization FP16 VPN CSV Web CEIR GoogLeNet ModelScope hf Animate Excel LaTeX Interview 腾讯云 TTS 版权 RAR Video Docker 搞笑 Git Django Permission Ubuntu OCR YOLO 算法题 Data UI SAM Hilton Bipartite 报税 Paddle CC SQLite transformers diffusers PDF printf JSON Github Freesound Crawler Baidu WAN scipy Math CUDA Transformers C++ Password Datetime SQL Quantize Bert OpenAI Base64 EXCEL NameSilo torchinfo Hotel TensorFlow Michelin PDB Anaconda tar Food uwsgi NLTK Pandas Color 关于博主 Agent llama.cpp Conda BTC ResNet-50 LLM CV 域名 GPT4 mmap TSV Tiktoken Jupyter PyTorch Website FP32 Clash Vmess Numpy 公式 tqdm Template Search Paper MD5 BeautifulSoup v2ray DeepSeek XML 净利润 第一性原理 Statistics OpenCV VGG-16 Qwen2 News FP64 git Logo Mixtral Markdown 飞书 Distillation Tracking Algorithm VSCode Pytorch FlashAttention Sklearn Gemma Attention Shortcut IndexTTS2 阿里云 签证 uWSGI Safetensors Plotly Tensor Translation 顶会 GGML SVR AI 多线程 COCO PIP 继承 图形思考法 ONNX Ptyhon Card TensorRT Python GIT API 证件照 HaggingFace 音频 CTC 强化学习 GPTQ Magnet Plate Bitcoin Qwen Firewall 云服务器 财报 FP8 Review icon ChatGPT Input Land DeepStream Pillow BF16 Random CLAP 递归学习法 HuggingFace Proxy PyCharm CAM QWEN Llama WebCrawler Nginx v0.dev logger Heatmap LoRA
站点统计

本站现有博文322篇,共被浏览792959

本站已经建立2490天!

热门文章
文章归档
回到顶部