Documentation
¶
Overview ¶
Package online implements online learning with drift detection and model rollback.
Stability: alpha
Package online provides online learning components for continuous model adaptation. It includes triggers that decide when retraining should occur based on loss drift or sample count schedules.
Index ¶
- Constants
- type AuditEvent
- type AuditLog
- type AutoNASConfig
- type AutoNASTrigger
- type CompositeValidator
- type DriftAlert
- type DriftConfig
- type DriftDetector
- type DriftTrigger
- type EWC
- func (e *EWC) Baseline() []float64
- func (e *EWC) ComputeFisher(data [][]float64, lossFn func([]float64) float64) error
- func (e *EWC) Fisher() []float64
- func (e *EWC) Lambda() float64
- func (e *EWC) Loss(taskLoss float64, currentWeights []float64) float64
- func (e *EWC) Penalty(currentWeights []float64) float64
- func (e *EWC) SetLambda(lambda float64)
- func (e *EWC) UpdateBaseline(newWeights []float64)
- type FeedbackCollector
- func (fc *FeedbackCollector) Close() error
- func (fc *FeedbackCollector) Flush() ([]FeedbackSignal, error)
- func (fc *FeedbackCollector) ReadAll() ([]FeedbackSignal, error)
- func (fc *FeedbackCollector) Record(signal FeedbackSignal) error
- func (fc *FeedbackCollector) Start(ctx context.Context)
- func (fc *FeedbackCollector) Stop()
- type FeedbackConfig
- type FeedbackSignal
- type IncrementalUpdater
- type LoRAUpdateConfig
- type LossDeltaValidator
- type ModelSnapshot
- type NASProposal
- type RollbackConfig
- type RollbackManager
- type Sample
- type ScheduledTrigger
- type Trigger
- type TriggerConfig
- type TriggerState
- type ValidationConfig
- type ValidationResult
- type Validator
- type WeightNormValidator
Constants ¶
const ( EventTrigger = "trigger" EventUpdate = "update" EventRollback = "rollback" EventValidation = "validation" )
EventType constants identify the kind of audit event.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type AuditEvent ¶
type AuditEvent struct {
Timestamp time.Time `json:"timestamp"`
EventType string `json:"event_type"`
Details map[string]any `json:"details,omitempty"`
Outcome string `json:"outcome"`
}
AuditEvent records a single auditable action in the online learning pipeline.
type AuditLog ¶
type AuditLog struct {
// contains filtered or unexported fields
}
AuditLog writes and reads AuditEvents as JSONL (one JSON object per line).
func NewAuditLog ¶
NewAuditLog opens or creates a JSONL audit log file at path.
func (*AuditLog) Log ¶
func (a *AuditLog) Log(event AuditEvent) error
Log marshals the event to JSON, appends a newline, and flushes to disk.
func (*AuditLog) ReadAll ¶
func (a *AuditLog) ReadAll() ([]AuditEvent, error)
ReadAll reads all events from the audit log file.
type AutoNASConfig ¶
type AutoNASConfig struct {
// ImprovementThreshold is the minimum relative Sharpe ratio improvement
// required to propose the new architecture (e.g. 0.05 means 5%).
ImprovementThreshold float64
// SearchConfig is the NAS search configuration passed to RunSignalNAS.
SearchConfig nas.SignalSearchConfig
// Validators are the safety validators (ADR-052) that must approve
// the architecture replacement before it is accepted.
Validators []Validator
}
AutoNASConfig holds configuration for the automated NAS trigger.
type AutoNASTrigger ¶
type AutoNASTrigger struct {
// contains filtered or unexported fields
}
AutoNASTrigger listens for DriftAlert events and runs NAS search to discover improved architectures. When the discovered architecture's Sharpe ratio exceeds the current model's by at least ImprovementThreshold, it proposes the replacement through the safety validation pipeline.
func NewAutoNASTrigger ¶
func NewAutoNASTrigger(cfg AutoNASConfig) *AutoNASTrigger
NewAutoNASTrigger creates a new AutoNASTrigger with the given configuration.
func (*AutoNASTrigger) OnDriftAlert ¶
func (t *AutoNASTrigger) OnDriftAlert( ctx context.Context, alert DriftAlert, currentSharpe float64, data nas.SignalDataProvider, ) (*NASProposal, error)
OnDriftAlert processes a drift alert by running NAS search and proposing an architecture replacement if the discovered architecture is sufficiently better than the current model. The currentSharpe argument is the Sharpe ratio of the currently deployed model.
func (*AutoNASTrigger) Proposals ¶
func (t *AutoNASTrigger) Proposals() []NASProposal
Proposals returns all recorded proposals.
type CompositeValidator ¶
type CompositeValidator struct {
Validators []Validator
}
CompositeValidator runs multiple validators and returns the first failure.
func NewCompositeValidator ¶
func NewCompositeValidator(validators ...Validator) *CompositeValidator
NewCompositeValidator returns a CompositeValidator that runs all provided validators in order.
func (*CompositeValidator) Validate ¶
func (c *CompositeValidator) Validate(before, after ModelSnapshot) ValidationResult
Validate runs each validator in order and returns the first failure. If all validators pass, it returns a passing result.
type DriftAlert ¶
type DriftAlert struct {
// Timestamp is when the alert was raised.
Timestamp time.Time
// CurrentSharpe is the Sharpe ratio of the current window.
CurrentSharpe float64
// MeanSharpe is the rolling mean of historical Sharpe ratios.
MeanSharpe float64
// Threshold is mean - 1 sigma; the alert fires when CurrentSharpe < Threshold.
Threshold float64
// WindowSize is the number of P&L observations in the rolling window.
WindowSize int
}
DriftAlert is emitted when the current window Sharpe ratio falls below the rolling mean minus one standard deviation.
type DriftConfig ¶
type DriftConfig struct {
// WindowSize is the number of daily P&L observations in the rolling window.
// Defaults to 30 if zero.
WindowSize int
}
DriftConfig holds parameters for the DriftDetector.
type DriftDetector ¶
type DriftDetector struct {
// contains filtered or unexported fields
}
DriftDetector computes a rolling Sharpe ratio over a configurable window of daily P&L observations and raises an alert when the current Sharpe ratio drops below the historical mean minus one standard deviation.
func NewDriftDetector ¶
func NewDriftDetector(cfg DriftConfig) *DriftDetector
NewDriftDetector creates a new DriftDetector with the given configuration.
func (*DriftDetector) CurrentSharpe ¶
func (d *DriftDetector) CurrentSharpe() float64
CurrentSharpe returns the Sharpe ratio of the current window. Returns 0 if there is insufficient data.
func (*DriftDetector) Observe ¶
func (d *DriftDetector) Observe(date time.Time, pnl float64) *DriftAlert
Observe records a daily P&L value and returns a DriftAlert if the current window Sharpe ratio has dropped below the rolling mean minus one sigma. Returns nil when there is no alert or insufficient data.
func (*DriftDetector) Window ¶
func (d *DriftDetector) Window() []float64
Window returns a copy of the current P&L window.
type DriftTrigger ¶
type DriftTrigger struct {
Config TriggerConfig
// Now returns the current time. If nil, time.Now is used.
Now func() time.Time
}
DriftTrigger fires when the rolling mean loss increases by more than DriftThreshold relative to the baseline (first half of the eval window).
func (*DriftTrigger) RecordSample ¶
func (d *DriftTrigger) RecordSample(state *TriggerState, loss float64)
RecordSample appends the loss to RecentLosses, trimming to EvalWindowSize.
func (*DriftTrigger) ShouldRetrain ¶
func (d *DriftTrigger) ShouldRetrain(state *TriggerState, newLoss float64) bool
ShouldRetrain returns true when the rolling mean of the most recent half of the eval window exceeds the baseline (first half) by more than DriftThreshold, provided enough samples have been collected and the cooldown period has elapsed.
type EWC ¶ added in v1.8.0
type EWC struct {
// contains filtered or unexported fields
}
EWC implements Elastic Weight Consolidation for continual learning. It prevents catastrophic forgetting by penalizing changes to parameters that were important for previously learned tasks. Importance is estimated via the diagonal of the Fisher information matrix.
func NewEWC ¶ added in v1.8.0
NewEWC creates a new EWC instance with the given model weights as the baseline and the specified number of samples for Fisher estimation. The lambda (regularization strength) defaults to 1.0 and can be set via SetLambda.
func (*EWC) ComputeFisher ¶ added in v1.8.0
ComputeFisher estimates the diagonal of the Fisher information matrix using the provided data and loss function. For each data point, it constructs a per-sample loss (by passing each data point to lossFn as a single-element dataset) and computes the squared gradient via finite differences.
The lossFn receives a weight vector and a single data point, and returns the scalar loss for that sample.
func (*EWC) Fisher ¶ added in v1.8.0
Fisher returns a copy of the current Fisher information diagonal. Returns nil if Fisher has not been computed.
func (*EWC) Loss ¶ added in v1.8.0
Loss computes a total loss that includes both the task loss and the EWC penalty: totalLoss = taskLoss + Penalty(currentWeights). This is a convenience method for use in training loops.
func (*EWC) Penalty ¶ added in v1.8.0
Penalty computes the EWC penalty term for the given current weights. The penalty is: (lambda / 2) * sum_i(fisher[i] * (currentWeights[i] - baseline[i])^2) Returns 0 if Fisher has not been computed yet.
func (*EWC) SetLambda ¶ added in v1.8.0
SetLambda sets the EWC regularization strength. Higher values more strongly penalize deviations from the baseline weights on important parameters.
func (*EWC) UpdateBaseline ¶ added in v1.8.0
UpdateBaseline updates the reference weights to a new set of weights. This should be called after the model has finished learning a new task, so that future EWC penalties are computed relative to the new optimal weights. The Fisher information is preserved across baseline updates.
type FeedbackCollector ¶
type FeedbackCollector struct {
// contains filtered or unexported fields
}
FeedbackCollector buffers feedback signals in memory and periodically flushes them to JSONL files on disk.
func NewFeedbackCollector ¶
func NewFeedbackCollector(cfg FeedbackConfig) (*FeedbackCollector, error)
NewFeedbackCollector creates a new FeedbackCollector, creating the StoragePath directory if it does not exist.
func (*FeedbackCollector) Close ¶
func (fc *FeedbackCollector) Close() error
Close releases resources held by the collector.
func (*FeedbackCollector) Flush ¶
func (fc *FeedbackCollector) Flush() ([]FeedbackSignal, error)
Flush writes all buffered signals to a JSONL file on disk, returns them, and clears the buffer.
func (*FeedbackCollector) ReadAll ¶
func (fc *FeedbackCollector) ReadAll() ([]FeedbackSignal, error)
ReadAll reads all feedback signals from JSONL files in the storage directory.
func (*FeedbackCollector) Record ¶
func (fc *FeedbackCollector) Record(signal FeedbackSignal) error
Record appends a signal to the in-memory buffer. If the buffer reaches BufferSize, it is automatically flushed to disk.
func (*FeedbackCollector) Start ¶
func (fc *FeedbackCollector) Start(ctx context.Context)
Start begins a background goroutine that flushes buffered signals every FlushInterval.
func (*FeedbackCollector) Stop ¶
func (fc *FeedbackCollector) Stop()
Stop stops the background goroutine and performs a final flush.
type FeedbackConfig ¶
type FeedbackConfig struct {
// BufferSize is the maximum number of signals buffered in memory
// before an automatic flush to disk.
BufferSize int
// FlushInterval is the duration between periodic background flushes.
FlushInterval time.Duration
// StoragePath is the directory where JSONL feedback files are written.
StoragePath string
}
FeedbackConfig holds parameters for the FeedbackCollector.
type FeedbackSignal ¶
type FeedbackSignal struct {
Timestamp time.Time `json:"timestamp"`
PredictionID string `json:"prediction_id"`
Predicted []float32 `json:"predicted"`
Actual []float32 `json:"actual"`
Score float64 `json:"score"`
}
FeedbackSignal represents a single feedback observation comparing a model's prediction against the actual outcome.
type IncrementalUpdater ¶
type IncrementalUpdater struct {
// contains filtered or unexported fields
}
IncrementalUpdater applies incremental LoRA updates using SGD. It maintains adapter weights and supports rollback to the pre-update state.
func NewIncrementalUpdater ¶
func NewIncrementalUpdater(cfg LoRAUpdateConfig) *IncrementalUpdater
NewIncrementalUpdater creates a new updater with the given configuration. The adapter dimensions are inferred from the first Update call's sample sizes.
func (*IncrementalUpdater) CommitUpdate ¶
func (u *IncrementalUpdater) CommitUpdate()
CommitUpdate finalizes the current update by discarding the rollback snapshot.
func (*IncrementalUpdater) CurrentLoss ¶
func (u *IncrementalUpdater) CurrentLoss(samples []Sample) float64
CurrentLoss computes the MSE loss over the given samples using the current adapter weights. Returns +Inf if there are no samples or the adapter is not initialized.
func (*IncrementalUpdater) Rollback ¶
func (u *IncrementalUpdater) Rollback() error
Rollback restores the adapter weights to the state before the last Update call. Returns an error if no snapshot is available (e.g. after CommitUpdate or before any Update).
func (*IncrementalUpdater) Update ¶
func (u *IncrementalUpdater) Update(samples []Sample) error
Update applies MaxSteps of SGD on the LoRA adapter parameters using the provided samples. Each step iterates over all samples, computing MSE loss gradients and updating A and B matrices in-place.
type LoRAUpdateConfig ¶
type LoRAUpdateConfig struct {
// Rank is the low-rank dimension for LoRA adapter matrices.
Rank int
// Alpha is the scaling factor; the LoRA output is multiplied by Alpha/Rank.
Alpha int
// LR is the learning rate for SGD updates.
LR float64
// MaxSteps is the maximum number of gradient descent steps per Update call.
MaxSteps int
// TargetModules lists the module names that LoRA adapters are applied to.
TargetModules []string
}
LoRAUpdateConfig holds parameters for an incremental LoRA update pass.
type LossDeltaValidator ¶
type LossDeltaValidator struct {
MaxLossDelta float64
}
LossDeltaValidator rejects updates where the loss increases by more than MaxLossDelta.
func NewLossDeltaValidator ¶
func NewLossDeltaValidator(maxDelta float64) *LossDeltaValidator
NewLossDeltaValidator returns a LossDeltaValidator with the given threshold.
func (*LossDeltaValidator) Validate ¶
func (v *LossDeltaValidator) Validate(before, after ModelSnapshot) ValidationResult
Validate rejects the update if after.Loss - before.Loss > MaxLossDelta.
type ModelSnapshot ¶
type ModelSnapshot struct {
// Weights maps layer names to their weight tensors as flat float32 slices.
Weights map[string][]float32
// Loss is the model loss on a validation set.
Loss float64
}
ModelSnapshot captures model state for comparison during validation.
type NASProposal ¶
type NASProposal struct {
// Alert is the drift alert that triggered the NAS search.
Alert DriftAlert
// SearchOutput is the full NAS search output.
SearchOutput *nas.SignalSearchOutput
// CurrentSharpe is the Sharpe ratio of the current model at trigger time.
CurrentSharpe float64
// DiscoveredSharpe is the Sharpe-like metric of the discovered architecture.
DiscoveredSharpe float64
// Improvement is the relative improvement (discovered - current) / |current|.
Improvement float64
// Accepted indicates whether the proposal passed safety validation.
Accepted bool
// RejectionReason is non-empty if the proposal was rejected by a validator.
RejectionReason string
}
NASProposal represents a proposed architecture replacement discovered by automated NAS after a drift event.
type RollbackConfig ¶
type RollbackConfig struct {
// MaxVersions is the maximum number of snapshots to retain.
MaxVersions int
// StoragePath is the directory where snapshot files are stored.
StoragePath string
}
RollbackConfig holds parameters for the rollback manager.
type RollbackManager ¶
type RollbackManager struct {
// contains filtered or unexported fields
}
RollbackManager manages versioned model weight snapshots on disk, supporting snapshot creation, rollback, listing, and pruning.
func NewRollbackManager ¶
func NewRollbackManager(cfg RollbackConfig) (*RollbackManager, error)
NewRollbackManager creates a RollbackManager, creating StoragePath if it does not exist. It scans the directory for existing snapshots so that persistence across restarts is maintained.
func (*RollbackManager) Close ¶
func (m *RollbackManager) Close() error
Close is a no-op that satisfies resource cleanup conventions.
func (*RollbackManager) ListSnapshots ¶
func (m *RollbackManager) ListSnapshots() []string
ListSnapshots returns snapshot IDs sorted by creation time, newest first.
func (*RollbackManager) Prune ¶
func (m *RollbackManager) Prune() error
Prune deletes the oldest snapshots beyond MaxVersions.
type ScheduledTrigger ¶
type ScheduledTrigger struct {
Config TriggerConfig
// Interval is the number of samples between triggers.
Interval int
// Now returns the current time. If nil, time.Now is used.
Now func() time.Time
}
ScheduledTrigger fires every N samples regardless of loss values.
func (*ScheduledTrigger) RecordSample ¶
func (s *ScheduledTrigger) RecordSample(state *TriggerState, loss float64)
RecordSample appends the loss to RecentLosses and increments SampleCount.
func (*ScheduledTrigger) ShouldRetrain ¶
func (s *ScheduledTrigger) ShouldRetrain(state *TriggerState, _ float64) bool
ShouldRetrain returns true every Interval samples, provided the cooldown period has elapsed.
type Trigger ¶
type Trigger interface {
// ShouldRetrain returns true if conditions are met for retraining
// given the current state and a new loss observation.
ShouldRetrain(state *TriggerState, newLoss float64) bool
// RecordSample records a new loss observation into the trigger state.
RecordSample(state *TriggerState, loss float64)
}
Trigger decides when a model should be retrained.
type TriggerConfig ¶
type TriggerConfig struct {
// DriftThreshold is the relative increase in rolling mean loss over
// baseline that triggers retraining (e.g. 0.1 means 10% increase).
DriftThreshold float64
// MinSampleCount is the minimum number of samples that must be
// recorded before a trigger can fire.
MinSampleCount int
// EvalWindowSize is the number of recent losses used to compute the
// rolling mean for drift detection.
EvalWindowSize int
// CooldownPeriod is the minimum duration between consecutive triggers.
CooldownPeriod time.Duration
}
TriggerConfig holds parameters that control when retraining is triggered.
type TriggerState ¶
type TriggerState struct {
// LastTriggerTime is the time the trigger last fired.
LastTriggerTime time.Time
// SampleCount is the total number of samples recorded.
SampleCount int
// RecentLosses holds the most recent loss values, up to
// EvalWindowSize entries.
RecentLosses []float64
}
TriggerState tracks the mutable state for a trigger across evaluations.
type ValidationConfig ¶
type ValidationConfig struct {
// MaxLossDelta is the maximum allowed increase in loss (after - before).
MaxLossDelta float64
// MaxWeightNorm is the maximum allowed L2 norm for any single weight tensor.
MaxWeightNorm float64
// MaxGradNorm is the maximum allowed gradient norm (reserved for future use).
MaxGradNorm float64
}
ValidationConfig holds thresholds for the built-in validators.
type ValidationResult ¶
type ValidationResult struct {
// Pass is true if the validation passed.
Pass bool
// Reason describes why validation failed (empty when Pass is true).
Reason string
}
ValidationResult is the outcome of a safety validation check.
type Validator ¶
type Validator interface {
// Validate compares the model state before and after an update and
// returns whether the update should be accepted.
Validate(before, after ModelSnapshot) ValidationResult
}
Validator checks whether a model update is safe to promote.
type WeightNormValidator ¶
type WeightNormValidator struct {
MaxWeightNorm float64
}
WeightNormValidator rejects updates where the L2 norm of any weight tensor in the updated model exceeds MaxWeightNorm.
func NewWeightNormValidator ¶
func NewWeightNormValidator(maxNorm float64) *WeightNormValidator
NewWeightNormValidator returns a WeightNormValidator with the given threshold.
func (*WeightNormValidator) Validate ¶
func (v *WeightNormValidator) Validate(_, after ModelSnapshot) ValidationResult
Validate rejects the update if any weight tensor in after has an L2 norm exceeding MaxWeightNorm.