Documentation ¶
Index ¶
- func CrossEntropy(g *ag.Graph, x ag.Node, c int) ag.Node
- func CrossEntropySeq(g *ag.Graph, predicted []ag.Node, target []int, reduceMean bool) ag.Node
- func Distance(g *ag.Graph, x ag.Node, target mat.Float) ag.Node
- func FocalLoss(g *ag.Graph, x ag.Node, c int, gamma mat.Float) ag.Node
- func MAE(g *ag.Graph, x ag.Node, y ag.Node, reduceMean bool) ag.Node
- func MAESeq(g *ag.Graph, predicted []ag.Node, target []ag.Node, reduceMean bool) ag.Node
- func MSE(g *ag.Graph, x ag.Node, y ag.Node, reduceMean bool) ag.Node
- func MSESeq(g *ag.Graph, predicted []ag.Node, target []ag.Node, reduceMean bool) ag.Node
- func NLL(g *ag.Graph, x ag.Node, y ag.Node) ag.Node
- func Norm2Quantization(g *ag.Graph, x ag.Node) ag.Node
- func OneHotQuantization(g *ag.Graph, x ag.Node, q mat.Float) ag.Node
- func Perplexity(g *ag.Graph, x ag.Node, c int) ag.Node
- func SPG(g *ag.Graph, logPropActions []ag.Node, logProbTargets []ag.Node) ag.Node
- func WeightedCrossEntropy(weights []mat.Float) func(g *ag.Graph, x ag.Node, c int) ag.Node
- func WeightedFocalLoss(weights []mat.Float) func(g *ag.Graph, x ag.Node, c int, gamma mat.Float) ag.Node
- func ZeroOneQuantization(g *ag.Graph, x ag.Node) ag.Node
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func CrossEntropy ¶
CrossEntropy implements a cross-entropy loss function. x is the raw scores for each class (logits). c is the index of the gold class.
func CrossEntropySeq ¶
CrossEntropySeq calculates the CrossEntropy loss on the given sequence.
func FocalLoss ¶ added in v0.6.0
FocalLoss implements a variant of the CrossEntropy loss that reduces the loss contribution from "easy" examples and increases the importance of correcting misclassified examples. x is the raw scores for each class (logits). c is the index of the gold class. gamma is the focusing parameter (gamma ≥ 0).
func MAE ¶
MAE measures the mean absolute error (a.k.a. L1 Loss) between each element in the input x and target y.
func MSE ¶
MSE measures the mean squared error (squared L2 norm) between each element in the input x and target y.
func NLL ¶
NLL returns the loss of the input x respect to the target y. The target is expected to be a one-hot vector.
func Norm2Quantization ¶
Norm2Quantization is a loss function that is minimized when norm2(x) = 1.
func OneHotQuantization ¶
OneHotQuantization is a loss function that pushes towards the x vector to be 1-hot. q is the quantization regularizer weight (suggested 0.00001).
func Perplexity ¶
Perplexity computes the perplexity, implemented as exp over the cross-entropy.
func SPG ¶
SPG (Softmax Policy Gradient) is a Gradient Policy used in Reinforcement Learning. logPropActions are the log-probability of the chosen action by the Agent at each time; logProbTargets are results of the reward function i.e. the predicted log-likelihood of the ground truth at each time;
func WeightedCrossEntropy ¶ added in v0.6.0
WeightedCrossEntropy implements a weighted cross-entropy loss function. x is the raw scores for each class (logits). c is the index of the gold class. This function is scaled by a weighting factor weights[class] ∈ [0,1]
func WeightedFocalLoss ¶ added in v0.6.0
func WeightedFocalLoss(weights []mat.Float) func(g *ag.Graph, x ag.Node, c int, gamma mat.Float) ag.Node
WeightedFocalLoss implements a variant of the CrossEntropy loss that reduces the loss contribution from "easy" examples and increases the importance of correcting misclassified examples. x is the raw scores for each class (logits). c is the index of the gold class. gamma is the focusing parameter (gamma ≥ 0). This function is scaled by a weighting factor weights[class] ∈ [0,1].
Types ¶
This section is empty.