optimizer

package
v1.38.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 30, 2026 License: Apache-2.0 Imports: 9 Imported by: 0

Documentation

Overview

Package optimizer provides various optimization algorithms for neural networks.

Package optimizer provides neural network optimizers including AdamW and SGD.

Stability: beta

Package optimizer provides various optimization algorithms for neural networks.

Package optimizer provides various optimization algorithms for neural networks.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type AdamW

type AdamW[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AdamW implements the AdamW optimizer.

func NewAdamW

func NewAdamW[T tensor.Numeric](engine compute.Engine[T], learningRate, beta1, beta2, epsilon, weightDecay T) *AdamW[T]

NewAdamW creates a new AdamW optimizer.

func (*AdamW[T]) SetLR added in v1.8.0

func (a *AdamW[T]) SetLR(lr T)

SetLR sets the learning rate. This is typically called by a scheduler.

func (*AdamW[T]) SetMaxGradNorm added in v1.11.0

func (a *AdamW[T]) SetMaxGradNorm(maxGradNorm float64)

SetMaxGradNorm sets the maximum gradient norm for gradient clipping. If maxGradNorm <= 0, gradient clipping is disabled.

func (*AdamW[T]) Step

func (a *AdamW[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step updates the parameters based on their gradients.

type AdamW8bit added in v1.5.0

type AdamW8bit[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AdamW8bit implements the AdamW optimizer with block-wise INT8 quantization for first and second moment estimates. Parameters remain in full precision. This reduces optimizer state memory by ~4x compared to FP32 AdamW.

func NewAdamW8bit added in v1.5.0

func NewAdamW8bit[T tensor.Numeric](lr, beta1, beta2, eps, wd float32) *AdamW8bit[T]

NewAdamW8bit creates a new 8-bit AdamW optimizer.

func (*AdamW8bit[T]) Step added in v1.5.0

func (a *AdamW8bit[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step updates parameters based on their gradients. Moment estimates are stored in INT8 and dequantized for computation, then re-quantized after update.

type EMA added in v0.2.1

type EMA[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

EMA wraps an Optimizer with Exponential Moving Average weight averaging. After each inner optimizer step, it updates shadow weights:

shadow = decay * shadow + (1-decay) * param.Value

Call SwapShadow before validation to use averaged weights, then SwapBack to restore training weights.

func NewEMA added in v0.2.1

func NewEMA[T tensor.Numeric](inner Optimizer[T], engine compute.Engine[T], decay T) *EMA[T]

NewEMA creates a new EMA wrapper around the given optimizer.

func (*EMA[T]) Step added in v0.2.1

func (e *EMA[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step runs the inner optimizer step and then updates shadow weights.

func (*EMA[T]) SwapBack added in v0.2.1

func (e *EMA[T]) SwapBack(ctx context.Context, params []*graph.Parameter[T]) error

SwapBack is a semantic alias for SwapShadow — the swap operation is symmetric.

func (*EMA[T]) SwapShadow added in v0.2.1

func (e *EMA[T]) SwapShadow(ctx context.Context, params []*graph.Parameter[T]) error

SwapShadow swaps param.Value with shadow weights for validation/checkpointing.

type Int8State added in v1.5.0

type Int8State struct {
	// contains filtered or unexported fields
}

Int8State holds a block-wise INT8-quantized representation of a float32 slice. Each block of blockSize elements shares a single scale factor, reducing memory from 4 bytes/element to ~1 byte/element (+ negligible scale overhead).

type Optimizer

type Optimizer[T tensor.Numeric] interface {
	Step(ctx context.Context, params []*graph.Parameter[T]) error
}

Optimizer defines the interface for optimization algorithms.

type SGD

type SGD[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SGD implements the stochastic gradient descent optimizer.

func NewSGD

func NewSGD[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], learningRate float32) *SGD[T]

NewSGD creates a new SGD optimizer.

func (*SGD[T]) Clip

func (s *SGD[T]) Clip(ctx context.Context, params []*graph.Parameter[T], threshold float32)

Clip clips the gradients of the parameters by a threshold.

func (*SGD[T]) SetLR added in v1.8.0

func (s *SGD[T]) SetLR(lr T)

SetLR sets the learning rate. This is typically called by a scheduler.

func (*SGD[T]) Step

func (s *SGD[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step updates the parameters based on their gradients.

type SWA added in v0.2.1

type SWA[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SWA wraps an Optimizer with Stochastic Weight Averaging. Unlike EMA which averages every step, SWA averages at epoch boundaries. Call UpdateAverage at the end of each epoch (after startEpoch). Call SwapWeights before validation to use averaged weights.

func NewSWA added in v0.2.1

func NewSWA[T tensor.Numeric](inner Optimizer[T], engine compute.Engine[T], startEpoch int) *SWA[T]

NewSWA creates a new SWA wrapper around the given optimizer.

func (*SWA[T]) NAveraged added in v0.2.1

func (s *SWA[T]) NAveraged() int

NAveraged returns the number of checkpoints averaged so far.

func (*SWA[T]) Step added in v0.2.1

func (s *SWA[T]) Step(ctx context.Context, params []*graph.Parameter[T]) error

Step delegates to the inner optimizer.

func (*SWA[T]) SwapWeights added in v0.2.1

func (s *SWA[T]) SwapWeights(ctx context.Context, params []*graph.Parameter[T]) error

SwapWeights swaps live params with averaged params.

func (*SWA[T]) UpdateAverage added in v0.2.1

func (s *SWA[T]) UpdateAverage(ctx context.Context, params []*graph.Parameter[T], epoch int) error

UpdateAverage updates the running average of parameters. Should be called at the end of each epoch. Only averages when epoch >= startEpoch. Formula: avg = avg + (param - avg) / (n + 1)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL