metrics

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 18, 2023 License: Apache-2.0 Imports: 7 Imported by: 0

Documentation

Overview

Package metrics holds a library of metrics and defines

Index

Constants

View Source
const (
	LossMetricType     = "loss"
	AccuracyMetricType = "accuracy"
)

Variables

This section is empty.

Functions

func BatchSize

func BatchSize(data *Node) *Node

BatchSize returns the batch size (assumed first dimension) of the data node, casting it to the same dtype as data.

func BinaryAccuracyGraph

func BinaryAccuracyGraph(labels, predictions []*Node) *Node

BinaryAccuracyGraph can be used in combination with New*Metric functions to build metrics for binary accuracy. It assumes predictions are probabilities, that labels are {0, 1} and that predictions and labels have the same shape and dtype.

func BinaryLogitsAccuracyGraph

func BinaryLogitsAccuracyGraph(labels, logits []*Node) *Node

BinaryLogitsAccuracyGraph can be used in combination with New*Metric functions to build metrics for binary accuracy for logit. Notice 0s are considered a miss. It assumes predictions are logits, that labels are {0, 1} and that predictions and labels have the same size and dtype. The shape may be different (e.g.: `[batch_size, 1]` and `[batch_size]`), they will be reshaped to the logits shape before the accuracy is calculated.

func SparseCategoricalAccuracyGraph

func SparseCategoricalAccuracyGraph(labels, logits []*Node) *Node

SparseCategoricalAccuracyGraph returns the accuracy -- fraction of times argmax(logits) is the true label. It works for both probabilities or logits. Ties are considered misses. Labels is expected to be some integer type. And the returned dtype is the same as logits.

Types

type BaseMetricGraph

type BaseMetricGraph func(labels, predictions []*Node) *Node

BaseMetricGraph is a graph building function of any metric that can be calculated stateless, without the need for any context. It should return a scalar, the mean for the given batch.

type Interface

type Interface interface {
	// Name of the metric.
	Name() string

	// ShortName is a shortened version of the name (preferably a few characters) to display in progress bars or
	// similar UIs.
	ShortName() string

	// ScopeName used to store state: a combination of name and something unique.
	ScopeName() string

	// MetricType is a key for metrics that share the same quantity or semantics. Eg.:
	// "Moving-Average-Accuracy" and "Batch-Accuracy" would both have the same
	// "accuracy" metric type, and for instance, can be displayed on the same plot, sharing
	// the Y-axis.
	MetricType() string

	// UpdateGraph builds a graph that takes as input the predictions (or logits) and labels and
	// outputs the resulting metric (a scalar).
	UpdateGraph(ctx *context.Context, labels, predictions []*Node) (metric *Node)

	// PrettyPrint is used to pretty-print a metric value, usually in a short form.
	PrettyPrint(value tensor.Tensor) string

	// Reset metrics internal counters, when starting a new evaluation. Notice this may be called
	// before UpdateGraph, the metric should handle this without errors.
	Reset(ctx *context.Context) error
}

Interface for a Metric.

func NewBaseMetric

func NewBaseMetric(name, shortName, metricType string, metricFn BaseMetricGraph, pPrintFn PrettyPrintFn) Interface

NewBaseMetric creates a stateless metric from any BaseMetricGraph function, it will return the metric calculated solely on the last batch. pPrintFn can be left as nil, and a default will be used.

func NewExponentialMovingAverageMetric

func NewExponentialMovingAverageMetric(name, shortName, metricType string, metricFn BaseMetricGraph, pPrintFn PrettyPrintFn, newExampleWeight float64) Interface

NewExponentialMovingAverageMetric creates a metric from any BaseMetricGraph function. It takes new examples with the given weight (newExampleWeight), and decays the reset to 1-newExampleWeight.

A typical value of newExampleWeight is 0.01, the smaller the value, the slower the moving average moves. pPrintFn can be left as nil, and a default will be used.

This doesn't have a set prior, it will start being a normal average until there are enough terms, and it becomes an exponential moving average.

func NewMeanBinaryAccuracy

func NewMeanBinaryAccuracy(name, shortName string) Interface

NewMeanBinaryAccuracy returns a new binary accuracy metric with the given names.

func NewMeanBinaryLogitsAccuracy

func NewMeanBinaryLogitsAccuracy(name, shortName string) Interface

NewMeanBinaryLogitsAccuracy returns a new binary accuracy metric with the given names.

func NewMeanMetric

func NewMeanMetric(name, shortName, metricType string, metricFn BaseMetricGraph, pPrintFn PrettyPrintFn) Interface

NewMeanMetric creates a metric from any BaseMetricGraph function. It assumes the batch size (to weight the mean with each new result) is given by the first dimension of the labels' node. pPrintFn can be left as nil, and a default will be used.

func NewMovingAverageBinaryAccuracy

func NewMovingAverageBinaryAccuracy(name, shortName string, newExampleWeight float64) Interface

NewMovingAverageBinaryAccuracy returns a new binary accuracy metric with the given names. A typical value of newExampleWeight is 0.01, the smaller the value, the slower the moving average moves.

func NewMovingAverageBinaryLogitsAccuracy

func NewMovingAverageBinaryLogitsAccuracy(name, shortName string, newExampleWeight float64) Interface

NewMovingAverageBinaryLogitsAccuracy returns a new binary accuracy metric with the given names. A typical value of newExampleWeight is 0.01, the smaller the value, the slower the moving average moves.

func NewMovingAverageSparseCategoricalAccuracy

func NewMovingAverageSparseCategoricalAccuracy(name, shortName string, newExampleWeight float64) Interface

NewMovingAverageSparseCategoricalAccuracy returns a new sparse categorical accuracy metric with the given names. The accuracy is defined as the fraction of times argmax(logits) is the true label. It works for both probabilities or logits. Ties are considered misses. Labels is expected to be some integer type. And the returned dtype is the same as logits. A typical value of newExampleWeight is 0.01, the smaller the value, the slower the moving average moves.

func NewSparseCategoricalAccuracy

func NewSparseCategoricalAccuracy(name, shortName string) Interface

NewSparseCategoricalAccuracy returns a new sparse categorical accuracy metric with the given names. The accuracy is defined as the fraction of times argmax(logits) is the true label. It works for both probabilities or logits. Ties are considered misses. Labels is expected to be some integer type. And the returned dtype is the same as logits.

type PrettyPrintFn

type PrettyPrintFn func(value tensor.Tensor) string

PrettyPrintFn is a function to convert a metric value to a string.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL