Source code for mlbench_core.models.pytorch.transformer.decoder

import math

import torch
import torch.nn.functional as F
from torch import nn

from mlbench_core.models.pytorch.transformer.modules import (
    PositionalEmbedding,
    TransformerDecoderLayer,
)


[docs]class TransformerDecoder(nn.Module): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args: Arguments of model. All arguments should be accessible via `__getattribute__` method dictionary (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). left_pad (bool): Pad targets to the left (`True`) or right (`False`). Default: `False` """ def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False ): super().__init__() self.dictionary = dictionary self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed embed_dim = embed_tokens.embedding_dim padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = ( PositionalEmbedding( args.max_target_positions, embed_dim, padding_idx, left_pad=left_pad, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) self.layers = nn.ModuleList( [ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] ) if not self.share_input_output_embed: self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) self.normalize = args.decoder_normalize_before if self.normalize: self.layer_norm = nn.LayerNorm(embed_dim) def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # embed positions positions = ( self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) if x.size(1) == 1: if x.is_contiguous(): x = x.view(x.size(0), x.size(1), x.size(2)) else: x = x.contiguous() else: x = x.contiguous() attn = None # decoder layers for layer in self.layers: x, attn = layer( x, encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out["encoder_padding_mask"] if encoder_out is not None else None, incremental_state, ) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x, attn def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions()) def reorder_incremental_state(self, incremental_state, new_order): """Reorder incremental state. This should be called when the order of the input has changed from the previous time step. A typical use case is beam search, where the input order changes between time steps based on the selection of beams. """ def apply_reorder_incremental_state(module): if module != self and hasattr(module, "reorder_incremental_state"): module.reorder_incremental_state( incremental_state, new_order, ) self.apply(apply_reorder_incremental_state)