Source code for mlbench_core.dataset.nlp.pytorch.wmt17.wmt17_dictionary

import torch


[docs]class Dictionary(object): """Dictionary Class for WMT17 Dataset. Essentially a mapping from symbols to consecutive integers Args: pad (str): Padding symbol to use eos (str): End of String symbol to use """ def __init__(self, pad="<pad>_", eos="<EOS>_"): self.pad_word, self.eos_word = pad, eos self.symbols = [] self.count = [] self.indices = {} # dictionary indexing starts at 1 for consistency with Lua # Commented out and hard-coded since pad and eos are in the dictionary files already self.add_symbol("<lua_index_compat>") self.pad_index = 1 self.eos_index = 2 self.nspecial = 3 def __eq__(self, other): return self.indices == other.indices def __getitem__(self, idx): if idx < len(self.symbols): return self.symbols[idx] else: assert idx < len(self.symbols) def __len__(self): """Returns the number of symbols in the dictionary""" return len(self.symbols) def index(self, sym): """Returns the index of the specified symbol""" if sym in self.indices: return self.indices[sym] else: assert sym in self.indices def string(self, tensor, bpe_symbol=None): """Helper for converting a tensor of token indices to a string. Can optionally remove BPE symbols or escape <unk> words. """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join(self.string(t) for t in tensor) def token_string(i): return self[i] sent = " ".join(token_string(i) for i in tensor if i != self.eos()) if bpe_symbol is not None: sent = (sent + " ").replace(bpe_symbol, "").rstrip() return sent def add_symbol(self, word, n=1): """Adds a word to the dictionary""" if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx else: idx = len(self.symbols) self.indices[word] = idx self.symbols.append(word) self.count.append(n) return idx def update(self, new_dict): """Updates counts from new dictionary.""" for word in new_dict.symbols: idx2 = new_dict.indices[word] if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + new_dict.count[idx2] else: idx = len(self.symbols) self.indices[word] = idx self.symbols.append(word) self.count.append(new_dict.count[idx2]) def pad(self): """Helper to get index of pad symbol""" return self.pad_index def eos(self): """Helper to get index of end-of-sentence symbol""" return self.eos_index @classmethod def load(cls, f, ignore_utf_errors=False): """Loads the dictionary from a text file with the format: ``` <symbol0> <symbol1> ... ``` Args: f (str): Dictionary file name ignore_utf_errors (bool): Ignore UTF-8 related errors """ if isinstance(f, str): try: if not ignore_utf_errors: with open(f, "r", encoding="utf-8") as fd: return cls.load(fd) else: with open(f, "r", encoding="utf-8", errors="ignore") as fd: return cls.load(fd) except FileNotFoundError as fnfe: raise fnfe except Exception: raise Exception( "Incorrect encoding detected in {}, please rebuild the dataset".format( f ) ) d = cls() for line in f.readlines(): word = line.strip()[1:-1] count = 1 d.indices[word] = len(d.symbols) d.symbols.append(word) d.count.append(count) n_pad_tokens_on_end = 33712 - len(d.symbols) for i in range(n_pad_tokens_on_end): pad_str = "<pad000" + str(i) + ">" d.indices[pad_str] = len(d.symbols) d.symbols.append(pad_str) d.count.append(1) return d