EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Python XGBoost 多进程 git Crawler ONNX Hilton SAM Review Search Jupyter 音频 git-lfs Video Freesound SQLite Bin tqdm FP16 递归学习法 Password v2ray Michelin Interview YOLO 报税 Github 净利润 LaTeX Base64 FP32 BeautifulSoup Gemma 阿里云 VGG-16 FP8 Dataset 继承 Jetson 多线程 InvalidArgumentError Miniforge v0.dev Image2Text Datetime Baidu Algorithm Llama diffusers Logo GIT GGML UI llama.cpp VSCode Bitcoin BF16 scipy Food NameSilo Firewall ms-swift Anaconda 搞笑 Template Disk UNIX XML OCR Data Ubuntu tar GoogLeNet OpenAI Knowledge CAM ModelScope News Permission PDF 财报 Conda Qwen2.5 Linux 关于博主 图形思考法 CLAP Quantization Docker Numpy Claude CV Ptyhon CSV Windows LLAMA 签证 Plotly GPTQ Quantize Statistics Cloudreve JSON WAN Zip DeepSeek RGB Bert icon FlashAttention Card printf Qwen Land Rebuttal Use RL Animate NLTK PyTorch PDB Paper Input WebCrawler Django ResNet-50 RAR 论文 Random CEIR Distillation Shortcut uwsgi HaggingFace Heatmap MD5 Nginx Markdown SQL 证件照 Tracking Google Breakpoint 强化学习 PIP Website Pytorch Mixtral Domain PyCharm hf SVR Tensor 飞书 DeepStream TTS Plate torchinfo Math 版权 logger 图标 TSV LLM IndexTTS2 LeetCode FastAPI VPN Attention transformers Sklearn Hotel Streamlit Pillow Qwen2 GPT4 mmap Color Web Vim Paddle 腾讯云 BTC Agent 算法题 Diagram 第一性原理 AI 公式 Translation Transformers Clash CUDA uWSGI API Vmess LoRA 云服务器 Git FP64 Tiktoken Hungarian 域名 Pandas QWEN CC ChatGPT SPIE Excel TensorFlow Augmentation 顶会 Proxy HuggingFace EXCEL TensorRT Bipartite Magnet COCO C++ 论文速读 NLP Pickle CTC OpenCV Safetensors
站点统计

本站现有博文332篇,共被浏览868910

本站已经建立2577天!

热门文章
文章归档
回到顶部