training

package
v1.38.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 31, 2026 License: Apache-2.0 Imports: 11 Imported by: 0

Documentation

Overview

Package training provides adapter implementations for bridging existing and new interfaces.

Package training provides tools for training neural networks.

Package training provides the V2 trainer API using a Batch and pluggable strategy.

Package training provides neural network training orchestration for the Zerfoo ML framework. (Stability: beta)

The package implements a layered design: a core Trainer interface for single-step parameter updates, a DefaultTrainer that wires together a computation graph, a loss node, an optimizer, and a pluggable gradient strategy, and a higher-level TrainingWorkflow interface for full training-loop orchestration with data providers and model providers.

Trainer and DefaultTrainer

Trainer is the fundamental training interface. It performs one training step: forward pass, loss computation, backward pass, and parameter update.

trainer := training.NewDefaultTrainer[float32](g, lossNode, opt, nil)
loss, err := trainer.TrainStep(ctx, g, opt, inputs, targets)

DefaultTrainer is the standard implementation. It delegates gradient computation to a GradientStrategy and parameter updates to an optimizer.Optimizer. When no strategy is provided, it defaults to DefaultBackpropStrategy.

Gradient Strategies

GradientStrategy controls how gradients are computed for each training step. Two strategies are provided:

Custom strategies can implement the GradientStrategy interface to add auxiliary losses, gradient clipping, deep supervision, or other specialized gradient computation techniques.

Optimizers

The optimizer sub-package defines the optimizer.Optimizer interface and provides several implementations:

  • AdamW — Adam with decoupled weight decay.
  • SGD — Stochastic gradient descent with optional momentum.
  • EMA — Exponential moving average of model parameters.
  • SWA — Stochastic weight averaging.

Batch and Data Iteration

Batch groups inputs and targets for a single training step. Inputs are provided as a map from graph input nodes to tensors; targets are a single tensor.

DataIterator provides sequential access to batches. ChunkedDataIterator loads batches in chunks via a callback, keeping only one chunk in memory at a time for large datasets. DataIteratorAdapter wraps a static slice of batches as a DataIterator.

Model Interface

Model defines a trainable model with Forward, Backward, and Parameters methods. This is the low-level model interface used by training components that need direct forward/backward control.

Training Workflow and Plugin Registry

For full training-loop orchestration, TrainingWorkflow combines data providers, model providers, metrics, and the training loop into a single interface. TrainerWorkflowAdapter bridges the core Trainer interface to TrainingWorkflow for use with the plugin system.

PluginRegistry enables runtime registration and lookup of workflows, data providers, model providers, sequence providers, metric computers, and cross validators. Global registries Float32Registry and Float64Registry are provided for common numeric types.

See interfaces_doc.go for detailed documentation of the plugin architecture and workflow interfaces. Stability: beta

Package training defines training-time gradient computation strategies.

Package training provides generic interfaces for ML training workflows.

Package training provides comprehensive documentation for the generic training interfaces.

Overview

The training package implements a plugin-based architecture that allows domain-specific applications to customize training behavior while maintaining a generic, reusable core. This design follows the hexagonal architecture pattern, where the core framework defines ports (interfaces) and domain-specific applications provide adapters (implementations).

Core Design Principles

  1. **Separation of Concerns**: Generic ML training logic is separated from domain-specific business logic (e.g., custom scoring or evaluation requirements).

  2. **Dependency Inversion**: High-level training workflows depend on abstractions, not concrete implementations.

  3. **Plugin Architecture**: Components can be registered and swapped at runtime using the plugin registry system.

4. **Backward Compatibility**: Adapter patterns allow legacy code to work with new interfaces.

  1. **Extensibility**: Configuration structures include extension points for domain-specific customization.

Interface Hierarchy

## Primary Interfaces

### TrainingWorkflow[T]

The main orchestrator for training processes. Implementations define the complete training pipeline including initialization, training loops, validation, and cleanup.

Usage Pattern:

workflow := registry.GetWorkflow(ctx, "standard", config)
workflow.Initialize(ctx, workflowConfig)
result := workflow.Train(ctx, dataProvider, modelProvider)

### DataProvider[T]

Abstracts data access patterns for training and validation. This replaces domain-specific data loading with a generic interface that can handle any data source.

Implementations should provide: - Efficient batch iteration - Train/validation data splitting - Metadata for training customization - Resource management (cleanup)

Usage Pattern:

dataProvider := registry.GetDataProvider(ctx, "csv", config)
trainingData := dataProvider.GetTrainingData(ctx, batchConfig)
for trainingData.Next(ctx) {
    batch := trainingData.Batch()
    // Process batch
}

### ModelProvider[T]

Abstracts model creation and management. This allows different model architectures to be used with the same training workflow.

Implementations should handle: - Model instantiation from configuration - Model serialization/deserialization - Model metadata and introspection

Usage Pattern:

modelProvider := registry.GetModelProvider(ctx, "mlp", config)
model := modelProvider.CreateModel(ctx, modelConfig)
modelProvider.SaveModel(ctx, model, "/path/to/model")

## Supporting Interfaces

### SequenceProvider[T]

Generalizes sequence generation for curriculum learning. This replaces the domain-specific EraSequencer with a generic interface that can handle any sequencing strategy.

Common implementations: - Consecutive sequence generation - Random sequence sampling - Curriculum learning strategies - Time-series aware sequencing

### MetricComputer[T]

Provides extensible metric computation. Applications can register custom metrics while using standard training workflows.

Usage Pattern:

computer := NewMetricComputer()
computer.RegisterMetric("mse", mseFuncition)
computer.RegisterMetric("custom_score", customFunction)
metrics := computer.ComputeMetrics(ctx, predictions, targets, metadata)

### CrossValidator[T]

Implements various cross-validation strategies in a generic way. This allows time-series CV, k-fold CV, group-based CV, etc. to be used interchangeably.

Configuration System

All interfaces use structured configuration with extension points:

## Extension Pattern

All configuration structures include an `Extensions` field that allows domain-specific applications to pass additional configuration without modifying the core framework:

config := WorkflowConfig{
    NumEpochs: 100,
    Extensions: map[string]interface{}{
        "domain": map[string]interface{}{
            "task_id":     "main",
            "batch_limit": 120,
        },
    },
}

## Configuration Validation

Implementations should validate their configuration and return descriptive errors for invalid or missing parameters.

Plugin Registry System

The plugin registry enables runtime component selection and supports multiple implementations of each interface.

## Registration Pattern

// Register a component
err := Float32Registry.RegisterWorkflow("standard", func(ctx context.Context, config map[string]interface{}) (TrainingWorkflow[float32], error) {
    return NewStandardWorkflow(config), nil
})

// Use registered component
workflow, err := Float32Registry.GetWorkflow(ctx, "standard", config)

## Factory Functions

All registry factories receive context and configuration, enabling: - Initialization validation - Resource allocation - Graceful error handling - Context-aware cleanup

Adapter Pattern Implementation

The adapter pattern allows smooth migration from legacy interfaces to new generic interfaces:

## TrainerWorkflowAdapter

Adapts existing Trainer implementations to work with the TrainingWorkflow interface:

legacy := NewDefaultTrainer(graph, loss, optimizer, strategy)
adapter := NewTrainerWorkflowAdapter(legacy, optimizer)
// Now 'adapter' can be used with new workflow system

## Migration Strategy

1. Implement adapters for existing components 2. Register adapted components in plugin registry 3. Gradually replace legacy direct usage with registry-based usage 4. Implement new components directly against generic interfaces

Error Handling Patterns

## Context Cancellation

All interface methods accept context.Context and should respect cancellation:

func (w *MyWorkflow) Train(ctx context.Context, ...) (*TrainingResult[T], error) {
    for epoch := 0; epoch < maxEpochs; epoch++ {
        select {
        case <-ctx.Done():
            return nil, ctx.Err()
        default:
            // Continue training
        }
    }
}

## Resource Cleanup

Implementations should provide proper resource cleanup:

func (p *MyDataProvider) Close() error {
    // Close files, database connections, etc.
    return nil
}

Use defer for automatic cleanup:

dataProvider := registry.GetDataProvider(ctx, "csv", config)
defer dataProvider.Close()

Performance Considerations

## Memory Management

- Implement proper resource cleanup in Close() methods - Use streaming where possible for large datasets - Consider memory-mapped files for large data

## Concurrency

- DataIterator implementations should be thread-safe - Plugin registry uses read-write mutexes for thread safety - Consider goroutine pools for parallel processing

Testing Patterns

## Mock Implementations

Create mock implementations for testing:

type MockDataProvider[T] struct {
    batches []*Batch[T]
}

func (m *MockDataProvider[T]) GetTrainingData(ctx context.Context, config BatchConfig) (DataIterator[T], error) {
    return NewDataIteratorAdapter(m.batches), nil
}

## Integration Testing

Test complete workflows with real implementations:

func TestWorkflowIntegration(t *testing.T) {
    registry := NewPluginRegistry[float32]()
    registry.RegisterWorkflow("test", testWorkflowFactory)
    registry.RegisterDataProvider("mock", mockDataProviderFactory)

    workflow, _ := registry.GetWorkflow(ctx, "test", config)
    dataProvider, _ := registry.GetDataProvider(ctx, "mock", dataConfig)

    result, err := workflow.Train(ctx, dataProvider, modelProvider)
    assert.NoError(t, err)
    assert.NotNil(t, result)
}

Migration from Domain-Specific Code

## Configuration Migration

Domain-specific configuration moves to extensions:

Before:

config := LegacyTrainingConfig{
    Epochs:     100,
    TaskID:     "main",
    BatchLimit: 120,
}

After:

config := WorkflowConfig{
    NumEpochs: 100,
    Extensions: map[string]interface{}{
        "domain": map[string]interface{}{
            "task_id":     "main",
            "batch_limit": 120,
        },
    },
}

Best Practices

## Implementation Guidelines

1. **Validate Early**: Check configuration in Initialize() methods 2. **Fail Fast**: Return errors immediately for invalid states 3. **Resource Cleanup**: Always implement proper cleanup in Close() methods 4. **Context Awareness**: Respect context cancellation in long-running operations 5. **Thread Safety**: Ensure implementations are thread-safe where documented

## Extension Guidelines

1. **Namespace Extensions**: Use descriptive keys in Extensions maps 2. **Validate Extensions**: Check extension configuration early 3. **Provide Defaults**: Handle missing extension configuration gracefully 4. **Document Extensions**: Clearly document expected extension parameters

## Testing Guidelines

1. **Mock Dependencies**: Use mock implementations for unit testing 2. **Test Error Cases**: Ensure proper error handling and cleanup 3. **Integration Tests**: Test complete workflows with real data 4. **Performance Tests**: Benchmark critical paths with realistic data

Examples

See the following files for complete implementation examples: - adapter.go: Adapter pattern implementations - registry.go: Plugin registry system - interfaces.go: Core interface definitions

Package training provides core components for neural network training.

Package training provides a plugin registry for training components.

Package training defines default backpropagation strategy.

Package training defines the one-step gradient approximation strategy.

Package training provides tools for training neural networks.

Index

Constants

This section is empty.

Variables

View Source
var (
	Float32Registry = NewPluginRegistry[float32]()
	Float64Registry = NewPluginRegistry[float64]()
)

Global registry instances for common numeric types

Functions

func CreateWindows added in v1.10.0

func CreateWindows(data [][]float64, windowLen int) (windows [][][]float64, labels []float64)

CreateWindows converts flat rows into overlapping temporal windows. Each window contains windowLen consecutive rows. Labels come from the last row's last element.

func ParseWindowSizes added in v1.10.0

func ParseWindowSizes(s string) []int

ParseWindowSizes parses a comma-separated string of window sizes. Example: "15,30,60,120" -> []int{15, 30, 60, 120}

Types

type Backend added in v1.9.0

type Backend interface {
	Train(features [][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
}

Backend is implemented by models that train on flat tabular data. Walk-forward validators dispatch to this interface when the model does not implement WindowedBackend.

type Batch added in v0.2.1

type Batch[T tensor.Numeric] struct {
	Inputs  map[graph.Node[T]]*tensor.TensorNumeric[T]
	Targets *tensor.TensorNumeric[T]
}

Batch groups the stable inputs for a single training step.

Inputs are provided as a map keyed by the graph's input nodes. Targets are provided as a single tensor; strategies may interpret targets appropriately for the chosen loss.

type BatchConfig added in v0.2.1

type BatchConfig struct {
	BatchSize  int                    `json:"batch_size"`
	Shuffle    bool                   `json:"shuffle"`
	DropLast   bool                   `json:"drop_last"`
	NumWorkers int                    `json:"num_workers"`
	Extensions map[string]interface{} `json:"extensions"`
}

BatchConfig configures batch processing.

type CVConfig added in v0.2.1

type CVConfig struct {
	Strategy   string                 `json:"strategy"` // "k_fold", "time_series", "group"
	NumFolds   int                    `json:"num_folds"`
	GroupBy    string                 `json:"group_by"`  // For group-based CV
	PurgeGap   int                    `json:"purge_gap"` // For time-series CV
	TestSize   float64                `json:"test_size"`
	RandomSeed uint64                 `json:"random_seed"`
	Extensions map[string]interface{} `json:"extensions"`
}

CVConfig configures cross-validation.

type CVResult added in v0.2.1

type CVResult[T tensor.Numeric] struct {
	MeanLoss    T                      `json:"mean_loss"`
	StdLoss     T                      `json:"std_loss"`
	MeanMetrics map[string]float64     `json:"mean_metrics"`
	StdMetrics  map[string]float64     `json:"std_metrics"`
	FoldResults []ValidationResult[T]  `json:"fold_results"`
	TotalTime   float64                `json:"total_time_seconds"`
	Extensions  map[string]interface{} `json:"extensions"`
}

CVResult contains cross-validation results.

type ChunkLoader added in v0.2.1

type ChunkLoader[T tensor.Numeric] func(chunkIdx int) ([]*Batch[T], error)

ChunkLoader is a callback that returns the next chunk of batches. It receives a zero-based chunk index and returns the batches for that chunk. Return nil batches (or empty slice) with nil error to signal no more chunks.

type ChunkedDataIterator added in v0.2.1

type ChunkedDataIterator[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ChunkedDataIterator loads batches in chunks via a callback function. Each chunk represents a logical unit of data (e.g., one era, one file shard). Only one chunk's batches are held in memory at a time; the previous chunk's data is released when the next chunk is loaded.

func NewChunkedDataIterator added in v0.2.1

func NewChunkedDataIterator[T tensor.Numeric](loader ChunkLoader[T]) *ChunkedDataIterator[T]

NewChunkedDataIterator creates a new iterator that loads batches in chunks. The loader function is called with increasing chunk indices starting from 0. When the loader returns nil or an empty slice with no error, iteration ends.

func (*ChunkedDataIterator[T]) Batch added in v0.2.1

func (c *ChunkedDataIterator[T]) Batch() *Batch[T]

Batch returns the current batch. Returns nil if called before Next or after iteration is exhausted.

func (*ChunkedDataIterator[T]) Close added in v0.2.1

func (c *ChunkedDataIterator[T]) Close() error

Close releases resources held by the iterator.

func (*ChunkedDataIterator[T]) Error added in v0.2.1

func (c *ChunkedDataIterator[T]) Error() error

Error returns any error that occurred during chunk loading.

func (*ChunkedDataIterator[T]) Next added in v0.2.1

func (c *ChunkedDataIterator[T]) Next(_ context.Context) bool

Next advances to the next batch. Returns false when all chunks are exhausted or an error occurs. Automatically loads the next chunk when the current one is fully consumed.

func (*ChunkedDataIterator[T]) Reset added in v0.2.1

func (c *ChunkedDataIterator[T]) Reset() error

Reset rewinds the iterator to the beginning, allowing re-iteration from chunk 0. The loader will be called again starting from index 0.

type CrossValidator added in v0.2.1

type CrossValidator[T tensor.Numeric] interface {
	// CreateFolds generates cross-validation folds from the dataset
	CreateFolds(ctx context.Context, dataset DataProvider[T], config CVConfig) ([]Fold[T], error)

	// ValidateModel performs cross-validation on a model
	ValidateModel(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T], config CVConfig) (*CVResult[T], error)
}

CrossValidator provides generic cross-validation strategies.

type CrossValidatorFactory added in v0.2.1

type CrossValidatorFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (CrossValidator[T], error)

CrossValidatorFactory creates CrossValidator instances

type DataIterator added in v0.2.1

type DataIterator[T tensor.Numeric] interface {
	// Next advances to the next batch, returns false when exhausted
	Next(ctx context.Context) bool

	// Batch returns the current batch data
	Batch() *Batch[T]

	// Error returns any error that occurred during iteration
	Error() error

	// Close releases iterator resources
	Close() error

	// Reset rewinds the iterator to the beginning
	Reset() error
}

DataIterator provides sequential access to training/validation batches.

type DataIteratorAdapter added in v0.2.1

type DataIteratorAdapter[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

DataIteratorAdapter provides a simple iterator implementation over static data.

func NewDataIteratorAdapter added in v0.2.1

func NewDataIteratorAdapter[T tensor.Numeric](batches []*Batch[T]) *DataIteratorAdapter[T]

NewDataIteratorAdapter creates a new data iterator from a slice of batches.

func (*DataIteratorAdapter[T]) Batch added in v0.2.1

func (d *DataIteratorAdapter[T]) Batch() *Batch[T]

Batch implements DataIterator.Batch

func (*DataIteratorAdapter[T]) Close added in v0.2.1

func (d *DataIteratorAdapter[T]) Close() error

Close implements DataIterator.Close

func (*DataIteratorAdapter[T]) Error added in v0.2.1

func (d *DataIteratorAdapter[T]) Error() error

Error implements DataIterator.Error

func (*DataIteratorAdapter[T]) Next added in v0.2.1

func (d *DataIteratorAdapter[T]) Next(ctx context.Context) bool

Next implements DataIterator.Next

func (*DataIteratorAdapter[T]) Reset added in v0.2.1

func (d *DataIteratorAdapter[T]) Reset() error

Reset implements DataIterator.Reset

type DataProvider added in v0.2.1

type DataProvider[T tensor.Numeric] interface {
	// GetTrainingData returns training data in batches
	GetTrainingData(ctx context.Context, config BatchConfig) (DataIterator[T], error)

	// GetValidationData returns validation data in batches
	GetValidationData(ctx context.Context, config BatchConfig) (DataIterator[T], error)

	// GetMetadata returns dataset metadata for training customization
	GetMetadata() map[string]interface{}

	// Close releases any resources held by the provider
	Close() error
}

DataProvider abstracts data access patterns for training and validation. This replaces domain-specific data loading with a generic interface.

type DataProviderFactory added in v0.2.1

type DataProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (DataProvider[T], error)

DataProviderFactory creates DataProvider instances

type DefaultBackpropStrategy added in v0.2.1

type DefaultBackpropStrategy[T tensor.Numeric] struct{}

DefaultBackpropStrategy performs standard backpropagation through the loss and model graph.

func NewDefaultBackpropStrategy added in v0.2.1

func NewDefaultBackpropStrategy[T tensor.Numeric]() *DefaultBackpropStrategy[T]

NewDefaultBackpropStrategy constructs a DefaultBackpropStrategy.

func (*DefaultBackpropStrategy[T]) ComputeGradients added in v0.2.1

func (s *DefaultBackpropStrategy[T]) ComputeGradients(
	ctx context.Context,
	g *graph.Graph[T],
	loss graph.Node[T],
	batch Batch[T],
) (T, error)

ComputeGradients runs forward pass, computes loss, runs backward passes, and leaves parameter gradients populated on the graph's parameters.

type DefaultTrainer added in v0.2.1

type DefaultTrainer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

DefaultTrainer encapsulates stable training components and delegates gradient computation to a strategy.

func NewDefaultTrainer added in v0.2.1

func NewDefaultTrainer[T tensor.Numeric](
	g *graph.Graph[T],
	loss graph.Node[T],
	optimizer opt.Optimizer[T],
	strategy GradientStrategy[T],
) *DefaultTrainer[T]

NewDefaultTrainer constructs a new DefaultTrainer. If strategy is nil, DefaultBackpropStrategy is used.

func (*DefaultTrainer[T]) TrainStep added in v0.2.1

func (t *DefaultTrainer[T]) TrainStep(
	ctx context.Context,
	g *graph.Graph[T],
	optimizer opt.Optimizer[T],
	inputs map[graph.Node[T]]*tensor.TensorNumeric[T],
	targets *tensor.TensorNumeric[T],
) (T, error)

TrainStep performs a single training step using the configured strategy and optimizer.

type EarlyStopConfig added in v1.8.0

type EarlyStopConfig struct {
	// Patience is the number of epochs without improvement before stopping.
	Patience int
	// Alpha is the EMA smoothing factor (0 < alpha < 1). Default: 0.1.
	Alpha float64
	// MinDelta is the minimum improvement to count as progress.
	MinDelta float64
	// Mode is "min" for loss (lower is better) or "max" for accuracy (higher is better).
	Mode string
}

EarlyStopConfig configures smoothed early stopping behavior.

type EarlyStopping added in v1.8.0

type EarlyStopping struct {
	// contains filtered or unexported fields
}

EarlyStopping tracks a smoothed metric via exponential moving average and signals when training should stop due to lack of improvement.

func NewEarlyStopping added in v1.8.0

func NewEarlyStopping(config EarlyStopConfig) *EarlyStopping

NewEarlyStopping creates a new EarlyStopping instance with the given config. If Alpha is zero, it defaults to 0.1. If Mode is empty, it defaults to "min".

func (*EarlyStopping) BestMetric added in v1.8.0

func (es *EarlyStopping) BestMetric() float64

BestMetric returns the best smoothed metric observed so far.

func (*EarlyStopping) Reset added in v1.8.0

func (es *EarlyStopping) Reset()

Reset clears all state so the instance can be reused.

func (*EarlyStopping) Step added in v1.8.0

func (es *EarlyStopping) Step(metric float64) bool

Step records a new metric value and returns true if training should stop. On the first call, it initializes the smoothed metric to the raw value. On subsequent calls, it applies EMA smoothing and checks for improvement.

type Fold added in v0.2.1

type Fold[T tensor.Numeric] interface {
	// TrainData returns the training data for this fold
	TrainData() DataProvider[T]

	// ValidData returns the validation data for this fold
	ValidData() DataProvider[T]

	// FoldIndex returns the fold index
	FoldIndex() int

	// Metadata returns fold-specific metadata
	Metadata() map[string]interface{}
}

Fold represents a single cross-validation fold.

type GradientStrategy added in v0.2.1

type GradientStrategy[T tensor.Numeric] interface {
	ComputeGradients(
		ctx context.Context,
		g *graph.Graph[T],
		loss graph.Node[T],
		batch Batch[T],
	) (lossValue T, err error)
}

GradientStrategy encapsulates how to compute gradients for a training step.

Implementations may perform standard backprop through the loss, use approximations, or incorporate auxiliary losses (e.g., deep supervision). The strategy must leave parameter gradients populated on the graph's parameters so that the optimizer can apply updates afterwards.

type GradientStrategyAdapter added in v0.2.1

type GradientStrategyAdapter[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GradientStrategyAdapter adapts GradientStrategy to work with the new interface system.

func NewGradientStrategyAdapter added in v0.2.1

func NewGradientStrategyAdapter[T tensor.Numeric](strategy GradientStrategy[T], g *graph.Graph[T], lossNode graph.Node[T]) *GradientStrategyAdapter[T]

NewGradientStrategyAdapter creates a new gradient strategy adapter.

func (*GradientStrategyAdapter[T]) ComputeGradientsFromBatch added in v0.2.1

func (a *GradientStrategyAdapter[T]) ComputeGradientsFromBatch(ctx context.Context, batch *Batch[T]) (T, error)

ComputeGradientsFromBatch adapts batch processing to the legacy GradientStrategy interface.

type MetricComputer added in v0.2.1

type MetricComputer[T tensor.Numeric] interface {
	// ComputeMetrics calculates metrics from predictions and targets
	ComputeMetrics(ctx context.Context, predictions, targets *tensor.TensorNumeric[T], metadata map[string]interface{}) (map[string]float64, error)

	// RegisterMetric adds a new metric computation
	RegisterMetric(name string, metric MetricFunction[T])

	// UnregisterMetric removes a metric computation
	UnregisterMetric(name string)

	// AvailableMetrics returns all registered metric names
	AvailableMetrics() []string
}

MetricComputer provides extensible metric computation.

type MetricComputerFactory added in v0.2.1

type MetricComputerFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (MetricComputer[T], error)

MetricComputerFactory creates MetricComputer instances

type MetricFunction added in v0.2.1

type MetricFunction[T tensor.Numeric] func(ctx context.Context, predictions, targets *tensor.TensorNumeric[T], metadata map[string]interface{}) (float64, error)

MetricFunction defines a single metric computation.

type Model

type Model[T tensor.Numeric] interface {
	// Forward performs the forward pass of the model.
	Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
	// Backward performs the backward pass of the model.
	Backward(ctx context.Context, grad *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
	// Parameters returns the parameters of the model.
	Parameters() []*graph.Parameter[T]
}

Model defines the interface for a trainable model.

type ModelConfig added in v0.2.1

type ModelConfig struct {
	Type         string                 `json:"type"`
	Architecture map[string]interface{} `json:"architecture"`
	Hyperparams  map[string]interface{} `json:"hyperparams"`
	Extensions   map[string]interface{} `json:"extensions"`
}

ModelConfig configures model creation.

type ModelInfo added in v0.2.1

type ModelInfo struct {
	Name         string                 `json:"name"`
	Version      string                 `json:"version"`
	Architecture string                 `json:"architecture"`
	Parameters   int64                  `json:"parameter_count"`
	InputShape   []int                  `json:"input_shape"`
	OutputShape  []int                  `json:"output_shape"`
	Extensions   map[string]interface{} `json:"extensions"`
}

ModelInfo contains model metadata.

type ModelProvider added in v0.2.1

type ModelProvider[T tensor.Numeric] interface {
	// CreateModel creates a new model instance
	CreateModel(ctx context.Context, config ModelConfig) (*graph.Graph[T], error)

	// LoadModel loads a pre-trained model
	LoadModel(ctx context.Context, path string) (*graph.Graph[T], error)

	// SaveModel saves the current model state
	SaveModel(ctx context.Context, model *graph.Graph[T], path string) error

	// GetModelInfo returns model metadata
	GetModelInfo() ModelInfo
}

ModelProvider abstracts model creation and management. This allows different model architectures to be used with the same training workflow.

type ModelProviderFactory added in v0.2.1

type ModelProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelProvider[T], error)

ModelProviderFactory creates ModelProvider instances

type OneStepApproximationStrategy added in v0.2.1

type OneStepApproximationStrategy[T tensor.Numeric] struct{}

OneStepApproximationStrategy performs a one-step gradient approximation. It is designed for training recurrent models without full BPTT.

func NewOneStepApproximationStrategy added in v0.2.1

func NewOneStepApproximationStrategy[T tensor.Numeric]() *OneStepApproximationStrategy[T]

NewOneStepApproximationStrategy constructs a OneStepApproximationStrategy.

func (*OneStepApproximationStrategy[T]) ComputeGradients added in v0.2.1

func (s *OneStepApproximationStrategy[T]) ComputeGradients(
	ctx context.Context,
	g *graph.Graph[T],
	loss graph.Node[T],
	batch Batch[T],
) (T, error)

ComputeGradients performs a forward pass and a one-step backward pass.

type PluginRegistry added in v0.2.1

type PluginRegistry[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

PluginRegistry manages registered training components and provides factory functions.

func NewPluginRegistry added in v0.2.1

func NewPluginRegistry[T tensor.Numeric]() *PluginRegistry[T]

NewPluginRegistry creates a new plugin registry.

func (*PluginRegistry[T]) Clear added in v0.2.1

func (r *PluginRegistry[T]) Clear()

Clear removes all registrations.

func (*PluginRegistry[T]) GetCrossValidator added in v0.2.1

func (r *PluginRegistry[T]) GetCrossValidator(ctx context.Context, name string, config map[string]interface{}) (CrossValidator[T], error)

GetCrossValidator retrieves a registered cross validator factory and creates an instance.

func (*PluginRegistry[T]) GetDataProvider added in v0.2.1

func (r *PluginRegistry[T]) GetDataProvider(ctx context.Context, name string, config map[string]interface{}) (DataProvider[T], error)

GetDataProvider retrieves a registered data provider factory and creates an instance.

func (*PluginRegistry[T]) GetMetricComputer added in v0.2.1

func (r *PluginRegistry[T]) GetMetricComputer(ctx context.Context, name string, config map[string]interface{}) (MetricComputer[T], error)

GetMetricComputer retrieves a registered metric computer factory and creates an instance.

func (*PluginRegistry[T]) GetModelProvider added in v0.2.1

func (r *PluginRegistry[T]) GetModelProvider(ctx context.Context, name string, config map[string]interface{}) (ModelProvider[T], error)

GetModelProvider retrieves a registered model provider factory and creates an instance.

func (*PluginRegistry[T]) GetSequenceProvider added in v0.2.1

func (r *PluginRegistry[T]) GetSequenceProvider(ctx context.Context, name string, config map[string]interface{}) (SequenceProvider[T], error)

GetSequenceProvider retrieves a registered sequence provider factory and creates an instance.

func (*PluginRegistry[T]) GetWorkflow added in v0.2.1

func (r *PluginRegistry[T]) GetWorkflow(ctx context.Context, name string, config map[string]interface{}) (TrainingWorkflow[T], error)

GetWorkflow retrieves a registered workflow factory and creates an instance.

func (*PluginRegistry[T]) ListCrossValidators added in v0.2.1

func (r *PluginRegistry[T]) ListCrossValidators() []string

ListCrossValidators returns all registered cross validator names.

func (*PluginRegistry[T]) ListDataProviders added in v0.2.1

func (r *PluginRegistry[T]) ListDataProviders() []string

ListDataProviders returns all registered data provider names.

func (*PluginRegistry[T]) ListMetricComputers added in v0.2.1

func (r *PluginRegistry[T]) ListMetricComputers() []string

ListMetricComputers returns all registered metric computer names.

func (*PluginRegistry[T]) ListModelProviders added in v0.2.1

func (r *PluginRegistry[T]) ListModelProviders() []string

ListModelProviders returns all registered model provider names.

func (*PluginRegistry[T]) ListSequenceProviders added in v0.2.1

func (r *PluginRegistry[T]) ListSequenceProviders() []string

ListSequenceProviders returns all registered sequence provider names.

func (*PluginRegistry[T]) ListWorkflows added in v0.2.1

func (r *PluginRegistry[T]) ListWorkflows() []string

ListWorkflows returns all registered workflow names.

func (*PluginRegistry[T]) RegisterCrossValidator added in v0.2.1

func (r *PluginRegistry[T]) RegisterCrossValidator(name string, factory CrossValidatorFactory[T]) error

RegisterCrossValidator registers a cross validator factory.

func (*PluginRegistry[T]) RegisterDataProvider added in v0.2.1

func (r *PluginRegistry[T]) RegisterDataProvider(name string, factory DataProviderFactory[T]) error

RegisterDataProvider registers a data provider factory.

func (*PluginRegistry[T]) RegisterMetricComputer added in v0.2.1

func (r *PluginRegistry[T]) RegisterMetricComputer(name string, factory MetricComputerFactory[T]) error

RegisterMetricComputer registers a metric computer factory.

func (*PluginRegistry[T]) RegisterModelProvider added in v0.2.1

func (r *PluginRegistry[T]) RegisterModelProvider(name string, factory ModelProviderFactory[T]) error

RegisterModelProvider registers a model provider factory.

func (*PluginRegistry[T]) RegisterSequenceProvider added in v0.2.1

func (r *PluginRegistry[T]) RegisterSequenceProvider(name string, factory SequenceProviderFactory[T]) error

RegisterSequenceProvider registers a sequence provider factory.

func (*PluginRegistry[T]) RegisterWorkflow added in v0.2.1

func (r *PluginRegistry[T]) RegisterWorkflow(name string, factory WorkflowFactory[T]) error

RegisterWorkflow registers a training workflow factory.

func (*PluginRegistry[T]) Summary added in v0.2.1

func (r *PluginRegistry[T]) Summary() map[string]int

Summary returns a summary of all registered components.

func (*PluginRegistry[T]) UnregisterCrossValidator added in v0.2.1

func (r *PluginRegistry[T]) UnregisterCrossValidator(name string)

UnregisterCrossValidator removes a cross validator registration.

func (*PluginRegistry[T]) UnregisterDataProvider added in v0.2.1

func (r *PluginRegistry[T]) UnregisterDataProvider(name string)

UnregisterDataProvider removes a data provider registration.

func (*PluginRegistry[T]) UnregisterMetricComputer added in v0.2.1

func (r *PluginRegistry[T]) UnregisterMetricComputer(name string)

UnregisterMetricComputer removes a metric computer registration.

func (*PluginRegistry[T]) UnregisterModelProvider added in v0.2.1

func (r *PluginRegistry[T]) UnregisterModelProvider(name string)

UnregisterModelProvider removes a model provider registration.

func (*PluginRegistry[T]) UnregisterSequenceProvider added in v0.2.1

func (r *PluginRegistry[T]) UnregisterSequenceProvider(name string)

UnregisterSequenceProvider removes a sequence provider registration.

func (*PluginRegistry[T]) UnregisterWorkflow added in v0.2.1

func (r *PluginRegistry[T]) UnregisterWorkflow(name string)

UnregisterWorkflow removes a workflow registration.

type Predictor added in v1.9.0

type Predictor interface {
	Predict(modelPath string, features [][]float64) ([]float64, error)
}

Predictor is implemented by models that predict from flat feature vectors.

type SequenceConfig added in v0.2.1

type SequenceConfig struct {
	MaxSeqLen    int                    `json:"max_seq_len"`
	NumSequences int                    `json:"num_sequences"`
	Strategy     string                 `json:"strategy"` // "consecutive", "random", "curriculum"
	Extensions   map[string]interface{} `json:"extensions"`
}

SequenceConfig configures sequence generation.

type SequenceProvider added in v0.2.1

type SequenceProvider[T tensor.Numeric] interface {
	// GenerateSequences creates training sequences from the dataset
	GenerateSequences(ctx context.Context, dataset DataProvider[T], config SequenceConfig) ([]DataProvider[T], error)

	// GenerateTrainValidationSplit creates train/validation splits
	GenerateTrainValidationSplit(ctx context.Context, dataset DataProvider[T], config SplitConfig) (DataProvider[T], DataProvider[T], error)

	// SetRandomSeed sets the random seed for reproducible sequence generation
	SetRandomSeed(seed uint64)
}

SequenceProvider abstracts sequence generation for curriculum learning. This replaces the domain-specific EraSequencer with a generic interface.

type SequenceProviderFactory added in v0.2.1

type SequenceProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (SequenceProvider[T], error)

SequenceProviderFactory creates SequenceProvider instances

type SimpleModelProvider added in v0.2.1

type SimpleModelProvider[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SimpleModelProvider provides a basic model provider implementation.

func NewSimpleModelProvider added in v0.2.1

func NewSimpleModelProvider[T tensor.Numeric](
	factory func(ctx context.Context, config ModelConfig) (*graph.Graph[T], error),
	info ModelInfo,
) *SimpleModelProvider[T]

NewSimpleModelProvider creates a new simple model provider.

func (*SimpleModelProvider[T]) CreateModel added in v0.2.1

func (s *SimpleModelProvider[T]) CreateModel(ctx context.Context, config ModelConfig) (*graph.Graph[T], error)

CreateModel implements ModelProvider.CreateModel

func (*SimpleModelProvider[T]) GetModelInfo added in v0.2.1

func (s *SimpleModelProvider[T]) GetModelInfo() ModelInfo

GetModelInfo implements ModelProvider.GetModelInfo

func (*SimpleModelProvider[T]) LoadModel added in v0.2.1

func (s *SimpleModelProvider[T]) LoadModel(ctx context.Context, path string) (*graph.Graph[T], error)

LoadModel implements ModelProvider.LoadModel

func (*SimpleModelProvider[T]) SaveModel added in v0.2.1

func (s *SimpleModelProvider[T]) SaveModel(ctx context.Context, model *graph.Graph[T], path string) error

SaveModel implements ModelProvider.SaveModel by writing model parameters to a GGUF file using the shared ztensor/gguf writer. Each graph parameter is stored as a float32 tensor. The model info architecture name is written as the general.architecture metadata key.

type SplitConfig added in v0.2.1

type SplitConfig struct {
	ValidationRatio float64                `json:"validation_ratio"`
	Strategy        string                 `json:"strategy"` // "random", "chronological", "stratified"
	RandomSeed      uint64                 `json:"random_seed"`
	Extensions      map[string]interface{} `json:"extensions"`
}

SplitConfig configures train/validation splitting.

type TrainConfig added in v1.9.0

type TrainConfig struct {
	Epochs       int
	BatchSize    int
	LearningRate float64
	WeightDecay  float64
}

TrainConfig holds hyperparameters shared across flat and windowed backends.

type TrainResult added in v1.9.0

type TrainResult struct {
	FinalLoss   float64
	BestLoss    float64
	BestEpoch   int
	TotalEpochs int
}

TrainResult holds the outcome of a training run.

func DispatchTrain added in v1.9.0

func DispatchTrain(backend Backend, features [][]float64, labels []float64, config TrainConfig) (*TrainResult, error)

DispatchTrain checks whether backend implements WindowedBackend and calls TrainWindowed if so; otherwise it falls back to Backend.Train with flat features. This is the dispatch logic used by walk-forward validators.

func DispatchTrainWindowed added in v1.9.0

func DispatchTrainWindowed(backend Backend, windows [][][]float64, labels []float64, config TrainConfig) (*TrainResult, error)

DispatchTrainWindowed dispatches a windowed training call. If the backend implements WindowedBackend it calls TrainWindowed directly; otherwise it flattens the windows and falls back to Backend.Train.

type Trainer

type Trainer[T tensor.Numeric] interface {
	// TrainStep performs a single training step for a model.
	// It takes the model's parameters, the optimizer, and the input/target data.
	// It is responsible for computing the loss, gradients, and updating the parameters.
	TrainStep(
		ctx context.Context,
		modelGraph *graph.Graph[T],
		optimizer optimizer.Optimizer[T],
		inputs map[graph.Node[T]]*tensor.TensorNumeric[T],
		targets *tensor.TensorNumeric[T],
	) (loss T, err error)
}

Trainer is an interface for model-specific training orchestration.

type TrainerWorkflowAdapter added in v0.2.1

type TrainerWorkflowAdapter[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

TrainerWorkflowAdapter adapts the existing Trainer interface to the new TrainingWorkflow interface. This allows legacy trainer implementations to work with the new generic workflow system.

func NewTrainerWorkflowAdapter added in v0.2.1

func NewTrainerWorkflowAdapter[T tensor.Numeric](trainer Trainer[T], opt optimizer.Optimizer[T]) *TrainerWorkflowAdapter[T]

NewTrainerWorkflowAdapter creates a new adapter for legacy trainers.

func (*TrainerWorkflowAdapter[T]) GetMetrics added in v0.2.1

func (a *TrainerWorkflowAdapter[T]) GetMetrics() map[string]interface{}

GetMetrics implements TrainingWorkflow.GetMetrics

func (*TrainerWorkflowAdapter[T]) Initialize added in v0.2.1

func (a *TrainerWorkflowAdapter[T]) Initialize(ctx context.Context, config WorkflowConfig) error

Initialize implements TrainingWorkflow.Initialize

func (*TrainerWorkflowAdapter[T]) Shutdown added in v0.2.1

func (a *TrainerWorkflowAdapter[T]) Shutdown(ctx context.Context) error

Shutdown implements TrainingWorkflow.Shutdown

func (*TrainerWorkflowAdapter[T]) Train added in v0.2.1

func (a *TrainerWorkflowAdapter[T]) Train(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*TrainingResult[T], error)

Train implements TrainingWorkflow.Train by adapting to the legacy Trainer interface

func (*TrainerWorkflowAdapter[T]) Validate added in v0.2.1

func (a *TrainerWorkflowAdapter[T]) Validate(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*ValidationResult[T], error)

Validate implements TrainingWorkflow.Validate

type TrainingResult added in v0.2.1

type TrainingResult[T tensor.Numeric] struct {
	FinalLoss    T                      `json:"final_loss"`
	BestLoss     T                      `json:"best_loss"`
	BestEpoch    int                    `json:"best_epoch"`
	TotalEpochs  int                    `json:"total_epochs"`
	TrainingTime float64                `json:"training_time_seconds"`
	Metrics      map[string]float64     `json:"metrics"`
	ModelPath    string                 `json:"model_path,omitempty"`
	Extensions   map[string]interface{} `json:"extensions"`
}

TrainingResult contains training outcome information.

type TrainingWorkflow added in v0.2.1

type TrainingWorkflow[T tensor.Numeric] interface {
	// Initialize prepares the workflow with configuration and dependencies
	Initialize(ctx context.Context, config WorkflowConfig) error

	// Train executes the complete training workflow
	Train(ctx context.Context, dataset DataProvider[T], model ModelProvider[T]) (*TrainingResult[T], error)

	// Validate performs validation on the trained model
	Validate(ctx context.Context, dataset DataProvider[T], model ModelProvider[T]) (*ValidationResult[T], error)

	// GetMetrics returns current training metrics
	GetMetrics() map[string]interface{}

	// Shutdown cleans up resources
	Shutdown(ctx context.Context) error
}

TrainingWorkflow orchestrates the complete training process with pluggable components. This interface allows domain-specific applications to customize training behavior while maintaining framework-agnostic core logic.

type ValidationResult added in v0.2.1

type ValidationResult[T tensor.Numeric] struct {
	Loss           T                      `json:"loss"`
	Metrics        map[string]float64     `json:"metrics"`
	SampleCount    int                    `json:"sample_count"`
	ValidationTime float64                `json:"validation_time_seconds"`
	Extensions     map[string]interface{} `json:"extensions"`
}

ValidationResult contains validation outcome information.

type WindowedBackend added in v1.9.0

type WindowedBackend interface {
	TrainWindowed(windows [][][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
}

WindowedBackend is implemented by time-series models that consume temporal windows rather than flat feature vectors. Walk-forward validators check for this interface via type assertion before falling back to standard training.

type WindowedPredictor added in v1.9.0

type WindowedPredictor interface {
	PredictWindowed(modelPath string, windows [][][]float64) ([]float64, error)
}

WindowedPredictor is implemented by models that predict from temporal windows instead of flat feature vectors.

type WorkflowConfig added in v0.2.1

type WorkflowConfig struct {
	// Training configuration
	NumEpochs    int     `json:"num_epochs"`
	LearningRate float64 `json:"learning_rate"`
	EarlyStopTol float64 `json:"early_stop_tolerance"`
	MaxNoImprove int     `json:"max_no_improve"`
	RandomSeed   uint64  `json:"random_seed"`

	// Component configurations
	BatchConfig   BatchConfig            `json:"batch_config"`
	ModelConfig   ModelConfig            `json:"model_config"`
	MetricConfigs map[string]interface{} `json:"metric_configs"`

	// Extension point for domain-specific configuration
	Extensions map[string]interface{} `json:"extensions"`
}

WorkflowConfig configures the training workflow.

type WorkflowFactory added in v0.2.1

type WorkflowFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (TrainingWorkflow[T], error)

WorkflowFactory creates TrainingWorkflow instances

Directories

Path Synopsis
Package automl provides automated machine learning utilities including Bayesian hyperparameter optimization.
Package automl provides automated machine learning utilities including Bayesian hyperparameter optimization.
Package fp8 implements FP8 mixed-precision training support.
Package fp8 implements FP8 mixed-precision training support.
Package lora implements LoRA and QLoRA fine-tuning adapters.
Package lora implements LoRA and QLoRA fine-tuning adapters.
Package loss provides various loss functions for neural networks.
Package loss provides various loss functions for neural networks.
Package nas implements neural architecture search using DARTS.
Package nas implements neural architecture search using DARTS.
Package online implements online learning with drift detection and model rollback.
Package online implements online learning with drift detection and model rollback.
Package optimizer provides various optimization algorithms for neural networks.
Package optimizer provides various optimization algorithms for neural networks.
Package scheduler provides learning rate scheduling strategies for optimizers.
Package scheduler provides learning rate scheduling strategies for optimizers.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL