EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Ubuntu VGG-16 hf tqdm UNIX Agent Logo Food Qwen FP32 logger Python printf LoRA Jupyter Input PIP GIT transformers Excel Django Vim Crawler Bin WebCrawler Statistics C++ Datetime Numpy Attention 音频 域名 mmap Review 腾讯云 Firewall Distillation DeepStream DeepSeek uWSGI YOLO Color Sklearn LLM BF16 阿里云 Algorithm uwsgi Tracking Pytorch ONNX OCR RGB Dataset Random SAM Hungarian News Cloudreve FP16 BeautifulSoup ModelScope Plotly Template 证件照 SVR AI TTS Llama ChatGPT PDB InvalidArgumentError 强化学习 Video scipy Jetson torchinfo 递归学习法 Augmentation 飞书 顶会 Proxy Website Shortcut WAN Miniforge 公式 GoogLeNet Transformers Data Use RAR GGML Ptyhon Docker 报税 Github Zip 财报 LaTeX llama.cpp git Disk SQLite Vmess Conda 多进程 Tensor CLAP NLP Quantize OpenCV PyTorch Michelin v2ray MD5 XGBoost LLAMA Hotel Breakpoint Image2Text Plate IndexTTS2 CTC Base64 Gemma Tiktoken Windows CUDA Permission Bipartite 图形思考法 NLTK Heatmap Claude Bitcoin ResNet-50 Search FP8 diffusers TSV FP64 Freesound Web OpenAI FlashAttention CAM Clash Interview GPTQ Linux SQL 版权 BTC CSV Diagram Nginx Password API Card HuggingFace Animate 继承 PDF CV Git NameSilo Math Pillow XML EXCEL 云服务器 Hilton Pandas 算法题 tar LeetCode Pickle Mixtral FastAPI Streamlit Knowledge Paper 净利润 VSCode 搞笑 关于博主 git-lfs CC Markdown VPN 签证 PyCharm TensorRT Paddle JSON QWEN Translation 第一性原理 Quantization SPIE GPT4 Domain TensorFlow Bert 多线程 COCO Baidu Qwen2.5 Qwen2 Google CEIR Anaconda UI v0.dev Land HaggingFace Magnet Safetensors
站点统计

本站现有博文321篇,共被浏览776605

本站已经建立2468天!

热门文章
文章归档
回到顶部