Source code for mlbench_core.models.pytorch.transformer.transformer

from torch import nn

from mlbench_core.models.pytorch.transformer.decoder import TransformerDecoder
from mlbench_core.models.pytorch.transformer.encoder import TransformerEncoder
from mlbench_core.models.pytorch.transformer.modules import build_embedding

DEFAULT_MAX_SOURCE_POSITIONS = 256
DEFAULT_MAX_TARGET_POSITIONS = 256


[docs]class TransformerModel(nn.Module): """Transformer model This model uses MultiHeadAttention as described in :cite:`NIPS2017_7181` Args: args: Arguments of model. All arguments should be accessible via `__getattribute__` method src_dict (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): Source dictionary trg_dict (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): Target dictionary """ def __init__(self, args, src_dict, trg_dict): super().__init__() self._is_generation_fast = False if not hasattr(args, "max_source_positions"): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if not hasattr(args, "max_target_positions"): args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS # Define embedding layer if args.share_all_embeddings: if src_dict != trg_dict: raise ValueError("share_all_embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "share_all_embeddings requires encoder_embed_dim to match decoder_embed_dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "share_all_embeddings not compatible with decoder_embed_path" ) encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = build_embedding( src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = build_embedding( trg_dict, args.decoder_embed_dim, args.decoder_embed_path ) self.encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) self.decoder = TransformerDecoder(args, trg_dict, decoder_embed_tokens) def forward( self, src_tokens, src_lengths, prev_output_tokens, ): """ Run the forward pass of the transformer model. Args: src_tokens (:obj:`torch.Tensor`): Source tokens src_lengths (:obj:`torch.Tensor`): Source sentence lengths prev_output_tokens (:obj:`torch.Tensor`): Previous output tokens Returns: (:obj:`torch.Tensor`, Optional[:obj:`torch.Tensor`]): The model output, and attention weights if needed """ encoder_out = self.encoder(src_tokens) decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out) return decoder_out def max_positions(self): """Maximum length supported by the model.""" return self.encoder.max_positions(), self.decoder.max_positions() def max_decoder_positions(self): """Maximum length supported by the decoder. Returns: (int) """ return self.decoder.max_positions()