EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Python MD5 Statistics OpenCV Firewall Input Logo FastAPI Web ResNet-50 Quantization Ptyhon llama.cpp Food Augmentation Pillow YOLO mmap NLP InvalidArgumentError Qwen Pytorch Domain Vmess Data 继承 GGML Datetime Nginx FP16 多线程 uWSGI Random TensorFlow logger 音频 腾讯云 CTC Breakpoint PDF Miniforge Plate GoogLeNet transformers CAM CLAP Safetensors Bitcoin CV CUDA 证件照 FP32 Color Shortcut Permission Tiktoken 飞书 净利润 Github SQLite ChatGPT Llama Excel VPN LeetCode NameSilo 签证 Template Markdown Transformers PIP HuggingFace Heatmap Disk Interview Linux printf IndexTTS2 Baidu CC GIT 财报 hf Bipartite Windows Cloudreve Jetson NLTK RGB PyTorch Anaconda SAM DeepStream TSV 阿里云 Distillation 域名 SVR Knowledge UI Image2Text Qwen2 diffusers Use LLM uwsgi Animate COCO Password Django 报税 Card JSON CSV LaTeX WebCrawler torchinfo RAR v0.dev TensorRT 版权 Pandas Paddle Vim OpenAI WAN SQL Freesound Dataset GPT4 Website Michelin Quantize PDB SPIE tqdm VSCode UNIX ModelScope Base64 Mixtral Plotly BTC LLAMA OCR Magnet EXCEL Hotel Gemma Zip C++ FP8 Git Tracking Hilton Math Land Diagram Paper GPTQ 视频信息 FlashAttention VGG-16 多进程 Jupyter Proxy QWEN Docker XML Tensor Claude v2ray PyCharm DeepSeek XGBoost LoRA 关于博主 Sklearn ONNX Review API 算法题 git-lfs Streamlit Ubuntu CEIR Translation Qwen2.5 BF16 HaggingFace AI BeautifulSoup Numpy Bert Hungarian Bin scipy Pickle FP64 Clash git Conda tar 搞笑 Attention Google Algorithm Crawler Video 公式 TTS
站点统计

本站现有博文311篇,共被浏览740087

本站已经建立2377天!

热门文章
文章归档
回到顶部