Source code for mlbench_core.evaluation.tensorflow.metrics

r"""Define tensorflow metrics."""

import tensorflow as tf


class TopKAccuracy(object):
    """Compute the top-k accuracy of logits.

    Args:
        logits (:obj:`tf.Tensor`): input tensor
        labels (:obj:`tf.Tensor`): input one-hot encoded tensor.
        topkk (:obj:`int`, optional): Defaults to 1. top k accuracy.
    """

    def __init__(self, logits, labels, topk=1):
        labels = tf.cast(labels, tf.int32)
        true_classes = tf.argmax(labels, axis=1)

        # predicted classes
        pred_probs = tf.nn.softmax(logits, name="softmax_tensor")
        pred_classes = tf.argmax(pred_probs, axis=1)

        # get metrics.
        with tf.name_scope("metrics"):
            if topk == 1:
                self.name = "Prec@1"
                self.metric_op = (
                    tf.reduce_mean(
                        tf.cast(tf.equal(true_classes, pred_classes), tf.float32)
                    )
                    * 100.0
                )
            else:
                topk_op = tf.nn.in_top_k(
                    predictions=pred_probs, targets=true_classes, k=topk
                )
                self.name = "Prec@" + str(topk)
                self.metric_op = tf.reduce_mean(tf.cast(topk_op, tf.float32)) * 100.0


[docs]def topk_accuracy_with_logits(logits, labels, k=1): """Compute the top-k accuracy of logits. Args: logits (:obj:`tf.Tensor`): input tensor labels (:obj:`tf.Tensor`): input one-hot encoded tensor. k (:obj:`int`, optional): Defaults to 1. top k accuracy. Returns: :obj:`tf.Tensor`: a scalar tensor of the accuracy (between 0 and 1). """ m = TopKAccuracy(logits=logits, labels=labels, topk=k) return {"name": m.name, "value": m.metric_op}