Attention Net with Pytorch
作者:XD / 发表: 2020年12月29日 07:47 / 更新: 2020年12月29日 07:49 / 编程笔记 / 阅读量:2868
Attention net may be put after the LSTM processing in the NLP task.
import torch
import torch.nn as nn
from torch.autograd import Variable
# attention layer
def attention_net(lstm_output):
hidden_size = 300
w_omega = Variable(torch.zeros(hidden_size, 2))
u_omega = Variable(torch.zeros(2))
output_reshape = torch.Tensor.reshape(lstm_output, [-1, hidden_size])
u = torch.tanh(torch.mm(output_reshape, w_omega))
attn_hidden_layer = torch.mm(u, torch.Tensor.reshape(u_omega, [-1, 1]))
sequence_length = lstm_output.size()[1]
alphas = nn.functional.softmax(attn_hidden_layer, dim=1)
alphas_reshape = torch.Tensor.reshape(alphas, [-1, sequence_length, 1])
state = lstm_output.permute(1, 0, 2)
attn_output = torch.sum(state * alphas_reshape, 1)
return attn_output
# add attention layer after lstm
if attetion_mode == True:
lstm_out = lstm_out.permute(1, 0, 2)
lstm_out = attention_net(lstm_out)