EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Baidu Tracking Password UNIX LoRA Template CLAP QWEN FP64 Knowledge CC Qwen2 继承 Anaconda Shortcut CUDA PDF v0.dev Card git Distillation llama.cpp PyTorch PIP Mixtral transformers NLTK Jetson VPN Interview Augmentation 图形思考法 Breakpoint COCO 阿里云 MD5 Bipartite XGBoost Hilton Bitcoin OpenAI PDB 腾讯云 Dataset SAM GPTQ Michelin Pandas 第一性原理 FastAPI Plate Animate GGML TensorFlow Heatmap Hungarian 搞笑 EXCEL InvalidArgumentError Paper Paddle SQLite Review HaggingFace Image2Text IndexTTS2 OpenCV Markdown Algorithm Github 多进程 GPT4 v2ray FP8 Tiktoken Ubuntu Hotel GIT Search JSON LLM Miniforge PyCharm Agent hf CSV SPIE uWSGI FlashAttention FP16 HuggingFace Safetensors LeetCode Attention Vmess Bin logger Ptyhon mmap CTC Conda Excel Permission Website Django 递归学习法 Python ModelScope 飞书 Color WebCrawler C++ SQL Bert 域名 Diagram Freesound Llama RGB Domain tar Cloudreve Random tqdm diffusers LLAMA Quantization Quantize 净利润 Translation Firewall Transformers NameSilo Linux Zip Pickle DeepStream CV printf TensorRT ChatGPT 财报 Pytorch DeepSeek uwsgi OCR AI Use 签证 Input VGG-16 报税 Qwen2.5 XML Data Proxy FP32 BeautifulSoup Qwen TSV NLP TTS Logo CEIR VSCode Video ResNet-50 证件照 多线程 Sklearn Jupyter SVR Base64 BF16 Web git-lfs Nginx Google RAR BTC Docker torchinfo Clash 版权 LaTeX GoogLeNet Claude Tensor scipy Food Windows Plotly WAN Pillow API Vim 算法题 UI 顶会 Math Disk Statistics Crawler Numpy Streamlit Land CAM Datetime Git 音频 Magnet ONNX YOLO 公式 Gemma 关于博主
站点统计

本站现有博文318篇,共被浏览749014

本站已经建立2400天!

热门文章
文章归档
回到顶部