EADST

Attention Net with Pytorch

Attention net may be put after the LSTM processing in the NLP task.

import torch
import torch.nn as nn
from torch.autograd import Variable

# attention layer
def attention_net(lstm_output):
    hidden_size = 300
    w_omega = Variable(torch.zeros(hidden_size, 2))
    u_omega = Variable(torch.zeros(2))
    output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
    u = torch.tanh(torch.mm(output_reshape, w_omega))
    attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
    sequence_length = lstm_output.size()[1]
    alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
    alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
    state = lstm_output.permute(1, 0, 2)
    attn_output = torch.sum(state * alphas_reshape, 1)
    return attn_output


# add attention layer after lstm
if attetion_mode == True:
    lstm_out = lstm_out.permute(1, 0, 2)
    lstm_out = attention_net(lstm_out)

相关标签
About Me
XD
Goals determine what you are going to be.
Category
标签云
Github Base64 Quantize Color SVR CC Random printf CSV VSCode DeepSeek BTC Jupyter ONNX Git Llama Pandas v0.dev Domain LaTeX Linux COCO CV Heatmap FastAPI 算法题 CEIR Land Data Tracking Freesound Diagram Web git-lfs 论文 多线程 RGB Tiktoken UNIX DeepStream 证件照 Permission Crawler OpenCV FP64 Proxy Python Bert BeautifulSoup uwsgi GPT4 icon WAN 报税 Input NLP IndexTTS2 Algorithm YOLO Card 域名 LoRA EXCEL Distillation BF16 Math AI Streamlit 净利润 Video JSON Pillow SQL Website FlashAttention 图标 GoogLeNet FP16 transformers FP32 Interview Zip Google LeetCode Image2Text Excel InvalidArgumentError torchinfo Miniforge Vim Template 签证 Breakpoint 搞笑 GGML scipy Sklearn llama.cpp Statistics Docker TSV 阿里云 Mixtral Search 第一性原理 Datetime Animate 强化学习 ModelScope Use Rebuttal Knowledge 论文速读 LLAMA NLTK VPN Bin Pickle GPTQ CAM Paper Jetson Ubuntu Plate VGG-16 Markdown 腾讯云 PyTorch Anaconda PyCharm Magnet 财报 Paddle Food Tensor PDF v2ray Vmess Logo Bipartite 顶会 继承 Michelin uWSGI Qwen WebCrawler Attention SQLite LLM Windows Augmentation HaggingFace CTC Plotly 云服务器 TTS 关于博主 hf Transformers TensorRT Baidu Quantization Hungarian Qwen2.5 Conda 飞书 NameSilo Hilton TensorFlow UI Clash PDB tqdm Hotel ChatGPT OCR tar ResNet-50 Qwen2 Cloudreve GIT logger mmap OpenAI Password News 公式 Disk Gemma 图形思考法 Translation Pytorch CUDA FP8 版权 Claude SPIE Shortcut PIP RAR Bitcoin MD5 XGBoost Review 音频 QWEN Firewall 多进程 XML CLAP SAM diffusers Agent 递归学习法 C++ Django Nginx HuggingFace git Safetensors Dataset Ptyhon API Numpy
站点统计

本站现有博文327篇,共被浏览826838

本站已经建立2533天!

热门文章
文章归档
回到顶部