EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Video 版权 Safetensors Food Land Clash Input Tracking Qwen Hotel AI Ubuntu 搞笑 OpenAI Sklearn PyTorch Color Interview Shortcut 算法题 Mixtral icon Template BF16 XML Augmentation LaTeX DeepSeek Pandas Review Anaconda Docker TTS Pickle CUDA CTC VPN Paper Card Miniforge 报税 Ptyhon Numpy Bipartite Zip Vmess TSV VSCode 第一性原理 uwsgi Hungarian Algorithm Git JSON Agent GoogLeNet Excel Plate Attention Knowledge GGML Diagram Streamlit 净利润 C++ Breakpoint 公式 NameSilo PIP WebCrawler LeetCode RAR Github 多线程 阿里云 继承 Baidu scipy FP64 Password HaggingFace DeepStream Datetime Dataset 证件照 llama.cpp 关于博主 Tiktoken Rebuttal SPIE NLP IndexTTS2 Data 顶会 ModelScope FastAPI EXCEL torchinfo tar FP16 Django v2ray Bin Bitcoin 财报 Crawler Heatmap 图形思考法 QWEN Llama Paddle ResNet-50 Michelin Permission GIT Quantize Firewall v0.dev FP8 Google COCO API Image2Text Search 强化学习 Plotly News tqdm 递归学习法 Math Logo XGBoost Tensor FlashAttention Transformers 腾讯云 Web PDB Translation BeautifulSoup OpenCV Claude git Animate LLM 图标 飞书 CAM VGG-16 Python OCR Hilton Gemma mmap BTC NLTK MD5 printf CSV Magnet CC Qwen2 uWSGI Domain Pytorch Distillation ONNX Nginx Cloudreve 多进程 Vim Jetson transformers Use diffusers SAM FP32 hf Markdown TensorFlow RGB Website UNIX Linux 签证 CV LLAMA logger CEIR 域名 PyCharm Qwen2.5 Base64 WAN ChatGPT Windows Jupyter Quantization Conda HuggingFace SQL Proxy Random PDF GPT4 SVR UI CLAP 云服务器 GPTQ LoRA Disk Pillow git-lfs InvalidArgumentError Bert SQLite Freesound TensorRT Statistics YOLO 音频
站点统计

本站现有博文324篇,共被浏览808773

本站已经建立2511天!

热门文章
文章归档
回到顶部