import itertools
import torch
import torch.nn as nn
import mlbench_core.dataset.nlp.pytorch.wmt16.wmt16_config as config
from mlbench_core.models.pytorch.gnmt.attention import BahdanauAttention
from mlbench_core.models.pytorch.gnmt.utils import init_lstm_
[docs]class RecurrentAttention(nn.Module):
"""
LSTM wrapped with an attention module.
Args:
input_size (int): number of features in input tensor
context_size (int): number of features in output from encoder
hidden_size (int): internal hidden size
num_layers (int): number of layers in LSTM
dropout (float): probability of dropout (on input to LSTM layer)
init_weight (float): range for the uniform initializer
"""
def __init__(
self,
input_size=1024,
context_size=1024,
hidden_size=1024,
num_layers=1,
dropout=0.2,
init_weight=0.1,
fusion=True,
):
super(RecurrentAttention, self).__init__()
self.rnn = nn.LSTM(
input_size, hidden_size, num_layers, bias=True, batch_first=False
)
init_lstm_(self.rnn, init_weight)
self.attn = BahdanauAttention(
hidden_size, context_size, context_size, normalize=True, fusion=fusion
)
self.dropout = nn.Dropout(dropout)
[docs] def forward(self, inputs, hidden, context, context_len):
"""
Execute RecurrentAttention.
Args:
inputs (int): tensor with inputs
hidden (int): hidden state for LSTM layer
context: context tensor from encoder
context_len: vector of encoder sequence lengths
Returns:
(rnn_outputs, hidden, attn_output, attn_scores)
"""
# set attention mask, sequences have different lengths, this mask
# allows to include only valid elements of context in attention's
# softmax
self.attn.set_mask(context_len, context)
inputs = self.dropout(inputs)
rnn_outputs, hidden = self.rnn(inputs, hidden)
attn_outputs, scores = self.attn(rnn_outputs, context)
return rnn_outputs, hidden, attn_outputs, scores
[docs]class Classifier(nn.Module):
"""
Fully-connected classifier
Args:
in_features (int): number of input features
out_features (int): number of output features (size of vocabulary)
init_weight (float): range for the uniform initializer
"""
def __init__(self, in_features, out_features, init_weight=0.1):
super(Classifier, self).__init__()
self.classifier = nn.Linear(in_features, out_features)
nn.init.uniform_(self.classifier.weight.data, -init_weight, init_weight)
nn.init.uniform_(self.classifier.bias.data, -init_weight, init_weight)
[docs] def forward(self, x):
"""
Execute the classifier.
Args:
x (torch.tensor):
Returns:
torch.tensor
"""
out = self.classifier(x)
return out
[docs]class ResidualRecurrentDecoder(nn.Module):
"""
Decoder with Embedding, LSTM layers, attention, residual connections and
optinal dropout.
Attention implemented in this module is different than the attention
discussed in the GNMT arxiv paper. In this model the output from the first
LSTM layer of the decoder goes into the attention module, then the
re-weighted context is concatenated with inputs to all subsequent LSTM
layers in the decoder at the current timestep.
Residual connections are enabled after 3rd LSTM layer, dropout is applied
on inputs to LSTM layers.
Args:
vocab_size (int): size of vocabulary
hidden_size (int): hidden size for LSMT layers
num_layers (int): number of LSTM layers
dropout (float): probability of dropout (on input to LSTM layers)
embedder (nn.Embedding): if None constructor will create new
embedding layer
init_weight (float): range for the uniform initializer
"""
def __init__(
self,
vocab_size,
hidden_size=1024,
num_layers=4,
dropout=0.2,
embedder=None,
init_weight=0.1,
fusion=True,
):
super(ResidualRecurrentDecoder, self).__init__()
self.num_layers = num_layers
self.att_rnn = RecurrentAttention(
hidden_size,
hidden_size,
hidden_size,
num_layers=1,
dropout=dropout,
fusion=fusion,
)
self.rnn_layers = nn.ModuleList()
for _ in range(num_layers - 1):
self.rnn_layers.append(
nn.LSTM(
2 * hidden_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
)
)
for lstm in self.rnn_layers:
init_lstm_(lstm, init_weight)
if embedder is not None:
self.embedder = embedder
else:
self.embedder = nn.Embedding(
vocab_size, hidden_size, padding_idx=config.PAD
)
nn.init.uniform_(self.embedder.weight.data, -init_weight, init_weight)
self.classifier = Classifier(hidden_size, vocab_size)
self.dropout = nn.Dropout(p=dropout)
[docs] def init_hidden(self, hidden):
"""
Converts flattened hidden state (from sequence generator) into a tuple
of hidden states.
Args:
hidden: None or flattened hidden state for decoder RNN layers
"""
if hidden is not None:
# per-layer chunks
hidden = hidden.chunk(self.num_layers)
# (h, c) chunks for LSTM layer
hidden = tuple(i.chunk(2) for i in hidden)
else:
hidden = [None] * self.num_layers
self.next_hidden = []
return hidden
[docs] def append_hidden(self, h):
"""
Appends the hidden vector h to the list of internal hidden states.
Args:
h: hidden vector
"""
if self.inference:
self.next_hidden.append(h)
[docs] def package_hidden(self):
"""
Flattens the hidden state from all LSTM layers into one tensor (for
the sequence generator).
"""
if self.inference:
hidden = torch.cat(tuple(itertools.chain(*self.next_hidden)))
else:
hidden = None
return hidden
[docs] def forward(self, inputs, context, inference=False):
"""
Execute the decoder.
Args:
inputs: tensor with inputs to the decoder
context: state of encoder, encoder sequence lengths and hidden
state of decoder's LSTM layers
inference: if True stores and repackages hidden state
Returns:
"""
self.inference = inference
enc_context, enc_len, hidden = context
hidden = self.init_hidden(hidden)
x = self.embedder(inputs)
x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len)
self.append_hidden(h)
x = torch.cat((x, attn), dim=2)
x = self.dropout(x)
x, h = self.rnn_layers[0](x, hidden[1])
self.append_hidden(h)
for i in range(1, len(self.rnn_layers)):
residual = x
x = torch.cat((x, attn), dim=2)
x = self.dropout(x)
x, h = self.rnn_layers[i](x, hidden[i + 1])
self.append_hidden(h)
x = x + residual
x = self.classifier(x)
hidden = self.package_hidden()
return x, scores, [enc_context, enc_len, hidden]