optimizer

package
v1.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 15, 2026 License: Apache-2.0 Imports: 8 Imported by: 0

Documentation

Overview

Package optimizer provides various optimization algorithms for neural networks.

Package optimizer provides various optimization algorithms for neural networks.

Package optimizer provides various optimization algorithms for neural networks.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type AdamW

type AdamW[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AdamW implements the AdamW optimizer.

func NewAdamW

func NewAdamW[T tensor.Numeric](engine compute.Engine[T], learningRate, beta1, beta2, epsilon, weightDecay T) *AdamW[T]

NewAdamW creates a new AdamW optimizer.

func (*AdamW[T]) Step

func (a *AdamW[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step updates the parameters based on their gradients.

type EMA added in v0.2.1

type EMA[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

EMA wraps an Optimizer with Exponential Moving Average weight averaging. After each inner optimizer step, it updates shadow weights:

shadow = decay * shadow + (1-decay) * param.Value

Call SwapShadow before validation to use averaged weights, then SwapBack to restore training weights.

func NewEMA added in v0.2.1

func NewEMA[T tensor.Numeric](inner Optimizer[T], engine compute.Engine[T], decay T) *EMA[T]

NewEMA creates a new EMA wrapper around the given optimizer.

func (*EMA[T]) Step added in v0.2.1

func (e *EMA[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step runs the inner optimizer step and then updates shadow weights.

func (*EMA[T]) SwapBack added in v0.2.1

func (e *EMA[T]) SwapBack(ctx context.Context, params []*graph.Parameter[T]) error

SwapBack is a semantic alias for SwapShadow — the swap operation is symmetric.

func (*EMA[T]) SwapShadow added in v0.2.1

func (e *EMA[T]) SwapShadow(ctx context.Context, params []*graph.Parameter[T]) error

SwapShadow swaps param.Value with shadow weights for validation/checkpointing.

type Optimizer

type Optimizer[T tensor.Numeric] interface {
	Step(ctx context.Context, params []*graph.Parameter[T]) error
}

Optimizer defines the interface for optimization algorithms.

type SGD

type SGD[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SGD implements the stochastic gradient descent optimizer.

func NewSGD

func NewSGD[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], learningRate float32) *SGD[T]

NewSGD creates a new SGD optimizer.

func (*SGD[T]) Clip

func (s *SGD[T]) Clip(ctx context.Context, params []*graph.Parameter[T], threshold float32)

Clip clips the gradients of the parameters by a threshold.

func (*SGD[T]) Step

func (s *SGD[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step updates the parameters based on their gradients.

type SWA added in v0.2.1

type SWA[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SWA wraps an Optimizer with Stochastic Weight Averaging. Unlike EMA which averages every step, SWA averages at epoch boundaries. Call UpdateAverage at the end of each epoch (after startEpoch). Call SwapWeights before validation to use averaged weights.

func NewSWA added in v0.2.1

func NewSWA[T tensor.Numeric](inner Optimizer[T], engine compute.Engine[T], startEpoch int) *SWA[T]

NewSWA creates a new SWA wrapper around the given optimizer.

func (*SWA[T]) NAveraged added in v0.2.1

func (s *SWA[T]) NAveraged() int

NAveraged returns the number of checkpoints averaged so far.

func (*SWA[T]) Step added in v0.2.1

func (s *SWA[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step delegates to the inner optimizer.

func (*SWA[T]) SwapWeights added in v0.2.1

func (s *SWA[T]) SwapWeights(ctx context.Context, params []*graph.Parameter[T]) error

SwapWeights swaps live params with averaged params.

func (*SWA[T]) UpdateAverage added in v0.2.1

func (s *SWA[T]) UpdateAverage(ctx context.Context, params []*graph.Parameter[T], epoch int) error

UpdateAverage updates the running average of parameters. Should be called at the end of each epoch. Only averages when epoch >= startEpoch. Formula: avg = avg + (param - avg) / (n + 1)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL