Source code for mlbench_core.models.pytorch.transformer.encoder
import math
import torch
import torch.nn.functional as F
from torch import nn
from mlbench_core.models.pytorch.transformer.modules import (
PositionalEmbedding,
TransformerEncoderLayer,
)
[docs]class TransformerEncoder(nn.Module):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args: Arguments of model. All arguments should be accessible via `__getattribute__` method
dictionary (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool): Pad sources to the left (`True`) or right (`False`). Default: `True`
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__()
self.dictionary = dictionary
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.softmax_type = args.softmax_type
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = nn.LayerNorm(embed_dim)
[docs] def forward(self, src_tokens):
"""Forward function of encoder
Args:
src_tokens (:obj:`torch.Tensor`): Source tokens
Returns:
(dict): {`encoder:out` (:obj:`torch.Tensor`), `encoder_padding_mask` (:obj:`torch.Tensor`)}
"""
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if x.size(1) == 1:
if x.is_contiguous():
x = x.view(x.size(0), x.size(1), x.size(2))
else:
x = x.contiguous()
else:
x = x.contiguous()
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
if (self.softmax_type == "fast_fill") and (encoder_padding_mask is not None):
encoder_padding_mask = torch.zeros_like(
encoder_padding_mask, dtype=x.dtype
).masked_fill_(encoder_padding_mask, float("-inf"))
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if self.normalize:
x = self.layer_norm(x)
return {
"encoder_out": x, # T x B x C
"encoder_padding_mask": encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())