Source code for mlbench_core.dataset.nlp.pytorch.wmt16_dataset

import os
from copy import deepcopy

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from mlbench_core.dataset.util.tools import maybe_download_and_extract_tar_gz

from .wmt16 import (
    WMT16Tokenizer,
    filter_data,
    filter_raw_data,
    get_data_dtype,
    process_data,
    process_raw_data,
    wmt16_config,
)


[docs]class WMT16Dataset(Dataset): """Dataset for WMT16 en to de translation Args: root (str): Root folder where to download files lang (tuple): Language translation pair math_precision (str): One of `fp16` or `fp32`. The precision used during training download (bool): Download the dataset from source train (bool): Load train set validation (bool): Load validation set lazy (bool): Load the dataset in a lazy format min_len (int): Minimum sentence length max_len (int | None): Maximum sentence length max_size (int | None): Maximum dataset size """ urls = [ ( "https://storage.googleapis.com/mlbench-datasets/translation/wmt16_en_de.tar.gz", "wmt16_en_de.tar.gz", ) ] def __init__( self, root, lang=("en", "de"), math_precision=None, download=True, train=False, validation=False, lazy=False, preprocessed=False, sort=False, min_len=0, max_len=None, max_size=None, ): super(WMT16Dataset, self).__init__() if download: url, file_name = self.urls[0] maybe_download_and_extract_tar_gz(root, file_name, url) if train: self.path = os.path.join(root, wmt16_config.TRAIN_FNAME) elif validation: self.path = os.path.join(root, wmt16_config.VAL_FNAME) else: raise NotImplementedError() self.min_len, self.max_len = min_len, max_len self.lazy = lazy self.preprocessed = preprocessed self.tokenizer = WMT16Tokenizer(root, math_precision=math_precision) # Data is not tokenized already if not preprocessed: src_path, trg_path = tuple( os.path.expanduser(self.path + "." + x) for x in lang ) if lazy: raw_src = process_raw_data(src_path, max_size) raw_trg = process_raw_data(trg_path, max_size) assert len(raw_src) == len( raw_trg ), "Source and target have different lengths" raw_src, src_lens, raw_trg, trg_lens = filter_raw_data( raw_src, raw_trg, min_len - 2, max_len - 2 ) assert len(raw_src) == len( raw_trg ), "Source and target have different lengths" # Adding 2 because EOS and BOS are added later during tokenization src_lengths = [i + 2 for i in src_lens] trg_lengths = [i + 2 for i in trg_lens] self.src = raw_src self.trg = raw_trg else: src = process_data(src_path, self.tokenizer, max_size) trg = process_data(trg_path, self.tokenizer, max_size) assert len(src) == len(trg), "Source and target have different lengths" src, trg = filter_data(src, trg, min_len, max_len) assert len(src) == len(trg), "Source and target have different lengths" src_lengths = [len(s) for s in src] trg_lengths = [len(t) for t in trg] self.src, self.trg = src, trg else: self.path = "{}.bin".format(self.path) self.file = None ( file_max_len, file_min_len, file_vocab_size, src_lengths, trg_lengths, ) = self._extract_preprocessed_data() assert file_max_len == self.max_len assert file_min_len == self.min_len assert file_vocab_size == self.vocab_size self.dtype, _ = get_data_dtype(self.vocab_size) itemsize = np.iinfo(self.dtype).dtype.itemsize self.item_stride = itemsize * self.max_len * 2 self.src_lengths = torch.tensor(src_lengths) self.trg_lengths = torch.tensor(trg_lengths) # self.lengths = self.src_lengths + self.trg_lengths self.sorted = False if sort: self._sort_by_src_length() self.sorted = True def _extract_preprocessed_data(self): if not (os.path.exists(self.path) and os.path.isfile(self.path)): raise ValueError("File {} not found".format(self.path)) with open(self.path, "rb") as f: length = int(np.fromfile(f, np.int64, 1)) file_vocab_size = int(np.fromfile(f, np.int64, 1)) file_min_len = int(np.fromfile(f, np.int64, 1)) file_max_len = int(np.fromfile(f, np.int64, 1)) src_lengths = np.fromfile(f, np.int64, length) trg_lengths = np.fromfile(f, np.int64, length) self.offset = int(np.fromfile(f, np.int64, 1)) return file_max_len, file_min_len, file_vocab_size, src_lengths, trg_lengths def _sort_by_src_length(self): self.src_lengths, indices = self.src_lengths.sort(descending=True) self.src = [self.src[i] for i in indices] self.trg = [self.trg[i] for i in indices] self.trg_lengths = [self.trg_lengths[i] for i in indices] def prepare(self): if self.preprocessed: self.file = open(self.path, "rb") def __getitem__(self, idx): if self.preprocessed: offset = self.offset + self.item_stride * idx self.file.seek(offset, os.SEEK_SET) data = np.fromfile(self.file, self.dtype, self.max_len * 2) data = data.astype(np.int64) src_len = self.src_lengths[idx] tgt_len = self.trg_lengths[idx] src = torch.tensor(data[0:src_len]) tgt = torch.tensor(data[self.max_len : self.max_len + tgt_len]) return src, tgt else: if self.lazy: src = torch.tensor(self.tokenizer.segment(self.src[idx])) trg = torch.tensor(self.tokenizer.segment(self.trg[idx])) else: src, trg = self.src[idx], self.trg[idx] return src, trg def write_as_preprocessed( self, collate_fn, min_len=0, max_len=50, num_workers=2, batch_size=1024 ): loader = DataLoader( self, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, drop_last=False, ) srcs = [] tgts = [] src_lengths = [] tgt_lengths = [] dtype, torch_dtype = get_data_dtype(self.vocab_size) for (src, src_len), (tgt, tgt_len) in tqdm(loader): src_lengths.append(deepcopy(src_len.to(torch_dtype))) tgt_lengths.append(deepcopy(tgt_len.to(torch_dtype))) srcs.append(deepcopy(src)) tgts.append(deepcopy(tgt)) del src del src_len del tgt del tgt_len del loader import gc gc.collect() srcs = torch.cat(srcs) tgts = torch.cat(tgts) src_lengths = torch.cat(src_lengths) tgt_lengths = torch.cat(tgt_lengths) assert len(srcs) == len(tgts) == len(src_lengths) == len(tgt_lengths) length = len(srcs) offset = 0 dest_fname = "{}.bin".format(self.path) with open(dest_fname, "wb") as f: offset += f.write((np.array(length, dtype=np.int64))) offset += f.write((np.array(self.vocab_size, dtype=np.int64))) offset += f.write((np.array(min_len, dtype=np.int64))) offset += f.write((np.array(max_len, dtype=np.int64))) offset += f.write((np.array(src_lengths, dtype=np.int64))) offset += f.write((np.array(tgt_lengths, dtype=np.int64))) offset += np.iinfo(np.int64).dtype.itemsize f.write((np.array(offset, dtype=np.int64))) del src_lengths del tgt_lengths gc.collect() data = torch.cat((srcs, tgts), dim=1).numpy() f.write((np.array(data, dtype=dtype))) @property def vocab_size(self): return self.tokenizer.vocab_size def __len__(self): return len(self.src_lengths)