# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import re
import sys
import unicodedata
import six
import torch
from torch.nn import functional as F
from mlbench_core.models.pytorch.transformer.decoder import TransformerDecoder
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(
six.unichr(x)
for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix)
)
uregex = UnicodeRegex()
def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)]
def post_process_prediction(hypo_tokens, alignment, align_dict, tgt_dict, remove_bpe):
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
assert not align_dict
return hypo_tokens, hypo_str, alignment
def detokenize_subtokenized_sentence(subtokenized_sentence):
l1 = " ".join("".join(subtokenized_sentence.strip().split()).split("_"))
l1 = l1.replace(" ,", ",")
l1 = l1.replace(" .", ".")
l1 = l1.replace(" !", "!")
l1 = l1.replace(" ?", "?")
l1 = l1.replace(" ' ", "'")
l1 = l1.replace(" - ", "-")
l1 = l1.strip()
return l1
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/'
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a numer (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
return "".join(
six.unichr(x)
for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix)
)
[docs]class SequenceGenerator(object):
"""Generates translations of a given source sentence.
Args:
model (:obj:`torch.nn.Module`): The model to predict on. Should be instance of `TransformerModel`
src_dict (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): Source dictionary
trg_dict (:obj:`mlbench_core.dataset.nlp.pytorch.wmt17.Dictionary`): Target dictionary
beam_size (int): Size of the beam. Default 1
minlen (int): Minimum generation length. Default 1
maxlen (int): Maximum generation length. If `None`, takes value of model.max_decoder_positions().
Default `None`
stop_early (bool): Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores. Default `True`
normalize_scores (bool): Normalize scores by the length of the output. Default `True`
len_penalty (float): length penalty: <1.0 favors shorter, >1.0 favors longer sentences.
Default 1
retain_dropout (bool): Keep dropout layers. Default `False`
sampling (bool): sample hypotheses instead of using beam search. Default `False`
sampling_topk (int): sample from top K likely next words instead of all words. Default -1
sampling_temperature (int): temperature for random sampling. Default 1
"""
def __init__(
self,
model,
src_dict,
trg_dict,
beam_size=1,
minlen=1,
maxlen=None,
stop_early=True,
normalize_scores=True,
len_penalty=1,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1,
):
self.model = model
self.pad = trg_dict.pad()
self.eos = trg_dict.eos()
self.vocab_size = len(trg_dict)
self.src_dict = src_dict
self.trg_dict = trg_dict
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = self.model.max_decoder_positions()
max_decoder_len -= 1 # we define maxlen not including the EOS marker
self.maxlen = (
max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
)
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
[docs] def generate_batch_translations(
self,
batch,
maxlen_a=0.0,
maxlen_b=None,
prefix_size=0,
):
"""Yield individual translations of a batch.
Args:
batch (dict): The model input batch. Must have keys `net_input`, `target` and `ntokens`
maxlen_a (float):
maxlen_b (Optional[int]): Generate sequences of max lengths `maxlen_a*x + maxlen_b` where `x = input sentence length`
prefix_size (int): Prefix size
"""
if maxlen_b is None:
maxlen_b = self.maxlen
if "net_input" not in batch:
return
input = batch["net_input"]
srclen = input["src_tokens"].size(1)
with torch.no_grad():
hypos = self.generate(
input["src_tokens"],
input["src_lengths"],
maxlen=int(maxlen_a * srclen + maxlen_b),
prefix_tokens=batch["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
for i, id in enumerate(batch["id"].data):
# remove padding
src = strip_pad(input["src_tokens"].data[i, :], self.pad)
ref = (
strip_pad(batch["target"].data[i, :], self.pad)
if batch["target"] is not None
else None
)
yield id, src, ref, hypos[i]
[docs] def generate(self, src_tokens, src_lengths, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with torch.no_grad():
return self._generate(src_tokens, src_lengths, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = self.beam_size
beam_size = min(beam_size, self.vocab_size - 1)
incremental_state = None
if not self.retain_dropout:
self.model.eval()
if isinstance(self.model.decoder, TransformerDecoder):
incremental_state = {}
# compute the encoder output for each beam
encoder_out = self.model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen)
)
# initialize buffers
scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
attn, attn_buf = None, None
nonpad_idxs = None
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{"idx": None, "score": -math.inf} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfinalized_scores=None):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == maxlen or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
# We don't know why the reference adds 5 and divides by 6, perhaps for rounding
best_unfinalized_score /= ((maxlen + 5) / 6) ** self.len_penalty
if worst_finalized[sent]["score"] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[
:, 1 : step + 2
] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = (
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
if attn is not None
else None
)
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
# We don't know why the reference adds 5 and divides by 6, perhaps for rounding
eos_scores /= (((step + 1) + 5) / 6) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(
zip(bbsz_idx.tolist(), eos_scores.tolist())
):
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo():
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
"tokens": tokens_clone[i],
"score": score,
"attention": hypo_attn, # src_len x tgt_len
"alignment": alignment,
"positional_scores": pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]["score"]:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]["idx"]
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(
enumerate(finalized[sent]), key=lambda r: r[1]["score"]
)
worst_finalized[sent] = {
"score": s["score"],
"idx": idx,
}
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
batch_idxs
)
reorder_state.view(-1, beam_size).add_(
corr.unsqueeze(-1) * beam_size
)
if isinstance(self.model.decoder, TransformerDecoder):
self.model.decoder.reorder_incremental_state(
incremental_state, reorder_state
)
encoder_out = self.model.encoder.reorder_encoder_out(
encoder_out, reorder_state
)
probs, avg_attn_scores = self._decode_one(
tokens[:, : step + 1],
self.model,
encoder_out,
incremental_state,
log_probs=True,
)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs)
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer("cand_scores", type_of=scores)
cand_indices = buffer("cand_indices")
cand_beams = buffer("cand_beams")
eos_bbsz_idx = buffer("eos_bbsz_idx")
eos_scores = buffer("eos_scores", type_of=scores)
if step < maxlen:
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice,
dim=1,
index=prefix_tokens[:, step].view(-1, 1).data,
).expand(-1, cand_size)
cand_indices = (
prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
)
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert (
self.pad == 1
), "sampling assumes the first two symbols can be ignored"
if self.sampling_topk > 0:
values, indices = probs[:, 2:].topk(self.sampling_topk)
exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(
exp_probs, beam_size, replacement=True, out=cand_indices
)
else:
torch.multinomial(
exp_probs, 1, replacement=True, out=cand_indices
)
torch.gather(
exp_probs, dim=1, index=cand_indices, out=cand_scores
)
torch.gather(
indices, dim=1, index=cand_indices, out=cand_indices
)
cand_indices.add_(2)
else:
exp_probs = (
probs.div_(self.sampling_temperature)
.exp_()
.view(-1, self.vocab_size)
)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(
exp_probs[:, 2:],
beam_size,
replacement=True,
out=cand_indices,
)
else:
torch.multinomial(
exp_probs[:, 2:], 1, replacement=True, out=cand_indices
)
cand_indices.add_(2)
torch.gather(
exp_probs, dim=1, index=cand_indices, out=cand_scores
)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = (
torch.arange(0, beam_size)
.repeat(bsz, 2)
.type_as(cand_indices)
)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step - 1].view(bsz, beam_size),
dim=1,
index=cand_beams,
)
)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(
cand_size, probs.view(bsz, -1).size(1) - 1
), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.floor_divide(
cand_indices, self.vocab_size, out=cand_beams.resize_(0)
)
cand_indices.fmod_(self.vocab_size)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
probs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= len(
finalize_hypos(step, eos_bbsz_idx, eos_scores)
)
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx.resize_(0),
)
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores.resize_(0),
)
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores
)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = torch.nonzero(batch_mask).squeeze(-1)
# batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(
new_bsz * beam_size, attn.size(1), -1
)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer("active_mask")
torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
out=active_mask.resize_(0),
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore")
torch.topk(
active_mask,
k=beam_size,
dim=1,
largest=False,
out=(_ignore, active_hypos),
)
active_bbsz_idx = buffer("active_bbsz_idx")
torch.gather(
cand_bbsz_idx,
dim=1,
index=active_hypos,
out=active_bbsz_idx,
)
active_scores = torch.gather(
cand_scores,
dim=1,
index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, : step + 1],
dim=0,
index=active_bbsz_idx,
out=tokens_buf[:, : step + 1],
)
torch.gather(
cand_indices,
dim=1,
index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
scores[:, :step],
dim=0,
index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
cand_scores,
dim=1,
index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, : step + 2],
dim=0,
index=active_bbsz_idx,
out=attn_buf[:, :, : step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(
finalized[sent], key=lambda r: r["score"], reverse=True
)
return finalized
def _decode_one(self, tokens, model, encoder_out, incremental_state, log_probs):
with torch.no_grad():
if incremental_state is not None:
decoder_out = list(
model.decoder(
tokens, encoder_out, incremental_state=incremental_state
)
)
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
attn = attn[:, -1, :]
probs = get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
[docs] def translate_batch(
self,
batch,
maxlen_a=1.0,
maxlen_b=50,
prefix_size=0,
remove_bpe=None,
nbest=1,
ignore_case=True,
):
"""
Args:
batch (dict): The model input batch. Must have keys `net_input`, `target` and `ntokens`
maxlen_a (float): Default 1.0
maxlen_b (Optional[int]): Generate sequences of max lengths `maxlen_a*x + maxlen_b` where `x = input sentence length`.
Default 50
prefix_size (int): Prefix size. Default 0
remove_bpe (Optional[str]): BPE token. Default `None`
nbest (int): Number of hypotheses to output. Default 1
ignore_case (bool): Ignore case druing online eval. Default `True`
Returns:
(list[str], list[str]): The translations and their targets for the given batch
"""
translations = self.generate_batch_translations(
batch,
maxlen_a=maxlen_a,
maxlen_b=maxlen_b,
prefix_size=prefix_size,
)
ref_toks = []
sys_toks = []
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
src_str = self.src_dict.string(src_tokens, remove_bpe)
if has_target:
target_str = self.trg_dict.string(target_tokens, remove_bpe)
# Process top predictions
for i, hypo in enumerate(hypos[: min(len(hypos), nbest)]):
hypo_tokens, hypo_str, alignment = post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
alignment=hypo["alignment"].int().cpu()
if hypo["alignment"] is not None
else None,
align_dict=None,
tgt_dict=self.trg_dict,
remove_bpe=remove_bpe,
)
# Score only the top hypothesis
if has_target and i == 0:
src_str = detokenize_subtokenized_sentence(src_str)
target_str = detokenize_subtokenized_sentence(target_str)
hypo_str = detokenize_subtokenized_sentence(hypo_str)
sys_tok = bleu_tokenize(
hypo_str.lower() if ignore_case else hypo_str
)
ref_tok = bleu_tokenize(
target_str.lower() if ignore_case else target_str
)
sys_toks.append(sys_tok)
ref_toks.append(ref_tok)
return sys_toks, ref_toks
def get_normalized_probs(net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0]
if log_probs:
return F.log_softmax(logits, dim=-1, dtype=torch.float32)
else:
return F.softmax(logits, dim=-1, dtype=torch.float32)