online

package
v1.24.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 26, 2026 License: Apache-2.0 Imports: 14 Imported by: 0

Documentation

Overview

Package online implements online learning with drift detection and model rollback.

Stability: alpha

Package online provides online learning components for continuous model adaptation. It includes triggers that decide when retraining should occur based on loss drift or sample count schedules.

Index

Constants

View Source
const (
	EventTrigger    = "trigger"
	EventUpdate     = "update"
	EventRollback   = "rollback"
	EventValidation = "validation"
)

EventType constants identify the kind of audit event.

Variables

This section is empty.

Functions

This section is empty.

Types

type AuditEvent

type AuditEvent struct {
	Timestamp time.Time      `json:"timestamp"`
	EventType string         `json:"event_type"`
	Details   map[string]any `json:"details,omitempty"`
	Outcome   string         `json:"outcome"`
}

AuditEvent records a single auditable action in the online learning pipeline.

type AuditLog

type AuditLog struct {
	// contains filtered or unexported fields
}

AuditLog writes and reads AuditEvents as JSONL (one JSON object per line).

func NewAuditLog

func NewAuditLog(path string) (*AuditLog, error)

NewAuditLog opens or creates a JSONL audit log file at path.

func (*AuditLog) Close

func (a *AuditLog) Close() error

Close closes the underlying file.

func (*AuditLog) Log

func (a *AuditLog) Log(event AuditEvent) error

Log marshals the event to JSON, appends a newline, and flushes to disk.

func (*AuditLog) ReadAll

func (a *AuditLog) ReadAll() ([]AuditEvent, error)

ReadAll reads all events from the audit log file.

type AutoNASConfig

type AutoNASConfig struct {
	// ImprovementThreshold is the minimum relative Sharpe ratio improvement
	// required to propose the new architecture (e.g. 0.05 means 5%).
	ImprovementThreshold float64

	// SearchConfig is the NAS search configuration passed to RunSignalNAS.
	SearchConfig nas.SignalSearchConfig

	// Validators are the safety validators (ADR-052) that must approve
	// the architecture replacement before it is accepted.
	Validators []Validator
}

AutoNASConfig holds configuration for the automated NAS trigger.

type AutoNASTrigger

type AutoNASTrigger struct {
	// contains filtered or unexported fields
}

AutoNASTrigger listens for DriftAlert events and runs NAS search to discover improved architectures. When the discovered architecture's Sharpe ratio exceeds the current model's by at least ImprovementThreshold, it proposes the replacement through the safety validation pipeline.

func NewAutoNASTrigger

func NewAutoNASTrigger(cfg AutoNASConfig) *AutoNASTrigger

NewAutoNASTrigger creates a new AutoNASTrigger with the given configuration.

func (*AutoNASTrigger) OnDriftAlert

func (t *AutoNASTrigger) OnDriftAlert(
	ctx context.Context,
	alert DriftAlert,
	currentSharpe float64,
	data nas.SignalDataProvider,
) (*NASProposal, error)

OnDriftAlert processes a drift alert by running NAS search and proposing an architecture replacement if the discovered architecture is sufficiently better than the current model. The currentSharpe argument is the Sharpe ratio of the currently deployed model.

func (*AutoNASTrigger) Proposals

func (t *AutoNASTrigger) Proposals() []NASProposal

Proposals returns all recorded proposals.

type CompositeValidator

type CompositeValidator struct {
	Validators []Validator
}

CompositeValidator runs multiple validators and returns the first failure.

func NewCompositeValidator

func NewCompositeValidator(validators ...Validator) *CompositeValidator

NewCompositeValidator returns a CompositeValidator that runs all provided validators in order.

func (*CompositeValidator) Validate

func (c *CompositeValidator) Validate(before, after ModelSnapshot) ValidationResult

Validate runs each validator in order and returns the first failure. If all validators pass, it returns a passing result.

type DriftAlert

type DriftAlert struct {
	// Timestamp is when the alert was raised.
	Timestamp time.Time

	// CurrentSharpe is the Sharpe ratio of the current window.
	CurrentSharpe float64

	// MeanSharpe is the rolling mean of historical Sharpe ratios.
	MeanSharpe float64

	// Threshold is mean - 1 sigma; the alert fires when CurrentSharpe < Threshold.
	Threshold float64

	// WindowSize is the number of P&L observations in the rolling window.
	WindowSize int
}

DriftAlert is emitted when the current window Sharpe ratio falls below the rolling mean minus one standard deviation.

type DriftConfig

type DriftConfig struct {
	// WindowSize is the number of daily P&L observations in the rolling window.
	// Defaults to 30 if zero.
	WindowSize int
}

DriftConfig holds parameters for the DriftDetector.

type DriftDetector

type DriftDetector struct {
	// contains filtered or unexported fields
}

DriftDetector computes a rolling Sharpe ratio over a configurable window of daily P&L observations and raises an alert when the current Sharpe ratio drops below the historical mean minus one standard deviation.

func NewDriftDetector

func NewDriftDetector(cfg DriftConfig) *DriftDetector

NewDriftDetector creates a new DriftDetector with the given configuration.

func (*DriftDetector) CurrentSharpe

func (d *DriftDetector) CurrentSharpe() float64

CurrentSharpe returns the Sharpe ratio of the current window. Returns 0 if there is insufficient data.

func (*DriftDetector) Observe

func (d *DriftDetector) Observe(date time.Time, pnl float64) *DriftAlert

Observe records a daily P&L value and returns a DriftAlert if the current window Sharpe ratio has dropped below the rolling mean minus one sigma. Returns nil when there is no alert or insufficient data.

func (*DriftDetector) Window

func (d *DriftDetector) Window() []float64

Window returns a copy of the current P&L window.

type DriftTrigger

type DriftTrigger struct {
	Config TriggerConfig
	// Now returns the current time. If nil, time.Now is used.
	Now func() time.Time
}

DriftTrigger fires when the rolling mean loss increases by more than DriftThreshold relative to the baseline (first half of the eval window).

func (*DriftTrigger) RecordSample

func (d *DriftTrigger) RecordSample(state *TriggerState, loss float64)

RecordSample appends the loss to RecentLosses, trimming to EvalWindowSize.

func (*DriftTrigger) ShouldRetrain

func (d *DriftTrigger) ShouldRetrain(state *TriggerState, newLoss float64) bool

ShouldRetrain returns true when the rolling mean of the most recent half of the eval window exceeds the baseline (first half) by more than DriftThreshold, provided enough samples have been collected and the cooldown period has elapsed.

type EWC added in v1.8.0

type EWC struct {
	// contains filtered or unexported fields
}

EWC implements Elastic Weight Consolidation for continual learning. It prevents catastrophic forgetting by penalizing changes to parameters that were important for previously learned tasks. Importance is estimated via the diagonal of the Fisher information matrix.

func NewEWC added in v1.8.0

func NewEWC(weights []float64, fisherSamples int) *EWC

NewEWC creates a new EWC instance with the given model weights as the baseline and the specified number of samples for Fisher estimation. The lambda (regularization strength) defaults to 1.0 and can be set via SetLambda.

func (*EWC) Baseline added in v1.8.0

func (e *EWC) Baseline() []float64

Baseline returns a copy of the current baseline weights.

func (*EWC) ComputeFisher added in v1.8.0

func (e *EWC) ComputeFisher(data [][]float64, lossFn func([]float64) float64) error

ComputeFisher estimates the diagonal of the Fisher information matrix using the provided data and loss function. For each data point, it constructs a per-sample loss (by passing each data point to lossFn as a single-element dataset) and computes the squared gradient via finite differences.

The lossFn receives a weight vector and a single data point, and returns the scalar loss for that sample.

func (*EWC) Fisher added in v1.8.0

func (e *EWC) Fisher() []float64

Fisher returns a copy of the current Fisher information diagonal. Returns nil if Fisher has not been computed.

func (*EWC) Lambda added in v1.8.0

func (e *EWC) Lambda() float64

Lambda returns the current regularization strength.

func (*EWC) Loss added in v1.8.0

func (e *EWC) Loss(taskLoss float64, currentWeights []float64) float64

Loss computes a total loss that includes both the task loss and the EWC penalty: totalLoss = taskLoss + Penalty(currentWeights). This is a convenience method for use in training loops.

func (*EWC) Penalty added in v1.8.0

func (e *EWC) Penalty(currentWeights []float64) float64

Penalty computes the EWC penalty term for the given current weights. The penalty is: (lambda / 2) * sum_i(fisher[i] * (currentWeights[i] - baseline[i])^2) Returns 0 if Fisher has not been computed yet.

func (*EWC) SetLambda added in v1.8.0

func (e *EWC) SetLambda(lambda float64)

SetLambda sets the EWC regularization strength. Higher values more strongly penalize deviations from the baseline weights on important parameters.

func (*EWC) UpdateBaseline added in v1.8.0

func (e *EWC) UpdateBaseline(newWeights []float64)

UpdateBaseline updates the reference weights to a new set of weights. This should be called after the model has finished learning a new task, so that future EWC penalties are computed relative to the new optimal weights. The Fisher information is preserved across baseline updates.

type FeedbackCollector

type FeedbackCollector struct {
	// contains filtered or unexported fields
}

FeedbackCollector buffers feedback signals in memory and periodically flushes them to JSONL files on disk.

func NewFeedbackCollector

func NewFeedbackCollector(cfg FeedbackConfig) (*FeedbackCollector, error)

NewFeedbackCollector creates a new FeedbackCollector, creating the StoragePath directory if it does not exist.

func (*FeedbackCollector) Close

func (fc *FeedbackCollector) Close() error

Close releases resources held by the collector.

func (*FeedbackCollector) Flush

func (fc *FeedbackCollector) Flush() ([]FeedbackSignal, error)

Flush writes all buffered signals to a JSONL file on disk, returns them, and clears the buffer.

func (*FeedbackCollector) ReadAll

func (fc *FeedbackCollector) ReadAll() ([]FeedbackSignal, error)

ReadAll reads all feedback signals from JSONL files in the storage directory.

func (*FeedbackCollector) Record

func (fc *FeedbackCollector) Record(signal FeedbackSignal) error

Record appends a signal to the in-memory buffer. If the buffer reaches BufferSize, it is automatically flushed to disk.

func (*FeedbackCollector) Start

func (fc *FeedbackCollector) Start(ctx context.Context)

Start begins a background goroutine that flushes buffered signals every FlushInterval.

func (*FeedbackCollector) Stop

func (fc *FeedbackCollector) Stop()

Stop stops the background goroutine and performs a final flush.

type FeedbackConfig

type FeedbackConfig struct {
	// BufferSize is the maximum number of signals buffered in memory
	// before an automatic flush to disk.
	BufferSize int

	// FlushInterval is the duration between periodic background flushes.
	FlushInterval time.Duration

	// StoragePath is the directory where JSONL feedback files are written.
	StoragePath string
}

FeedbackConfig holds parameters for the FeedbackCollector.

type FeedbackSignal

type FeedbackSignal struct {
	Timestamp    time.Time `json:"timestamp"`
	PredictionID string    `json:"prediction_id"`
	Predicted    []float32 `json:"predicted"`
	Actual       []float32 `json:"actual"`
	Score        float64   `json:"score"`
}

FeedbackSignal represents a single feedback observation comparing a model's prediction against the actual outcome.

type IncrementalUpdater

type IncrementalUpdater struct {
	// contains filtered or unexported fields
}

IncrementalUpdater applies incremental LoRA updates using SGD. It maintains adapter weights and supports rollback to the pre-update state.

func NewIncrementalUpdater

func NewIncrementalUpdater(cfg LoRAUpdateConfig) *IncrementalUpdater

NewIncrementalUpdater creates a new updater with the given configuration. The adapter dimensions are inferred from the first Update call's sample sizes.

func (*IncrementalUpdater) CommitUpdate

func (u *IncrementalUpdater) CommitUpdate()

CommitUpdate finalizes the current update by discarding the rollback snapshot.

func (*IncrementalUpdater) CurrentLoss

func (u *IncrementalUpdater) CurrentLoss(samples []Sample) float64

CurrentLoss computes the MSE loss over the given samples using the current adapter weights. Returns +Inf if there are no samples or the adapter is not initialized.

func (*IncrementalUpdater) Rollback

func (u *IncrementalUpdater) Rollback() error

Rollback restores the adapter weights to the state before the last Update call. Returns an error if no snapshot is available (e.g. after CommitUpdate or before any Update).

func (*IncrementalUpdater) Update

func (u *IncrementalUpdater) Update(samples []Sample) error

Update applies MaxSteps of SGD on the LoRA adapter parameters using the provided samples. Each step iterates over all samples, computing MSE loss gradients and updating A and B matrices in-place.

type LoRAUpdateConfig

type LoRAUpdateConfig struct {
	// Rank is the low-rank dimension for LoRA adapter matrices.
	Rank int

	// Alpha is the scaling factor; the LoRA output is multiplied by Alpha/Rank.
	Alpha int

	// LR is the learning rate for SGD updates.
	LR float64

	// MaxSteps is the maximum number of gradient descent steps per Update call.
	MaxSteps int

	// TargetModules lists the module names that LoRA adapters are applied to.
	TargetModules []string
}

LoRAUpdateConfig holds parameters for an incremental LoRA update pass.

type LossDeltaValidator

type LossDeltaValidator struct {
	MaxLossDelta float64
}

LossDeltaValidator rejects updates where the loss increases by more than MaxLossDelta.

func NewLossDeltaValidator

func NewLossDeltaValidator(maxDelta float64) *LossDeltaValidator

NewLossDeltaValidator returns a LossDeltaValidator with the given threshold.

func (*LossDeltaValidator) Validate

func (v *LossDeltaValidator) Validate(before, after ModelSnapshot) ValidationResult

Validate rejects the update if after.Loss - before.Loss > MaxLossDelta.

type ModelSnapshot

type ModelSnapshot struct {
	// Weights maps layer names to their weight tensors as flat float32 slices.
	Weights map[string][]float32
	// Loss is the model loss on a validation set.
	Loss float64
}

ModelSnapshot captures model state for comparison during validation.

type NASProposal

type NASProposal struct {
	// Alert is the drift alert that triggered the NAS search.
	Alert DriftAlert

	// SearchOutput is the full NAS search output.
	SearchOutput *nas.SignalSearchOutput

	// CurrentSharpe is the Sharpe ratio of the current model at trigger time.
	CurrentSharpe float64

	// DiscoveredSharpe is the Sharpe-like metric of the discovered architecture.
	DiscoveredSharpe float64

	// Improvement is the relative improvement (discovered - current) / |current|.
	Improvement float64

	// Accepted indicates whether the proposal passed safety validation.
	Accepted bool

	// RejectionReason is non-empty if the proposal was rejected by a validator.
	RejectionReason string
}

NASProposal represents a proposed architecture replacement discovered by automated NAS after a drift event.

type RollbackConfig

type RollbackConfig struct {
	// MaxVersions is the maximum number of snapshots to retain.
	MaxVersions int
	// StoragePath is the directory where snapshot files are stored.
	StoragePath string
}

RollbackConfig holds parameters for the rollback manager.

type RollbackManager

type RollbackManager struct {
	// contains filtered or unexported fields
}

RollbackManager manages versioned model weight snapshots on disk, supporting snapshot creation, rollback, listing, and pruning.

func NewRollbackManager

func NewRollbackManager(cfg RollbackConfig) (*RollbackManager, error)

NewRollbackManager creates a RollbackManager, creating StoragePath if it does not exist. It scans the directory for existing snapshots so that persistence across restarts is maintained.

func (*RollbackManager) Close

func (m *RollbackManager) Close() error

Close is a no-op that satisfies resource cleanup conventions.

func (*RollbackManager) ListSnapshots

func (m *RollbackManager) ListSnapshots() []string

ListSnapshots returns snapshot IDs sorted by creation time, newest first.

func (*RollbackManager) Prune

func (m *RollbackManager) Prune() error

Prune deletes the oldest snapshots beyond MaxVersions.

func (*RollbackManager) Rollback

func (m *RollbackManager) Rollback(id string) (map[string][]float32, error)

Rollback deserializes and returns the weights for the given snapshot ID.

func (*RollbackManager) Snapshot

func (m *RollbackManager) Snapshot(id string, weights map[string][]float32) error

Snapshot serializes weights to a gob file and evicts the oldest snapshot if the number of snapshots exceeds MaxVersions.

type Sample

type Sample struct {
	Input []float32
	Label []float32
}

Sample represents a single training example with input features and target labels.

type ScheduledTrigger

type ScheduledTrigger struct {
	Config TriggerConfig
	// Interval is the number of samples between triggers.
	Interval int
	// Now returns the current time. If nil, time.Now is used.
	Now func() time.Time
}

ScheduledTrigger fires every N samples regardless of loss values.

func (*ScheduledTrigger) RecordSample

func (s *ScheduledTrigger) RecordSample(state *TriggerState, loss float64)

RecordSample appends the loss to RecentLosses and increments SampleCount.

func (*ScheduledTrigger) ShouldRetrain

func (s *ScheduledTrigger) ShouldRetrain(state *TriggerState, _ float64) bool

ShouldRetrain returns true every Interval samples, provided the cooldown period has elapsed.

type Trigger

type Trigger interface {
	// ShouldRetrain returns true if conditions are met for retraining
	// given the current state and a new loss observation.
	ShouldRetrain(state *TriggerState, newLoss float64) bool

	// RecordSample records a new loss observation into the trigger state.
	RecordSample(state *TriggerState, loss float64)
}

Trigger decides when a model should be retrained.

type TriggerConfig

type TriggerConfig struct {
	// DriftThreshold is the relative increase in rolling mean loss over
	// baseline that triggers retraining (e.g. 0.1 means 10% increase).
	DriftThreshold float64

	// MinSampleCount is the minimum number of samples that must be
	// recorded before a trigger can fire.
	MinSampleCount int

	// EvalWindowSize is the number of recent losses used to compute the
	// rolling mean for drift detection.
	EvalWindowSize int

	// CooldownPeriod is the minimum duration between consecutive triggers.
	CooldownPeriod time.Duration
}

TriggerConfig holds parameters that control when retraining is triggered.

type TriggerState

type TriggerState struct {
	// LastTriggerTime is the time the trigger last fired.
	LastTriggerTime time.Time

	// SampleCount is the total number of samples recorded.
	SampleCount int

	// RecentLosses holds the most recent loss values, up to
	// EvalWindowSize entries.
	RecentLosses []float64
}

TriggerState tracks the mutable state for a trigger across evaluations.

type ValidationConfig

type ValidationConfig struct {
	// MaxLossDelta is the maximum allowed increase in loss (after - before).
	MaxLossDelta float64
	// MaxWeightNorm is the maximum allowed L2 norm for any single weight tensor.
	MaxWeightNorm float64
	// MaxGradNorm is the maximum allowed gradient norm (reserved for future use).
	MaxGradNorm float64
}

ValidationConfig holds thresholds for the built-in validators.

type ValidationResult

type ValidationResult struct {
	// Pass is true if the validation passed.
	Pass bool
	// Reason describes why validation failed (empty when Pass is true).
	Reason string
}

ValidationResult is the outcome of a safety validation check.

type Validator

type Validator interface {
	// Validate compares the model state before and after an update and
	// returns whether the update should be accepted.
	Validate(before, after ModelSnapshot) ValidationResult
}

Validator checks whether a model update is safe to promote.

type WeightNormValidator

type WeightNormValidator struct {
	MaxWeightNorm float64
}

WeightNormValidator rejects updates where the L2 norm of any weight tensor in the updated model exceeds MaxWeightNorm.

func NewWeightNormValidator

func NewWeightNormValidator(maxNorm float64) *WeightNormValidator

NewWeightNormValidator returns a WeightNormValidator with the given threshold.

func (*WeightNormValidator) Validate

func (v *WeightNormValidator) Validate(_, after ModelSnapshot) ValidationResult

Validate rejects the update if any weight tensor in after has an L2 norm exceeding MaxWeightNorm.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL