Documentation
¶
Overview ¶
Package training provides adapter implementations for bridging existing and new interfaces.
Package training provides tools for training neural networks.
Package training provides the V2 trainer API using a Batch and pluggable strategy.
Package training provides neural network training orchestration for the Zerfoo ML framework. (Stability: beta)
The package implements a layered design: a core Trainer interface for single-step parameter updates, a DefaultTrainer that wires together a computation graph, a loss node, an optimizer, and a pluggable gradient strategy, and a higher-level TrainingWorkflow interface for full training-loop orchestration with data providers and model providers.
Trainer and DefaultTrainer ¶
Trainer is the fundamental training interface. It performs one training step: forward pass, loss computation, backward pass, and parameter update.
trainer := training.NewDefaultTrainer[float32](g, lossNode, opt, nil) loss, err := trainer.TrainStep(ctx, g, opt, inputs, targets)
DefaultTrainer is the standard implementation. It delegates gradient computation to a GradientStrategy and parameter updates to an optimizer.Optimizer. When no strategy is provided, it defaults to DefaultBackpropStrategy.
Gradient Strategies ¶
GradientStrategy controls how gradients are computed for each training step. Two strategies are provided:
- DefaultBackpropStrategy performs standard backpropagation through the full computation graph.
- OneStepApproximationStrategy performs a single-step gradient approximation, useful for recurrent models where full BPTT is too expensive.
Custom strategies can implement the GradientStrategy interface to add auxiliary losses, gradient clipping, deep supervision, or other specialized gradient computation techniques.
Optimizers ¶
The optimizer sub-package defines the optimizer.Optimizer interface and provides several implementations:
- AdamW — Adam with decoupled weight decay.
- SGD — Stochastic gradient descent with optional momentum.
- EMA — Exponential moving average of model parameters.
- SWA — Stochastic weight averaging.
Batch and Data Iteration ¶
Batch groups inputs and targets for a single training step. Inputs are provided as a map from graph input nodes to tensors; targets are a single tensor.
DataIterator provides sequential access to batches. ChunkedDataIterator loads batches in chunks via a callback, keeping only one chunk in memory at a time for large datasets. DataIteratorAdapter wraps a static slice of batches as a DataIterator.
Model Interface ¶
Model defines a trainable model with Forward, Backward, and Parameters methods. This is the low-level model interface used by training components that need direct forward/backward control.
Training Workflow and Plugin Registry ¶
For full training-loop orchestration, TrainingWorkflow combines data providers, model providers, metrics, and the training loop into a single interface. TrainerWorkflowAdapter bridges the core Trainer interface to TrainingWorkflow for use with the plugin system.
PluginRegistry enables runtime registration and lookup of workflows, data providers, model providers, sequence providers, metric computers, and cross validators. Global registries Float32Registry and Float64Registry are provided for common numeric types.
See interfaces_doc.go for detailed documentation of the plugin architecture and workflow interfaces. Stability: beta
Package training defines training-time gradient computation strategies.
Package training provides generic interfaces for ML training workflows.
Package training provides comprehensive documentation for the generic training interfaces.
Overview ¶
The training package implements a plugin-based architecture that allows domain-specific applications to customize training behavior while maintaining a generic, reusable core. This design follows the hexagonal architecture pattern, where the core framework defines ports (interfaces) and domain-specific applications provide adapters (implementations).
Core Design Principles ¶
**Separation of Concerns**: Generic ML training logic is separated from domain-specific business logic (e.g., custom scoring or evaluation requirements).
**Dependency Inversion**: High-level training workflows depend on abstractions, not concrete implementations.
**Plugin Architecture**: Components can be registered and swapped at runtime using the plugin registry system.
4. **Backward Compatibility**: Adapter patterns allow legacy code to work with new interfaces.
- **Extensibility**: Configuration structures include extension points for domain-specific customization.
Interface Hierarchy ¶
## Primary Interfaces
### TrainingWorkflow[T]
The main orchestrator for training processes. Implementations define the complete training pipeline including initialization, training loops, validation, and cleanup.
Usage Pattern:
workflow := registry.GetWorkflow(ctx, "standard", config) workflow.Initialize(ctx, workflowConfig) result := workflow.Train(ctx, dataProvider, modelProvider)
### DataProvider[T]
Abstracts data access patterns for training and validation. This replaces domain-specific data loading with a generic interface that can handle any data source.
Implementations should provide: - Efficient batch iteration - Train/validation data splitting - Metadata for training customization - Resource management (cleanup)
Usage Pattern:
dataProvider := registry.GetDataProvider(ctx, "csv", config)
trainingData := dataProvider.GetTrainingData(ctx, batchConfig)
for trainingData.Next(ctx) {
batch := trainingData.Batch()
// Process batch
}
### ModelProvider[T]
Abstracts model creation and management. This allows different model architectures to be used with the same training workflow.
Implementations should handle: - Model instantiation from configuration - Model serialization/deserialization - Model metadata and introspection
Usage Pattern:
modelProvider := registry.GetModelProvider(ctx, "mlp", config) model := modelProvider.CreateModel(ctx, modelConfig) modelProvider.SaveModel(ctx, model, "/path/to/model")
## Supporting Interfaces
### SequenceProvider[T]
Generalizes sequence generation for curriculum learning. This replaces the domain-specific EraSequencer with a generic interface that can handle any sequencing strategy.
Common implementations: - Consecutive sequence generation - Random sequence sampling - Curriculum learning strategies - Time-series aware sequencing
### MetricComputer[T]
Provides extensible metric computation. Applications can register custom metrics while using standard training workflows.
Usage Pattern:
computer := NewMetricComputer()
computer.RegisterMetric("mse", mseFuncition)
computer.RegisterMetric("custom_score", customFunction)
metrics := computer.ComputeMetrics(ctx, predictions, targets, metadata)
### CrossValidator[T]
Implements various cross-validation strategies in a generic way. This allows time-series CV, k-fold CV, group-based CV, etc. to be used interchangeably.
Configuration System ¶
All interfaces use structured configuration with extension points:
## Extension Pattern
All configuration structures include an `Extensions` field that allows domain-specific applications to pass additional configuration without modifying the core framework:
config := WorkflowConfig{
NumEpochs: 100,
Extensions: map[string]interface{}{
"domain": map[string]interface{}{
"task_id": "main",
"batch_limit": 120,
},
},
}
## Configuration Validation
Implementations should validate their configuration and return descriptive errors for invalid or missing parameters.
Plugin Registry System ¶
The plugin registry enables runtime component selection and supports multiple implementations of each interface.
## Registration Pattern
// Register a component
err := Float32Registry.RegisterWorkflow("standard", func(ctx context.Context, config map[string]interface{}) (TrainingWorkflow[float32], error) {
return NewStandardWorkflow(config), nil
})
// Use registered component
workflow, err := Float32Registry.GetWorkflow(ctx, "standard", config)
## Factory Functions
All registry factories receive context and configuration, enabling: - Initialization validation - Resource allocation - Graceful error handling - Context-aware cleanup
Adapter Pattern Implementation ¶
The adapter pattern allows smooth migration from legacy interfaces to new generic interfaces:
## TrainerWorkflowAdapter
Adapts existing Trainer implementations to work with the TrainingWorkflow interface:
legacy := NewDefaultTrainer(graph, loss, optimizer, strategy) adapter := NewTrainerWorkflowAdapter(legacy, optimizer) // Now 'adapter' can be used with new workflow system
## Migration Strategy
1. Implement adapters for existing components 2. Register adapted components in plugin registry 3. Gradually replace legacy direct usage with registry-based usage 4. Implement new components directly against generic interfaces
Error Handling Patterns ¶
## Context Cancellation
All interface methods accept context.Context and should respect cancellation:
func (w *MyWorkflow) Train(ctx context.Context, ...) (*TrainingResult[T], error) {
for epoch := 0; epoch < maxEpochs; epoch++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
// Continue training
}
}
}
## Resource Cleanup
Implementations should provide proper resource cleanup:
func (p *MyDataProvider) Close() error {
// Close files, database connections, etc.
return nil
}
Use defer for automatic cleanup:
dataProvider := registry.GetDataProvider(ctx, "csv", config) defer dataProvider.Close()
Performance Considerations ¶
## Memory Management
- Implement proper resource cleanup in Close() methods - Use streaming where possible for large datasets - Consider memory-mapped files for large data
## Concurrency
- DataIterator implementations should be thread-safe - Plugin registry uses read-write mutexes for thread safety - Consider goroutine pools for parallel processing
Testing Patterns ¶
## Mock Implementations
Create mock implementations for testing:
type MockDataProvider[T] struct {
batches []*Batch[T]
}
func (m *MockDataProvider[T]) GetTrainingData(ctx context.Context, config BatchConfig) (DataIterator[T], error) {
return NewDataIteratorAdapter(m.batches), nil
}
## Integration Testing
Test complete workflows with real implementations:
func TestWorkflowIntegration(t *testing.T) {
registry := NewPluginRegistry[float32]()
registry.RegisterWorkflow("test", testWorkflowFactory)
registry.RegisterDataProvider("mock", mockDataProviderFactory)
workflow, _ := registry.GetWorkflow(ctx, "test", config)
dataProvider, _ := registry.GetDataProvider(ctx, "mock", dataConfig)
result, err := workflow.Train(ctx, dataProvider, modelProvider)
assert.NoError(t, err)
assert.NotNil(t, result)
}
Migration from Domain-Specific Code ¶
## Configuration Migration
Domain-specific configuration moves to extensions:
Before:
config := LegacyTrainingConfig{
Epochs: 100,
TaskID: "main",
BatchLimit: 120,
}
After:
config := WorkflowConfig{
NumEpochs: 100,
Extensions: map[string]interface{}{
"domain": map[string]interface{}{
"task_id": "main",
"batch_limit": 120,
},
},
}
Best Practices ¶
## Implementation Guidelines
1. **Validate Early**: Check configuration in Initialize() methods 2. **Fail Fast**: Return errors immediately for invalid states 3. **Resource Cleanup**: Always implement proper cleanup in Close() methods 4. **Context Awareness**: Respect context cancellation in long-running operations 5. **Thread Safety**: Ensure implementations are thread-safe where documented
## Extension Guidelines
1. **Namespace Extensions**: Use descriptive keys in Extensions maps 2. **Validate Extensions**: Check extension configuration early 3. **Provide Defaults**: Handle missing extension configuration gracefully 4. **Document Extensions**: Clearly document expected extension parameters
## Testing Guidelines
1. **Mock Dependencies**: Use mock implementations for unit testing 2. **Test Error Cases**: Ensure proper error handling and cleanup 3. **Integration Tests**: Test complete workflows with real data 4. **Performance Tests**: Benchmark critical paths with realistic data
Examples ¶
See the following files for complete implementation examples: - adapter.go: Adapter pattern implementations - registry.go: Plugin registry system - interfaces.go: Core interface definitions
Package training provides core components for neural network training.
Package training provides a plugin registry for training components.
Package training defines default backpropagation strategy.
Package training defines the one-step gradient approximation strategy.
Package training provides tools for training neural networks.
Index ¶
- Variables
- func CreateWindows(data [][]float64, windowLen int) (windows [][][]float64, labels []float64)
- func ParseWindowSizes(s string) []int
- type Backend
- type Batch
- type BatchConfig
- type CVConfig
- type CVResult
- type ChunkLoader
- type ChunkedDataIterator
- type CrossValidator
- type CrossValidatorFactory
- type DataIterator
- type DataIteratorAdapter
- type DataProvider
- type DataProviderFactory
- type DefaultBackpropStrategy
- type DefaultTrainer
- type EarlyStopConfig
- type EarlyStopping
- type Fold
- type GradientStrategy
- type GradientStrategyAdapter
- type MetricComputer
- type MetricComputerFactory
- type MetricFunction
- type Model
- type ModelConfig
- type ModelInfo
- type ModelProvider
- type ModelProviderFactory
- type OneStepApproximationStrategy
- type PluginRegistry
- func (r *PluginRegistry[T]) Clear()
- func (r *PluginRegistry[T]) GetCrossValidator(ctx context.Context, name string, config map[string]interface{}) (CrossValidator[T], error)
- func (r *PluginRegistry[T]) GetDataProvider(ctx context.Context, name string, config map[string]interface{}) (DataProvider[T], error)
- func (r *PluginRegistry[T]) GetMetricComputer(ctx context.Context, name string, config map[string]interface{}) (MetricComputer[T], error)
- func (r *PluginRegistry[T]) GetModelProvider(ctx context.Context, name string, config map[string]interface{}) (ModelProvider[T], error)
- func (r *PluginRegistry[T]) GetSequenceProvider(ctx context.Context, name string, config map[string]interface{}) (SequenceProvider[T], error)
- func (r *PluginRegistry[T]) GetWorkflow(ctx context.Context, name string, config map[string]interface{}) (TrainingWorkflow[T], error)
- func (r *PluginRegistry[T]) ListCrossValidators() []string
- func (r *PluginRegistry[T]) ListDataProviders() []string
- func (r *PluginRegistry[T]) ListMetricComputers() []string
- func (r *PluginRegistry[T]) ListModelProviders() []string
- func (r *PluginRegistry[T]) ListSequenceProviders() []string
- func (r *PluginRegistry[T]) ListWorkflows() []string
- func (r *PluginRegistry[T]) RegisterCrossValidator(name string, factory CrossValidatorFactory[T]) error
- func (r *PluginRegistry[T]) RegisterDataProvider(name string, factory DataProviderFactory[T]) error
- func (r *PluginRegistry[T]) RegisterMetricComputer(name string, factory MetricComputerFactory[T]) error
- func (r *PluginRegistry[T]) RegisterModelProvider(name string, factory ModelProviderFactory[T]) error
- func (r *PluginRegistry[T]) RegisterSequenceProvider(name string, factory SequenceProviderFactory[T]) error
- func (r *PluginRegistry[T]) RegisterWorkflow(name string, factory WorkflowFactory[T]) error
- func (r *PluginRegistry[T]) Summary() map[string]int
- func (r *PluginRegistry[T]) UnregisterCrossValidator(name string)
- func (r *PluginRegistry[T]) UnregisterDataProvider(name string)
- func (r *PluginRegistry[T]) UnregisterMetricComputer(name string)
- func (r *PluginRegistry[T]) UnregisterModelProvider(name string)
- func (r *PluginRegistry[T]) UnregisterSequenceProvider(name string)
- func (r *PluginRegistry[T]) UnregisterWorkflow(name string)
- type Predictor
- type SequenceConfig
- type SequenceProvider
- type SequenceProviderFactory
- type SimpleModelProvider
- func (s *SimpleModelProvider[T]) CreateModel(ctx context.Context, config ModelConfig) (*graph.Graph[T], error)
- func (s *SimpleModelProvider[T]) GetModelInfo() ModelInfo
- func (s *SimpleModelProvider[T]) LoadModel(ctx context.Context, path string) (*graph.Graph[T], error)
- func (s *SimpleModelProvider[T]) SaveModel(ctx context.Context, model *graph.Graph[T], path string) error
- type SplitConfig
- type TrainConfig
- type TrainResult
- type Trainer
- type TrainerWorkflowAdapter
- func (a *TrainerWorkflowAdapter[T]) GetMetrics() map[string]interface{}
- func (a *TrainerWorkflowAdapter[T]) Initialize(ctx context.Context, config WorkflowConfig) error
- func (a *TrainerWorkflowAdapter[T]) Shutdown(ctx context.Context) error
- func (a *TrainerWorkflowAdapter[T]) Train(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*TrainingResult[T], error)
- func (a *TrainerWorkflowAdapter[T]) Validate(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*ValidationResult[T], error)
- type TrainingResult
- type TrainingWorkflow
- type ValidationResult
- type WindowedBackend
- type WindowedPredictor
- type WorkflowConfig
- type WorkflowFactory
Constants ¶
This section is empty.
Variables ¶
var ( Float32Registry = NewPluginRegistry[float32]() Float64Registry = NewPluginRegistry[float64]() )
Global registry instances for common numeric types
Functions ¶
func CreateWindows ¶ added in v1.10.0
CreateWindows converts flat rows into overlapping temporal windows. Each window contains windowLen consecutive rows. Labels come from the last row's last element.
func ParseWindowSizes ¶ added in v1.10.0
ParseWindowSizes parses a comma-separated string of window sizes. Example: "15,30,60,120" -> []int{15, 30, 60, 120}
Types ¶
type Backend ¶ added in v1.9.0
type Backend interface {
Train(features [][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
}
Backend is implemented by models that train on flat tabular data. Walk-forward validators dispatch to this interface when the model does not implement WindowedBackend.
type Batch ¶ added in v0.2.1
type Batch[T tensor.Numeric] struct { Inputs map[graph.Node[T]]*tensor.TensorNumeric[T] Targets *tensor.TensorNumeric[T] }
Batch groups the stable inputs for a single training step.
Inputs are provided as a map keyed by the graph's input nodes. Targets are provided as a single tensor; strategies may interpret targets appropriately for the chosen loss.
type BatchConfig ¶ added in v0.2.1
type BatchConfig struct {
BatchSize int `json:"batch_size"`
Shuffle bool `json:"shuffle"`
DropLast bool `json:"drop_last"`
NumWorkers int `json:"num_workers"`
Extensions map[string]interface{} `json:"extensions"`
}
BatchConfig configures batch processing.
type CVConfig ¶ added in v0.2.1
type CVConfig struct {
Strategy string `json:"strategy"` // "k_fold", "time_series", "group"
NumFolds int `json:"num_folds"`
GroupBy string `json:"group_by"` // For group-based CV
PurgeGap int `json:"purge_gap"` // For time-series CV
TestSize float64 `json:"test_size"`
RandomSeed uint64 `json:"random_seed"`
Extensions map[string]interface{} `json:"extensions"`
}
CVConfig configures cross-validation.
type CVResult ¶ added in v0.2.1
type CVResult[T tensor.Numeric] struct { MeanLoss T `json:"mean_loss"` StdLoss T `json:"std_loss"` MeanMetrics map[string]float64 `json:"mean_metrics"` StdMetrics map[string]float64 `json:"std_metrics"` FoldResults []ValidationResult[T] `json:"fold_results"` TotalTime float64 `json:"total_time_seconds"` Extensions map[string]interface{} `json:"extensions"` }
CVResult contains cross-validation results.
type ChunkLoader ¶ added in v0.2.1
ChunkLoader is a callback that returns the next chunk of batches. It receives a zero-based chunk index and returns the batches for that chunk. Return nil batches (or empty slice) with nil error to signal no more chunks.
type ChunkedDataIterator ¶ added in v0.2.1
ChunkedDataIterator loads batches in chunks via a callback function. Each chunk represents a logical unit of data (e.g., one era, one file shard). Only one chunk's batches are held in memory at a time; the previous chunk's data is released when the next chunk is loaded.
func NewChunkedDataIterator ¶ added in v0.2.1
func NewChunkedDataIterator[T tensor.Numeric](loader ChunkLoader[T]) *ChunkedDataIterator[T]
NewChunkedDataIterator creates a new iterator that loads batches in chunks. The loader function is called with increasing chunk indices starting from 0. When the loader returns nil or an empty slice with no error, iteration ends.
func (*ChunkedDataIterator[T]) Batch ¶ added in v0.2.1
func (c *ChunkedDataIterator[T]) Batch() *Batch[T]
Batch returns the current batch. Returns nil if called before Next or after iteration is exhausted.
func (*ChunkedDataIterator[T]) Close ¶ added in v0.2.1
func (c *ChunkedDataIterator[T]) Close() error
Close releases resources held by the iterator.
func (*ChunkedDataIterator[T]) Error ¶ added in v0.2.1
func (c *ChunkedDataIterator[T]) Error() error
Error returns any error that occurred during chunk loading.
func (*ChunkedDataIterator[T]) Next ¶ added in v0.2.1
func (c *ChunkedDataIterator[T]) Next(_ context.Context) bool
Next advances to the next batch. Returns false when all chunks are exhausted or an error occurs. Automatically loads the next chunk when the current one is fully consumed.
func (*ChunkedDataIterator[T]) Reset ¶ added in v0.2.1
func (c *ChunkedDataIterator[T]) Reset() error
Reset rewinds the iterator to the beginning, allowing re-iteration from chunk 0. The loader will be called again starting from index 0.
type CrossValidator ¶ added in v0.2.1
type CrossValidator[T tensor.Numeric] interface { // CreateFolds generates cross-validation folds from the dataset CreateFolds(ctx context.Context, dataset DataProvider[T], config CVConfig) ([]Fold[T], error) // ValidateModel performs cross-validation on a model ValidateModel(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T], config CVConfig) (*CVResult[T], error) }
CrossValidator provides generic cross-validation strategies.
type CrossValidatorFactory ¶ added in v0.2.1
type CrossValidatorFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (CrossValidator[T], error)
CrossValidatorFactory creates CrossValidator instances
type DataIterator ¶ added in v0.2.1
type DataIterator[T tensor.Numeric] interface { // Next advances to the next batch, returns false when exhausted Next(ctx context.Context) bool // Batch returns the current batch data Batch() *Batch[T] // Error returns any error that occurred during iteration Error() error // Close releases iterator resources Close() error // Reset rewinds the iterator to the beginning Reset() error }
DataIterator provides sequential access to training/validation batches.
type DataIteratorAdapter ¶ added in v0.2.1
DataIteratorAdapter provides a simple iterator implementation over static data.
func NewDataIteratorAdapter ¶ added in v0.2.1
func NewDataIteratorAdapter[T tensor.Numeric](batches []*Batch[T]) *DataIteratorAdapter[T]
NewDataIteratorAdapter creates a new data iterator from a slice of batches.
func (*DataIteratorAdapter[T]) Batch ¶ added in v0.2.1
func (d *DataIteratorAdapter[T]) Batch() *Batch[T]
Batch implements DataIterator.Batch
func (*DataIteratorAdapter[T]) Close ¶ added in v0.2.1
func (d *DataIteratorAdapter[T]) Close() error
Close implements DataIterator.Close
func (*DataIteratorAdapter[T]) Error ¶ added in v0.2.1
func (d *DataIteratorAdapter[T]) Error() error
Error implements DataIterator.Error
func (*DataIteratorAdapter[T]) Next ¶ added in v0.2.1
func (d *DataIteratorAdapter[T]) Next(ctx context.Context) bool
Next implements DataIterator.Next
func (*DataIteratorAdapter[T]) Reset ¶ added in v0.2.1
func (d *DataIteratorAdapter[T]) Reset() error
Reset implements DataIterator.Reset
type DataProvider ¶ added in v0.2.1
type DataProvider[T tensor.Numeric] interface { // GetTrainingData returns training data in batches GetTrainingData(ctx context.Context, config BatchConfig) (DataIterator[T], error) // GetValidationData returns validation data in batches GetValidationData(ctx context.Context, config BatchConfig) (DataIterator[T], error) // GetMetadata returns dataset metadata for training customization GetMetadata() map[string]interface{} // Close releases any resources held by the provider Close() error }
DataProvider abstracts data access patterns for training and validation. This replaces domain-specific data loading with a generic interface.
type DataProviderFactory ¶ added in v0.2.1
type DataProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (DataProvider[T], error)
DataProviderFactory creates DataProvider instances
type DefaultBackpropStrategy ¶ added in v0.2.1
DefaultBackpropStrategy performs standard backpropagation through the loss and model graph.
func NewDefaultBackpropStrategy ¶ added in v0.2.1
func NewDefaultBackpropStrategy[T tensor.Numeric]() *DefaultBackpropStrategy[T]
NewDefaultBackpropStrategy constructs a DefaultBackpropStrategy.
func (*DefaultBackpropStrategy[T]) ComputeGradients ¶ added in v0.2.1
func (s *DefaultBackpropStrategy[T]) ComputeGradients( ctx context.Context, g *graph.Graph[T], loss graph.Node[T], batch Batch[T], ) (T, error)
ComputeGradients runs forward pass, computes loss, runs backward passes, and leaves parameter gradients populated on the graph's parameters.
type DefaultTrainer ¶ added in v0.2.1
DefaultTrainer encapsulates stable training components and delegates gradient computation to a strategy.
func NewDefaultTrainer ¶ added in v0.2.1
func NewDefaultTrainer[T tensor.Numeric]( g *graph.Graph[T], loss graph.Node[T], optimizer opt.Optimizer[T], strategy GradientStrategy[T], ) *DefaultTrainer[T]
NewDefaultTrainer constructs a new DefaultTrainer. If strategy is nil, DefaultBackpropStrategy is used.
func (*DefaultTrainer[T]) TrainStep ¶ added in v0.2.1
func (t *DefaultTrainer[T]) TrainStep( ctx context.Context, g *graph.Graph[T], optimizer opt.Optimizer[T], inputs map[graph.Node[T]]*tensor.TensorNumeric[T], targets *tensor.TensorNumeric[T], ) (T, error)
TrainStep performs a single training step using the configured strategy and optimizer.
type EarlyStopConfig ¶ added in v1.8.0
type EarlyStopConfig struct {
// Patience is the number of epochs without improvement before stopping.
Patience int
// Alpha is the EMA smoothing factor (0 < alpha < 1). Default: 0.1.
Alpha float64
// MinDelta is the minimum improvement to count as progress.
MinDelta float64
// Mode is "min" for loss (lower is better) or "max" for accuracy (higher is better).
Mode string
}
EarlyStopConfig configures smoothed early stopping behavior.
type EarlyStopping ¶ added in v1.8.0
type EarlyStopping struct {
// contains filtered or unexported fields
}
EarlyStopping tracks a smoothed metric via exponential moving average and signals when training should stop due to lack of improvement.
func NewEarlyStopping ¶ added in v1.8.0
func NewEarlyStopping(config EarlyStopConfig) *EarlyStopping
NewEarlyStopping creates a new EarlyStopping instance with the given config. If Alpha is zero, it defaults to 0.1. If Mode is empty, it defaults to "min".
func (*EarlyStopping) BestMetric ¶ added in v1.8.0
func (es *EarlyStopping) BestMetric() float64
BestMetric returns the best smoothed metric observed so far.
func (*EarlyStopping) Reset ¶ added in v1.8.0
func (es *EarlyStopping) Reset()
Reset clears all state so the instance can be reused.
func (*EarlyStopping) Step ¶ added in v1.8.0
func (es *EarlyStopping) Step(metric float64) bool
Step records a new metric value and returns true if training should stop. On the first call, it initializes the smoothed metric to the raw value. On subsequent calls, it applies EMA smoothing and checks for improvement.
type Fold ¶ added in v0.2.1
type Fold[T tensor.Numeric] interface { // TrainData returns the training data for this fold TrainData() DataProvider[T] // ValidData returns the validation data for this fold ValidData() DataProvider[T] // FoldIndex returns the fold index FoldIndex() int // Metadata returns fold-specific metadata Metadata() map[string]interface{} }
Fold represents a single cross-validation fold.
type GradientStrategy ¶ added in v0.2.1
type GradientStrategy[T tensor.Numeric] interface { ComputeGradients( ctx context.Context, g *graph.Graph[T], loss graph.Node[T], batch Batch[T], ) (lossValue T, err error) }
GradientStrategy encapsulates how to compute gradients for a training step.
Implementations may perform standard backprop through the loss, use approximations, or incorporate auxiliary losses (e.g., deep supervision). The strategy must leave parameter gradients populated on the graph's parameters so that the optimizer can apply updates afterwards.
type GradientStrategyAdapter ¶ added in v0.2.1
GradientStrategyAdapter adapts GradientStrategy to work with the new interface system.
func NewGradientStrategyAdapter ¶ added in v0.2.1
func NewGradientStrategyAdapter[T tensor.Numeric](strategy GradientStrategy[T], g *graph.Graph[T], lossNode graph.Node[T]) *GradientStrategyAdapter[T]
NewGradientStrategyAdapter creates a new gradient strategy adapter.
func (*GradientStrategyAdapter[T]) ComputeGradientsFromBatch ¶ added in v0.2.1
func (a *GradientStrategyAdapter[T]) ComputeGradientsFromBatch(ctx context.Context, batch *Batch[T]) (T, error)
ComputeGradientsFromBatch adapts batch processing to the legacy GradientStrategy interface.
type MetricComputer ¶ added in v0.2.1
type MetricComputer[T tensor.Numeric] interface { // ComputeMetrics calculates metrics from predictions and targets ComputeMetrics(ctx context.Context, predictions, targets *tensor.TensorNumeric[T], metadata map[string]interface{}) (map[string]float64, error) // RegisterMetric adds a new metric computation RegisterMetric(name string, metric MetricFunction[T]) // UnregisterMetric removes a metric computation UnregisterMetric(name string) // AvailableMetrics returns all registered metric names AvailableMetrics() []string }
MetricComputer provides extensible metric computation.
type MetricComputerFactory ¶ added in v0.2.1
type MetricComputerFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (MetricComputer[T], error)
MetricComputerFactory creates MetricComputer instances
type MetricFunction ¶ added in v0.2.1
type MetricFunction[T tensor.Numeric] func(ctx context.Context, predictions, targets *tensor.TensorNumeric[T], metadata map[string]interface{}) (float64, error)
MetricFunction defines a single metric computation.
type Model ¶
type Model[T tensor.Numeric] interface { // Forward performs the forward pass of the model. Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) // Backward performs the backward pass of the model. Backward(ctx context.Context, grad *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error) // Parameters returns the parameters of the model. Parameters() []*graph.Parameter[T] }
Model defines the interface for a trainable model.
type ModelConfig ¶ added in v0.2.1
type ModelConfig struct {
Type string `json:"type"`
Architecture map[string]interface{} `json:"architecture"`
Hyperparams map[string]interface{} `json:"hyperparams"`
Extensions map[string]interface{} `json:"extensions"`
}
ModelConfig configures model creation.
type ModelInfo ¶ added in v0.2.1
type ModelInfo struct {
Name string `json:"name"`
Version string `json:"version"`
Architecture string `json:"architecture"`
Parameters int64 `json:"parameter_count"`
InputShape []int `json:"input_shape"`
OutputShape []int `json:"output_shape"`
Extensions map[string]interface{} `json:"extensions"`
}
ModelInfo contains model metadata.
type ModelProvider ¶ added in v0.2.1
type ModelProvider[T tensor.Numeric] interface { // CreateModel creates a new model instance CreateModel(ctx context.Context, config ModelConfig) (*graph.Graph[T], error) // LoadModel loads a pre-trained model LoadModel(ctx context.Context, path string) (*graph.Graph[T], error) // SaveModel saves the current model state SaveModel(ctx context.Context, model *graph.Graph[T], path string) error // GetModelInfo returns model metadata GetModelInfo() ModelInfo }
ModelProvider abstracts model creation and management. This allows different model architectures to be used with the same training workflow.
type ModelProviderFactory ¶ added in v0.2.1
type ModelProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (ModelProvider[T], error)
ModelProviderFactory creates ModelProvider instances
type OneStepApproximationStrategy ¶ added in v0.2.1
OneStepApproximationStrategy performs a one-step gradient approximation. It is designed for training recurrent models without full BPTT.
func NewOneStepApproximationStrategy ¶ added in v0.2.1
func NewOneStepApproximationStrategy[T tensor.Numeric]() *OneStepApproximationStrategy[T]
NewOneStepApproximationStrategy constructs a OneStepApproximationStrategy.
type PluginRegistry ¶ added in v0.2.1
PluginRegistry manages registered training components and provides factory functions.
func NewPluginRegistry ¶ added in v0.2.1
func NewPluginRegistry[T tensor.Numeric]() *PluginRegistry[T]
NewPluginRegistry creates a new plugin registry.
func (*PluginRegistry[T]) Clear ¶ added in v0.2.1
func (r *PluginRegistry[T]) Clear()
Clear removes all registrations.
func (*PluginRegistry[T]) GetCrossValidator ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetCrossValidator(ctx context.Context, name string, config map[string]interface{}) (CrossValidator[T], error)
GetCrossValidator retrieves a registered cross validator factory and creates an instance.
func (*PluginRegistry[T]) GetDataProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetDataProvider(ctx context.Context, name string, config map[string]interface{}) (DataProvider[T], error)
GetDataProvider retrieves a registered data provider factory and creates an instance.
func (*PluginRegistry[T]) GetMetricComputer ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetMetricComputer(ctx context.Context, name string, config map[string]interface{}) (MetricComputer[T], error)
GetMetricComputer retrieves a registered metric computer factory and creates an instance.
func (*PluginRegistry[T]) GetModelProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetModelProvider(ctx context.Context, name string, config map[string]interface{}) (ModelProvider[T], error)
GetModelProvider retrieves a registered model provider factory and creates an instance.
func (*PluginRegistry[T]) GetSequenceProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetSequenceProvider(ctx context.Context, name string, config map[string]interface{}) (SequenceProvider[T], error)
GetSequenceProvider retrieves a registered sequence provider factory and creates an instance.
func (*PluginRegistry[T]) GetWorkflow ¶ added in v0.2.1
func (r *PluginRegistry[T]) GetWorkflow(ctx context.Context, name string, config map[string]interface{}) (TrainingWorkflow[T], error)
GetWorkflow retrieves a registered workflow factory and creates an instance.
func (*PluginRegistry[T]) ListCrossValidators ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListCrossValidators() []string
ListCrossValidators returns all registered cross validator names.
func (*PluginRegistry[T]) ListDataProviders ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListDataProviders() []string
ListDataProviders returns all registered data provider names.
func (*PluginRegistry[T]) ListMetricComputers ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListMetricComputers() []string
ListMetricComputers returns all registered metric computer names.
func (*PluginRegistry[T]) ListModelProviders ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListModelProviders() []string
ListModelProviders returns all registered model provider names.
func (*PluginRegistry[T]) ListSequenceProviders ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListSequenceProviders() []string
ListSequenceProviders returns all registered sequence provider names.
func (*PluginRegistry[T]) ListWorkflows ¶ added in v0.2.1
func (r *PluginRegistry[T]) ListWorkflows() []string
ListWorkflows returns all registered workflow names.
func (*PluginRegistry[T]) RegisterCrossValidator ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterCrossValidator(name string, factory CrossValidatorFactory[T]) error
RegisterCrossValidator registers a cross validator factory.
func (*PluginRegistry[T]) RegisterDataProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterDataProvider(name string, factory DataProviderFactory[T]) error
RegisterDataProvider registers a data provider factory.
func (*PluginRegistry[T]) RegisterMetricComputer ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterMetricComputer(name string, factory MetricComputerFactory[T]) error
RegisterMetricComputer registers a metric computer factory.
func (*PluginRegistry[T]) RegisterModelProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterModelProvider(name string, factory ModelProviderFactory[T]) error
RegisterModelProvider registers a model provider factory.
func (*PluginRegistry[T]) RegisterSequenceProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterSequenceProvider(name string, factory SequenceProviderFactory[T]) error
RegisterSequenceProvider registers a sequence provider factory.
func (*PluginRegistry[T]) RegisterWorkflow ¶ added in v0.2.1
func (r *PluginRegistry[T]) RegisterWorkflow(name string, factory WorkflowFactory[T]) error
RegisterWorkflow registers a training workflow factory.
func (*PluginRegistry[T]) Summary ¶ added in v0.2.1
func (r *PluginRegistry[T]) Summary() map[string]int
Summary returns a summary of all registered components.
func (*PluginRegistry[T]) UnregisterCrossValidator ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterCrossValidator(name string)
UnregisterCrossValidator removes a cross validator registration.
func (*PluginRegistry[T]) UnregisterDataProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterDataProvider(name string)
UnregisterDataProvider removes a data provider registration.
func (*PluginRegistry[T]) UnregisterMetricComputer ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterMetricComputer(name string)
UnregisterMetricComputer removes a metric computer registration.
func (*PluginRegistry[T]) UnregisterModelProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterModelProvider(name string)
UnregisterModelProvider removes a model provider registration.
func (*PluginRegistry[T]) UnregisterSequenceProvider ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterSequenceProvider(name string)
UnregisterSequenceProvider removes a sequence provider registration.
func (*PluginRegistry[T]) UnregisterWorkflow ¶ added in v0.2.1
func (r *PluginRegistry[T]) UnregisterWorkflow(name string)
UnregisterWorkflow removes a workflow registration.
type Predictor ¶ added in v1.9.0
Predictor is implemented by models that predict from flat feature vectors.
type SequenceConfig ¶ added in v0.2.1
type SequenceConfig struct {
MaxSeqLen int `json:"max_seq_len"`
NumSequences int `json:"num_sequences"`
Strategy string `json:"strategy"` // "consecutive", "random", "curriculum"
Extensions map[string]interface{} `json:"extensions"`
}
SequenceConfig configures sequence generation.
type SequenceProvider ¶ added in v0.2.1
type SequenceProvider[T tensor.Numeric] interface { // GenerateSequences creates training sequences from the dataset GenerateSequences(ctx context.Context, dataset DataProvider[T], config SequenceConfig) ([]DataProvider[T], error) // GenerateTrainValidationSplit creates train/validation splits GenerateTrainValidationSplit(ctx context.Context, dataset DataProvider[T], config SplitConfig) (DataProvider[T], DataProvider[T], error) // SetRandomSeed sets the random seed for reproducible sequence generation SetRandomSeed(seed uint64) }
SequenceProvider abstracts sequence generation for curriculum learning. This replaces the domain-specific EraSequencer with a generic interface.
type SequenceProviderFactory ¶ added in v0.2.1
type SequenceProviderFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (SequenceProvider[T], error)
SequenceProviderFactory creates SequenceProvider instances
type SimpleModelProvider ¶ added in v0.2.1
SimpleModelProvider provides a basic model provider implementation.
func NewSimpleModelProvider ¶ added in v0.2.1
func NewSimpleModelProvider[T tensor.Numeric]( factory func(ctx context.Context, config ModelConfig) (*graph.Graph[T], error), info ModelInfo, ) *SimpleModelProvider[T]
NewSimpleModelProvider creates a new simple model provider.
func (*SimpleModelProvider[T]) CreateModel ¶ added in v0.2.1
func (s *SimpleModelProvider[T]) CreateModel(ctx context.Context, config ModelConfig) (*graph.Graph[T], error)
CreateModel implements ModelProvider.CreateModel
func (*SimpleModelProvider[T]) GetModelInfo ¶ added in v0.2.1
func (s *SimpleModelProvider[T]) GetModelInfo() ModelInfo
GetModelInfo implements ModelProvider.GetModelInfo
type SplitConfig ¶ added in v0.2.1
type SplitConfig struct {
ValidationRatio float64 `json:"validation_ratio"`
Strategy string `json:"strategy"` // "random", "chronological", "stratified"
RandomSeed uint64 `json:"random_seed"`
Extensions map[string]interface{} `json:"extensions"`
}
SplitConfig configures train/validation splitting.
type TrainConfig ¶ added in v1.9.0
TrainConfig holds hyperparameters shared across flat and windowed backends.
type TrainResult ¶ added in v1.9.0
TrainResult holds the outcome of a training run.
func DispatchTrain ¶ added in v1.9.0
func DispatchTrain(backend Backend, features [][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
DispatchTrain checks whether backend implements WindowedBackend and calls TrainWindowed if so; otherwise it falls back to Backend.Train with flat features. This is the dispatch logic used by walk-forward validators.
func DispatchTrainWindowed ¶ added in v1.9.0
func DispatchTrainWindowed(backend Backend, windows [][][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
DispatchTrainWindowed dispatches a windowed training call. If the backend implements WindowedBackend it calls TrainWindowed directly; otherwise it flattens the windows and falls back to Backend.Train.
type Trainer ¶
type Trainer[T tensor.Numeric] interface { // TrainStep performs a single training step for a model. // It takes the model's parameters, the optimizer, and the input/target data. // It is responsible for computing the loss, gradients, and updating the parameters. TrainStep( ctx context.Context, modelGraph *graph.Graph[T], optimizer optimizer.Optimizer[T], inputs map[graph.Node[T]]*tensor.TensorNumeric[T], targets *tensor.TensorNumeric[T], ) (loss T, err error) }
Trainer is an interface for model-specific training orchestration.
type TrainerWorkflowAdapter ¶ added in v0.2.1
TrainerWorkflowAdapter adapts the existing Trainer interface to the new TrainingWorkflow interface. This allows legacy trainer implementations to work with the new generic workflow system.
func NewTrainerWorkflowAdapter ¶ added in v0.2.1
func NewTrainerWorkflowAdapter[T tensor.Numeric](trainer Trainer[T], opt optimizer.Optimizer[T]) *TrainerWorkflowAdapter[T]
NewTrainerWorkflowAdapter creates a new adapter for legacy trainers.
func (*TrainerWorkflowAdapter[T]) GetMetrics ¶ added in v0.2.1
func (a *TrainerWorkflowAdapter[T]) GetMetrics() map[string]interface{}
GetMetrics implements TrainingWorkflow.GetMetrics
func (*TrainerWorkflowAdapter[T]) Initialize ¶ added in v0.2.1
func (a *TrainerWorkflowAdapter[T]) Initialize(ctx context.Context, config WorkflowConfig) error
Initialize implements TrainingWorkflow.Initialize
func (*TrainerWorkflowAdapter[T]) Shutdown ¶ added in v0.2.1
func (a *TrainerWorkflowAdapter[T]) Shutdown(ctx context.Context) error
Shutdown implements TrainingWorkflow.Shutdown
func (*TrainerWorkflowAdapter[T]) Train ¶ added in v0.2.1
func (a *TrainerWorkflowAdapter[T]) Train(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*TrainingResult[T], error)
Train implements TrainingWorkflow.Train by adapting to the legacy Trainer interface
func (*TrainerWorkflowAdapter[T]) Validate ¶ added in v0.2.1
func (a *TrainerWorkflowAdapter[T]) Validate(ctx context.Context, dataset DataProvider[T], modelProvider ModelProvider[T]) (*ValidationResult[T], error)
Validate implements TrainingWorkflow.Validate
type TrainingResult ¶ added in v0.2.1
type TrainingResult[T tensor.Numeric] struct { FinalLoss T `json:"final_loss"` BestLoss T `json:"best_loss"` BestEpoch int `json:"best_epoch"` TotalEpochs int `json:"total_epochs"` TrainingTime float64 `json:"training_time_seconds"` Metrics map[string]float64 `json:"metrics"` ModelPath string `json:"model_path,omitempty"` Extensions map[string]interface{} `json:"extensions"` }
TrainingResult contains training outcome information.
type TrainingWorkflow ¶ added in v0.2.1
type TrainingWorkflow[T tensor.Numeric] interface { // Initialize prepares the workflow with configuration and dependencies Initialize(ctx context.Context, config WorkflowConfig) error // Train executes the complete training workflow Train(ctx context.Context, dataset DataProvider[T], model ModelProvider[T]) (*TrainingResult[T], error) // Validate performs validation on the trained model Validate(ctx context.Context, dataset DataProvider[T], model ModelProvider[T]) (*ValidationResult[T], error) // GetMetrics returns current training metrics GetMetrics() map[string]interface{} // Shutdown cleans up resources Shutdown(ctx context.Context) error }
TrainingWorkflow orchestrates the complete training process with pluggable components. This interface allows domain-specific applications to customize training behavior while maintaining framework-agnostic core logic.
type ValidationResult ¶ added in v0.2.1
type ValidationResult[T tensor.Numeric] struct { Loss T `json:"loss"` Metrics map[string]float64 `json:"metrics"` SampleCount int `json:"sample_count"` ValidationTime float64 `json:"validation_time_seconds"` Extensions map[string]interface{} `json:"extensions"` }
ValidationResult contains validation outcome information.
type WindowedBackend ¶ added in v1.9.0
type WindowedBackend interface {
TrainWindowed(windows [][][]float64, labels []float64, config TrainConfig) (*TrainResult, error)
}
WindowedBackend is implemented by time-series models that consume temporal windows rather than flat feature vectors. Walk-forward validators check for this interface via type assertion before falling back to standard training.
type WindowedPredictor ¶ added in v1.9.0
type WindowedPredictor interface {
PredictWindowed(modelPath string, windows [][][]float64) ([]float64, error)
}
WindowedPredictor is implemented by models that predict from temporal windows instead of flat feature vectors.
type WorkflowConfig ¶ added in v0.2.1
type WorkflowConfig struct {
// Training configuration
NumEpochs int `json:"num_epochs"`
LearningRate float64 `json:"learning_rate"`
EarlyStopTol float64 `json:"early_stop_tolerance"`
MaxNoImprove int `json:"max_no_improve"`
RandomSeed uint64 `json:"random_seed"`
// Component configurations
BatchConfig BatchConfig `json:"batch_config"`
ModelConfig ModelConfig `json:"model_config"`
MetricConfigs map[string]interface{} `json:"metric_configs"`
// Extension point for domain-specific configuration
Extensions map[string]interface{} `json:"extensions"`
}
WorkflowConfig configures the training workflow.
type WorkflowFactory ¶ added in v0.2.1
type WorkflowFactory[T tensor.Numeric] func(ctx context.Context, config map[string]interface{}) (TrainingWorkflow[T], error)
WorkflowFactory creates TrainingWorkflow instances
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
Package automl provides automated machine learning utilities including Bayesian hyperparameter optimization.
|
Package automl provides automated machine learning utilities including Bayesian hyperparameter optimization. |
|
Package fp8 implements FP8 mixed-precision training support.
|
Package fp8 implements FP8 mixed-precision training support. |
|
Package lora implements LoRA and QLoRA fine-tuning adapters.
|
Package lora implements LoRA and QLoRA fine-tuning adapters. |
|
Package loss provides various loss functions for neural networks.
|
Package loss provides various loss functions for neural networks. |
|
Package nas implements neural architecture search using DARTS.
|
Package nas implements neural architecture search using DARTS. |
|
Package online implements online learning with drift detection and model rollback.
|
Package online implements online learning with drift detection and model rollback. |
|
Package optimizer provides various optimization algorithms for neural networks.
|
Package optimizer provides various optimization algorithms for neural networks. |
|
Package scheduler provides learning rate scheduling strategies for optimizers.
|
Package scheduler provides learning rate scheduling strategies for optimizers. |