nas

package
v1.5.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 18, 2026 License: Apache-2.0 Imports: 15 Imported by: 0

Documentation

Overview

Package nas implements Neural Architecture Search for the Zerfoo ML framework.

The search space is defined as a directed acyclic graph (DAG) of cells, where each cell contains nodes connected by edges. Each edge carries an operation type (e.g., convolution, pooling, skip connection). The search space can be sampled randomly or enumerated exhaustively for small spaces.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func DefaultOpCosts

func DefaultOpCosts() map[OpType]OpCost

DefaultOpCosts returns cost estimates for each operation type. The values model relative cost assuming a spatial dimension of 32x32 with 64 channels.

func DefaultOpParams

func DefaultOpParams() map[OpType]int64

DefaultOpParams returns estimated parameter counts for each operation type, assuming a spatial dimension of 32x32 with 64 input/output channels.

func ExportGGUF

func ExportGGUF(w io.Writer, arch *DiscretizedArch, cfg ExportConfig, weights map[string][]float32, shapes map[string][]int) error

ExportGGUF writes a NAS-discovered architecture and its trained weights to a GGUF v3 file. The architecture topology is encoded in GGUF metadata under nas.* keys, and model hyperparameters are stored under ts.signal.* keys for compatibility with the standard time-series inference path.

The weights map keys are tensor names (e.g., "blk.0.attn_q.weight") and values are flat float32 slices. Each tensor's shape is provided in the shapes map with the same keys. Shapes use row-major (PyTorch) convention; they are reversed to GGML order on write.

func LoadNASArchFromGGUF

func LoadNASArchFromGGUF(f *gguf.File) (*DiscretizedArch, ExportConfig, error)

LoadNASArchFromGGUF reads a NAS-exported GGUF file and reconstructs the DiscretizedArch and ExportConfig from its metadata. This enables round-trip verification: export then load back and confirm the architecture matches.

func SharpeRatio

func SharpeRatio(returns []float64) float64

SharpeRatio computes the Sharpe ratio from a series of returns. Returns 0 if there are fewer than 2 values or if the standard deviation is 0.

func ValidateExportRoundTrip

func ValidateExportRoundTrip(arch *DiscretizedArch, cfg ExportConfig, weights map[string][]float32, shapes map[string][]int) error

ValidateExportRoundTrip exports a NAS architecture to GGUF, parses it back, and verifies the architecture matches. Returns an error if any field differs.

Types

type CalibrationPoint

type CalibrationPoint struct {
	Cell    Cell
	Latency float64 // measured latency in seconds
}

CalibrationPoint pairs a cell architecture with its measured latency.

type Cell

type Cell struct {
	NumNodes int
	Edges    []Edge
}

Cell represents a single architecture cell as a DAG of edges between nodes.

func (Cell) Valid

func (c Cell) Valid() bool

Valid reports whether the cell is a valid DAG: every edge must have From < To (which guarantees no cycles), and node indices must be in [0, NumNodes).

type DARTSLayer

type DARTSLayer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

DARTSLayer implements a DARTS (Differentiable Architecture Search) mixed-operation layer. It computes a softmax-weighted mixture of candidate operations, where the architecture parameters (alpha) are learnable and the forward pass is differentiable through the softmax weights.

func NewDARTSLayer

func NewDARTSLayer[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], candidates []graph.Node[T]) (*DARTSLayer[T], error)

NewDARTSLayer creates a new DARTS mixed-operation layer with the given candidate operations. The alpha architecture parameters are initialized to zero, giving uniform softmax weights. At least 2 candidates are required.

func (*DARTSLayer[T]) Attributes

func (d *DARTSLayer[T]) Attributes() map[string]interface{}

Attributes returns the layer's attributes.

func (*DARTSLayer[T]) Backward

func (d *DARTSLayer[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for both the input and the alpha architecture parameters.

Given output = sum_i w_i * op_i(x) where w = softmax(alpha):

  • dInput = sum_i w_i * op_i.Backward(dOut)
  • dAlpha_k = sum_j dOut_j * (sum_i op_i(x)_j * (delta_{ik} * w_i - w_i * w_k)) which simplifies to: dAlpha_k = w_k * dot(dOut, op_k(x) - output)

func (*DARTSLayer[T]) Forward

func (d *DARTSLayer[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the softmax-weighted mixture of all candidate operations. output = sum_i softmax(alpha)_i * op_i(input)

func (*DARTSLayer[T]) OpType

func (d *DARTSLayer[T]) OpType() string

OpType returns the operation type identifier.

func (*DARTSLayer[T]) OutputShape

func (d *DARTSLayer[T]) OutputShape() []int

OutputShape returns the output shape, which matches the first candidate's output shape.

func (*DARTSLayer[T]) Parameters

func (d *DARTSLayer[T]) Parameters() []*graph.Parameter[T]

Parameters returns the learnable architecture parameters (alpha).

func (*DARTSLayer[T]) Weights

func (d *DARTSLayer[T]) Weights() []T

Weights returns the current softmax weights over candidate operations.

type DARTSOptimizer

type DARTSOptimizer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

DARTSOptimizer implements bilevel optimization for DARTS (Liu et al. 2019). Each step alternates between:

  • Inner loop: update network weights w using training loss gradient.
  • Outer loop: update architecture parameters alpha using validation loss gradient.

func NewDARTSOptimizer

func NewDARTSOptimizer[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], layer *DARTSLayer[T], cfg DARTSOptimizerConfig[T]) (*DARTSOptimizer[T], error)

NewDARTSOptimizer creates a new DARTS bilevel optimizer.

func (*DARTSOptimizer[T]) Step

func (d *DARTSOptimizer[T]) Step(
	ctx context.Context,
	trainInput, trainTarget *tensor.TensorNumeric[T],
	valInput, valTarget *tensor.TensorNumeric[T],
) error

Step performs one bilevel optimization step:

  1. Inner update: forward on trainInput, compute training loss, backprop, update network weights w.
  2. Outer update: forward on valInput, compute validation loss, backprop, update architecture alpha.

type DARTSOptimizerConfig

type DARTSOptimizerConfig[T tensor.Numeric] struct {
	// WeightLR is the learning rate for network weight updates (inner loop).
	WeightLR T
	// AlphaLR is the learning rate for architecture parameter updates (outer loop).
	AlphaLR T
}

DARTSOptimizerConfig holds configuration for the DARTS bilevel optimizer.

type DiscretizedArch

type DiscretizedArch struct {
	Cell        Cell
	TotalParams int64
}

DiscretizedArch represents a concrete architecture obtained by selecting the argmax operation for each edge in the cell DAG.

func Discretize

func Discretize[T tensor.Numeric](alpha []T, space *SearchSpace, maxParams int64) (*DiscretizedArch, error)

Discretize converts continuous DARTS architecture weights (alpha) into a concrete cell architecture by selecting the argmax operation per edge. It validates the resulting architecture against a maximum parameter budget.

The alpha slice must have length numEdges * numOps, laid out as [edge0_op0, edge0_op1, ..., edge0_opN, edge1_op0, ...]. This matches the alpha parameter shape from DARTSLayer when search space has numOps candidates.

type Edge

type Edge struct {
	From int
	To   int
	Op   OpType
}

Edge represents a directed edge in a cell DAG, connecting node From to node To with operation Op. Edges must satisfy From < To to ensure acyclicity.

type ExportConfig

type ExportConfig struct {
	// ModelName is the human-readable model name stored in general.name.
	ModelName string
	// HiddenDim is the hidden dimension of the discovered architecture.
	HiddenDim int
	// NumLayers is the number of stacked cells in the architecture.
	NumLayers int
	// InputFeatures is the number of input features (for time-series models).
	InputFeatures int
	// PatchLen is the patch length for patch-based architectures.
	PatchLen int
	// HorizonLen is the forecast horizon length.
	HorizonLen int
}

ExportConfig holds configuration for exporting a NAS-discovered architecture to GGUF format.

type HWProfile

type HWProfile struct {
	// Name is a human-readable identifier for the device.
	Name string
	// FLOPSThroughput is the peak FP32 throughput in GFLOPS.
	FLOPSThroughput float64
	// MemBandwidthGBs is the memory bandwidth in GB/s.
	MemBandwidthGBs float64
}

HWProfile describes the hardware capabilities of a target device.

func DGXSpark

func DGXSpark() HWProfile

DGXSpark returns the hardware profile for the NVIDIA DGX Spark (GB10 GPU).

type LatencyEstimator

type LatencyEstimator struct {
	// contains filtered or unexported fields
}

LatencyEstimator predicts inference latency for cell architectures using a linear cost model calibrated against measured hardware benchmarks.

func NewLatencyEstimator

func NewLatencyEstimator(hw HWProfile) *LatencyEstimator

NewLatencyEstimator creates an estimator for the given hardware profile using default operation costs.

func (*LatencyEstimator) Calibrate

func (e *LatencyEstimator) Calibrate(data []CalibrationPoint)

Calibrate fits the linear model coefficients (alpha, beta, bias) using ordinary least squares on the provided calibration data.

func (*LatencyEstimator) Estimate

func (e *LatencyEstimator) Estimate(c Cell) float64

Estimate returns the predicted inference latency in seconds for a cell.

func (*LatencyEstimator) LatencyEstimate

func (e *LatencyEstimator) LatencyEstimate(c Cell) float64

LatencyEstimate predicts inference latency for a cell architecture using the calibrated model. This is an alias for Estimate for API convenience.

func (*LatencyEstimator) RSquared

func (e *LatencyEstimator) RSquared(data []CalibrationPoint) float64

RSquared computes the coefficient of determination (R^2) of the estimator on the given data points.

type OpCost

type OpCost struct {
	FLOPs    float64 // floating-point operations per instance
	MemBytes float64 // bytes of memory traffic per instance
}

OpCost defines the per-instance cost of an operation type in terms of compute (FLOPs) and memory transfers (bytes read + written).

type OpType

type OpType string

OpType represents an operation type that can be placed on a cell edge.

const (
	OpConv3x3     OpType = "conv_3x3"
	OpConv5x5     OpType = "conv_5x5"
	OpSepConv3x3  OpType = "sep_conv_3x3"
	OpSepConv5x5  OpType = "sep_conv_5x5"
	OpAvgPool3x3  OpType = "avg_pool_3x3"
	OpMaxPool3x3  OpType = "max_pool_3x3"
	OpSkipConnect OpType = "skip_connect"
	OpZero        OpType = "zero"
)

func AllOps

func AllOps() []OpType

AllOps returns the default set of all 8 operation types.

type SearchSpace

type SearchSpace struct {
	NumNodes int
	Ops      []OpType
}

SearchSpace defines the space of possible cell architectures. It is parameterized by the number of nodes and the set of candidate operations.

func DefaultSignalSearchSpace

func DefaultSignalSearchSpace() *SearchSpace

DefaultSignalSearchSpace returns the default DARTS search space for PatchTST-like signal models: 4 nodes with pooling, skip, and zero ops.

func NewSearchSpace

func NewSearchSpace(numNodes int) *SearchSpace

NewSearchSpace creates a search space with the given number of nodes and all 8 default operation types.

func NewSearchSpaceWithOps

func NewSearchSpaceWithOps(numNodes int, ops []OpType) *SearchSpace

NewSearchSpaceWithOps creates a search space with the given number of nodes and a custom set of operation types.

func (*SearchSpace) Enumerate

func (s *SearchSpace) Enumerate(maxCells int) []Cell

Enumerate returns up to maxCells cell architectures by exhaustive enumeration. The cells are generated in lexicographic order over the operation assignments to edges.

func (*SearchSpace) NumCells

func (s *SearchSpace) NumCells() int64

NumCells returns the total number of possible cell architectures: len(Ops)^numEdges where numEdges = numNodes*(numNodes-1)/2.

func (*SearchSpace) Sample

func (s *SearchSpace) Sample(rng *rand.Rand) Cell

Sample randomly samples a valid cell architecture from the search space.

type SignalDataProvider

type SignalDataProvider interface {
	// TrainBatch returns a (input, target) pair for the training split.
	TrainBatch() (input, target []float32, shape []int, err error)
	// ValBatch returns a (input, target) pair for the validation split.
	ValBatch() (input, target []float32, shape []int, err error)
}

SignalDataProvider supplies training and validation data for the NAS search. Implementations can load from disk or generate synthetic data for testing.

type SignalSearchConfig

type SignalSearchConfig struct {
	// NumTrials is the number of DARTS search trials to run. Each trial
	// randomly initializes the architecture parameters and runs bilevel
	// optimization for SearchSteps steps.
	NumTrials int
	// SearchSteps is the number of bilevel optimization steps per trial.
	SearchSteps int
	// WeightLR is the inner-loop learning rate for network weights.
	WeightLR float64
	// AlphaLR is the outer-loop learning rate for architecture parameters.
	AlphaLR float64
	// MaxParams is the maximum parameter budget for the discretized architecture.
	// Zero means no limit.
	MaxParams int64

	// PatchTST-like architecture dimensions.
	// InputFeatures is the number of input features per time step.
	InputFeatures int
	// PatchLen is the number of time steps per patch.
	PatchLen int
	// HorizonLen is the forecast horizon length.
	HorizonLen int
	// HiddenDim is the hidden dimension of the model.
	HiddenDim int
	// NumLayers is the number of stacked cells.
	NumLayers int

	// SearchSpace defines the DARTS search space. If nil, a default space
	// suitable for PatchTST-like architectures is used.
	SearchSpace *SearchSpace

	// Seed for reproducibility. Zero means non-deterministic.
	Seed uint64
}

SignalSearchConfig holds configuration for a NAS search over time-series signal model architectures.

type SignalSearchOutput

type SignalSearchOutput struct {
	// Best is the result with the lowest validation loss across all trials.
	Best SignalSearchResult
	// AllResults contains results from every trial.
	AllResults []SignalSearchResult
	// ExportConfig is the GGUF export configuration derived from the search config.
	ExportConfig ExportConfig
}

SignalSearchOutput holds the complete output of RunSignalNAS.

func RunSignalNAS

RunSignalNAS runs the full NAS search pipeline for time-series signal models. It performs multiple DARTS trials, discretizes the best architecture, and returns a result ready for GGUF export.

type SignalSearchResult

type SignalSearchResult struct {
	// Trial is the 0-based trial index.
	Trial int
	// Arch is the discretized architecture discovered in this trial.
	Arch *DiscretizedArch
	// Metric is the evaluation metric value (lower is better for loss,
	// higher is better for Sharpe ratio depending on usage).
	Metric float64
	// FinalLoss is the validation loss at the end of the search.
	FinalLoss float64
}

SignalSearchResult holds the result of a NAS search trial.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL