Source code for mlbench_core.models.pytorch.transformer.decoder
import math
import torch
import torch.nn.functional as F
from torch import nn
from mlbench_core.models.pytorch.transformer.modules import (
PositionalEmbedding,
TransformerDecoderLayer,
)
[docs]class TransformerDecoder(nn.Module):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args: Arguments of model. All arguments should be accessible via `__getattribute__` method
dictionary (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
left_pad (bool): Pad targets to the left (`True`) or right (`False`). Default: `False`
"""
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False
):
super().__init__()
self.dictionary = dictionary
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
self.normalize = args.decoder_normalize_before
if self.normalize:
self.layer_norm = nn.LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if x.size(1) == 1:
if x.is_contiguous():
x = x.view(x.size(0), x.size(1), x.size(2))
else:
x = x.contiguous()
else:
x = x.contiguous()
attn = None
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"]
if encoder_out is not None
else None,
incremental_state,
)
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, attn
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, "reorder_incremental_state"):
module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state)