Source code for mlbench_core.utils.pytorch.inference.beam_search

# Taken from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/GNMT
import torch

import mlbench_core.dataset.nlp.pytorch.wmt16.wmt16_config as wmt16_config


[docs]class SequenceGenerator: """ Generator for the autoregressive inference with beam search decoding. Beam search decoding supports coverage penalty and length normalization. For details, refer to Section 7 of the GNMT paper (https://arxiv.org/pdf/1609.08144.pdf). Args: model: model which implements generate method beam_size (int): decoder beam size max_seq_len (int): maximum decoder sequence length len_norm_factor (float): length normalization factor len_norm_const (float): length normalization constant cov_penalty_factor (float): coverage penalty factor """ def __init__( self, model, beam_size=5, max_seq_len=100, len_norm_factor=0.6, len_norm_const=5, cov_penalty_factor=0.1, ): self.model = model self.beam_size = beam_size self.max_seq_len = max_seq_len self.len_norm_factor = len_norm_factor self.len_norm_const = len_norm_const self.cov_penalty_factor = cov_penalty_factor