Source code for mlbench_core.models.pytorch.gnmt.models

import torch.nn as nn
from torch.nn.functional import log_softmax

from mlbench_core.dataset.nlp.pytorch.wmt16 import wmt16_config
from mlbench_core.models.pytorch.gnmt.decoder import ResidualRecurrentDecoder
from mlbench_core.models.pytorch.gnmt.encoder import ResidualRecurrentEncoder


class Seq2Seq(nn.Module):
    """
    Generic Seq2Seq module, with an encoder and a decoder.
    Args:
        encoder (Encoder): Model encoder
        decoder (Decoder): Model decoder
    """

    def __init__(self, encoder=None, decoder=None):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, inputs, lengths):
        """
        Applies the encoder to inputs with a given input sequence lengths.

        Args:
            inputs (torch.tensor): tensor with inputs (seq_len, batch)
            lengths: vector with sequence lengths (excluding padding)

        Returns:
            torch.tensor
        """
        return self.encoder(inputs, lengths)

    def decode(self, inputs, context, inference=False):
        """
        Applies the decoder to inputs, given the context from the encoder.

        Args:
            inputs (torch.tensor): tensor with inputs (seq_len, batch)
            context: context from the encoder
            inference: if True inference mode, if False training mode

        Returns:
            torch.tensor
        """
        return self.decoder(inputs, context, inference)

    def generate(self, inputs, context, beam_size):
        """
        Autoregressive generator, works with SequenceGenerator class.
        Executes decoder (in inference mode), applies log_softmax and topK for
        inference with beam search decoding.

        Args:
            inputs: tensor with inputs to the decoder
            context: context from the encoder
            beam_size: beam size for the generator

        Returns:
            (words, logprobs, scores, new_context)
            words: indices of topK tokens
            logprobs: log probabilities of topK tokens
            scores: scores from the attention module (for coverage penalty)
            new_context: new decoder context, includes new hidden states for
            decoder RNN cells
        """
        logits, scores, new_context = self.decode(inputs, context, True)
        logprobs = log_softmax(logits, dim=-1)
        logprobs, words = logprobs.topk(beam_size, dim=-1)
        return words, logprobs, scores, new_context


[docs]class GNMT(Seq2Seq): """ GNMT v2 model Args: vocab_size (int): size of vocabulary (number of tokens) hidden_size (int): internal hidden size of the model num_layers (int): number of layers, applies to both encoder and decoder dropout (float): probability of dropout (in encoder and decoder) tensors, if false the model uses (seq, batch, feature) share_embedding (bool): if True embeddings are shared between encoder and decoder """ def __init__( self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, share_embedding=True, fusion=True, ): super(GNMT, self).__init__() if share_embedding: embedder = nn.Embedding( vocab_size, hidden_size, padding_idx=wmt16_config.PAD ) nn.init.uniform_(embedder.weight.data, -0.1, 0.1) else: embedder = None self.encoder = ResidualRecurrentEncoder( vocab_size, hidden_size, num_layers, dropout, embedder ) self.decoder = ResidualRecurrentDecoder( vocab_size, hidden_size, num_layers, dropout, embedder, fusion=fusion ) def forward(self, input_encoder, input_enc_len, input_decoder): context = self.encode(input_encoder, input_enc_len) context = (context, input_enc_len, None) output, _, _ = self.decode(input_decoder, context) return output