Source code for mlbench_core.dataset.nlp.pytorch.wmt16.wmt16_tokenizer

import os
from collections import defaultdict
from functools import partial

import torch

from mlbench_core.dataset.nlp.pytorch.wmt16 import wmt16_config


def _pad_vocabulary(vocab, math):
    """
    Pads vocabulary to a multiple of 'pad' tokens.

    Args:
        vocab (list): list with vocabulary
        math (str): Math precision. either `fp_16`, `manual_fp16` or `fp32`

    Returns:
        list: padded vocabulary
    """
    if math == "fp16":
        pad = 8
    elif math == "fp32":
        pad = 1
    else:
        raise NotImplementedError()

    vocab_size = len(vocab)
    padded_vocab_size = (vocab_size + pad - 1) // pad * pad
    for i in range(0, padded_vocab_size - vocab_size):
        token = f"madeupword{i:04d}"
        vocab.append(token)
    assert len(vocab) % pad == 0
    return vocab


[docs]class WMT16Tokenizer: """Tokenizer Class for WMT16 that uses the whole vocabulary Args: base_dir (str): Base directory for files math_precision (str): Math precision separator (str): BPE """ def __init__( self, base_dir, math_precision=None, separator="@@", ): self.separator = separator vocab = [ wmt16_config.PAD_TOKEN, wmt16_config.UNK_TOKEN, wmt16_config.BOS_TOKEN, wmt16_config.EOS_TOKEN, ] vocab_fname = os.path.join(base_dir, wmt16_config.VOCAB_FNAME) with open(vocab_fname, encoding="utf-8") as vfile: for line in vfile: vocab.append(line.strip()) vocab = _pad_vocabulary(vocab, math_precision) self.vocab_size = len(vocab) self.tok2idx = defaultdict(partial(int, wmt16_config.UNK)) for idx, token in enumerate(vocab): self.tok2idx[token] = idx self.idx2tok = {} for key, value in self.tok2idx.items(): self.idx2tok[value] = key
[docs] def segment(self, line): """ Tokenizes single sentence and adds special BOS and EOS tokens. :param line: sentence returns: list representing tokenized sentence """ line = line.strip().split() entry = [self.tok2idx[i] for i in line] entry = [wmt16_config.BOS] + entry + [wmt16_config.EOS] return entry
[docs] def detokenize(self, inputs, delim=" "): """ Detokenizes single sentence and removes token separator characters. :param inputs: sequence of tokens :param delim: tokenization delimiter returns: string representing detokenized sentence """ detok = delim.join([self.idx2tok[idx] for idx in inputs]) detok = detok.replace(self.separator + " ", "") detok = detok.replace(self.separator, "") detok = detok.replace(wmt16_config.BOS_TOKEN, "") detok = detok.replace(wmt16_config.EOS_TOKEN, "") detok = detok.replace(wmt16_config.PAD_TOKEN, "") detok = detok.strip() return detok