Documentation
¶
Overview ¶
Package ssm implements state space model layers.
Package ssm implements state space model layers.
Stability: alpha
Package ssm implements state space model layers.
Index ¶
- type BCNorm
- func (bn *BCNorm[T]) Backward(_ context.Context, dOut *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (bn *BCNorm[T]) Forward(_ context.Context, input *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (bn *BCNorm[T]) Name() string
- func (bn *BCNorm[T]) Parameters() []*graph.Parameter[T]
- type ComplexSSMState
- func (c *ComplexSSMState[T]) Attributes() map[string]interface{}
- func (c *ComplexSSMState[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (c *ComplexSSMState[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (c *ComplexSSMState[T]) Name() string
- func (c *ComplexSSMState[T]) OpType() string
- func (c *ComplexSSMState[T]) OutputShape() []int
- func (c *ComplexSSMState[T]) Parameters() []*graph.Parameter[T]
- func (c *ComplexSSMState[T]) SetName(n string)
- type ComplexSSMStateOption
- type DiscretizationMode
- type Linear
- func (l *Linear[T]) Backward(_ context.Context, _ types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (l *Linear[T]) Forward(_ context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (l *Linear[T]) Parameters() []*graph.Parameter[T]
- type MIMOMambaBlock
- func (m *MIMOMambaBlock[T]) Attributes() map[string]interface{}
- func (m *MIMOMambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (m *MIMOMambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *MIMOMambaBlock[T]) Name() string
- func (m *MIMOMambaBlock[T]) OpType() string
- func (m *MIMOMambaBlock[T]) OutputShape() []int
- func (m *MIMOMambaBlock[T]) Parameters() []*graph.Parameter[T]
- func (m *MIMOMambaBlock[T]) SetName(n string)
- type MIMOMambaBlockOption
- type MambaBlock
- func (m *MambaBlock[T]) Attributes() map[string]interface{}
- func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *MambaBlock[T]) Name() string
- func (m *MambaBlock[T]) OpType() string
- func (m *MambaBlock[T]) OutputShape() []int
- func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]
- func (m *MambaBlock[T]) SetName(n string)
- type MambaBlockOption
- type S4
- func (s *S4[T]) Attributes() map[string]interface{}
- func (s *S4[T]) Backward(_ context.Context, _ types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (s *S4[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (s *S4[T]) OpType() string
- func (s *S4[T]) OutputShape() []int
- func (s *S4[T]) Parameters() []*graph.Parameter[T]
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type BCNorm ¶ added in v1.7.0
BCNorm implements L2 normalization with a learnable gain for the B and C matrices of a state space model. It stabilizes the SSM recurrence by preventing the B/C values from growing unbounded, which is especially important when complex-valued RoPE rotations are applied.
For an input x of shape [..., dim]:
norm = sqrt(sum(x^2, dim=-1) / dim + eps) out = gain * x / norm
This is similar to RMSNorm but applied specifically to the SSM projection outputs before they enter the selective scan.
func NewBCNorm ¶ added in v1.7.0
func NewBCNorm[T tensor.Numeric](name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dim int) (*BCNorm[T], error)
NewBCNorm creates a new BCNorm layer.
func (*BCNorm[T]) Backward ¶ added in v1.7.0
func (bn *BCNorm[T]) Backward(_ context.Context, dOut *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Backward computes gradients for BCNorm.
func (*BCNorm[T]) Forward ¶ added in v1.7.0
func (bn *BCNorm[T]) Forward(_ context.Context, input *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward applies BCNorm: gain * x / rms(x).
func (*BCNorm[T]) Parameters ¶ added in v1.7.0
Parameters returns the trainable parameters (gain).
type ComplexSSMState ¶ added in v1.7.0
type ComplexSSMState[T tensor.Numeric] struct { // SSM parameters A *graph.Parameter[T] // [d_inner, d_state] D *graph.Parameter[T] // [d_inner] // contains filtered or unexported fields }
ComplexSSMState implements complex-valued SSM state tracking using RoPE embeddings on the B and C matrices. The hidden state is split into pairs of dimensions that are treated as (real, imaginary) components. RoPE rotates each pair by a position-dependent angle, encoding temporal structure into the state without doubling memory.
This follows the Mamba 3 design where B and C are rotated by RoPE before the selective scan, giving the recurrence complex-valued dynamics.
Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]
func NewComplexSSMState ¶ added in v1.7.0
func NewComplexSSMState[T tensor.Numeric]( name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dModel, dInner, dState, dtRank, convKer int, maxSeqLen int, opts ...ComplexSSMStateOption[T], ) (*ComplexSSMState[T], error)
NewComplexSSMState creates a new ComplexSSMState block.
dState must be even since dimensions are paired as (real, imaginary) for complex-valued RoPE rotation.
func (*ComplexSSMState[T]) Attributes ¶ added in v1.7.0
func (c *ComplexSSMState[T]) Attributes() map[string]interface{}
func (*ComplexSSMState[T]) Backward ¶ added in v1.7.0
func (c *ComplexSSMState[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for the ComplexSSMState block.
func (*ComplexSSMState[T]) Forward ¶ added in v1.7.0
func (c *ComplexSSMState[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the complex-valued SSM forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]
func (*ComplexSSMState[T]) Name ¶ added in v1.7.0
func (c *ComplexSSMState[T]) Name() string
func (*ComplexSSMState[T]) OpType ¶ added in v1.7.0
func (c *ComplexSSMState[T]) OpType() string
func (*ComplexSSMState[T]) OutputShape ¶ added in v1.7.0
func (c *ComplexSSMState[T]) OutputShape() []int
func (*ComplexSSMState[T]) Parameters ¶ added in v1.7.0
func (c *ComplexSSMState[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters.
func (*ComplexSSMState[T]) SetName ¶ added in v1.7.0
func (c *ComplexSSMState[T]) SetName(n string)
type ComplexSSMStateOption ¶ added in v1.7.0
type ComplexSSMStateOption[T tensor.Numeric] func(*ComplexSSMState[T])
ComplexSSMStateOption is a functional option for ComplexSSMState.
func WithComplexDiscretizationMode ¶ added in v1.7.0
func WithComplexDiscretizationMode[T tensor.Numeric](mode DiscretizationMode) ComplexSSMStateOption[T]
WithComplexDiscretizationMode sets the discretization mode.
type DiscretizationMode ¶ added in v1.7.0
type DiscretizationMode int
DiscretizationMode controls how the continuous SSM (A, B) is discretized.
const ( // ZOH uses zero-order hold discretization (Mamba 1/2 default): // Ā = exp(Δ * A) // B̄ = Δ * B ZOH DiscretizationMode = iota // ExpTrap uses exponential-trapezoidal discretization (Mamba 3): // Ā = exp(Δ * A) // B̄ = Δ * (I + exp(Δ * A)) / 2 * B // // This gives richer system dynamics by taking a trapezoidal average of the // continuous-time B at both endpoints of the discretization interval. ExpTrap )
type Linear ¶ added in v1.7.0
Linear is a simple linear projection layer used by ComplexSSMState.
func NewLinear ¶ added in v1.7.0
func NewLinear[T tensor.Numeric](name string, engine compute.Engine[T], ops numeric.Arithmetic[T], inDim, outDim int) (*Linear[T], error)
NewLinear creates a simple linear projection.
func (*Linear[T]) Backward ¶ added in v1.7.0
func (l *Linear[T]) Backward(_ context.Context, _ types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for the linear layer.
func (*Linear[T]) Forward ¶ added in v1.7.0
func (l *Linear[T]) Forward(_ context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes y = x @ W for the linear projection.
func (*Linear[T]) Parameters ¶ added in v1.7.0
Parameters returns the trainable parameters.
type MIMOMambaBlock ¶ added in v1.7.0
type MIMOMambaBlock[T tensor.Numeric] struct { // Per-head SSM parameters A []*graph.Parameter[T] // each [headDim, d_state] D []*graph.Parameter[T] // each [headDim] // contains filtered or unexported fields }
MIMOMambaBlock implements a multi-input multi-output SSM block with multiple parallel state spaces (heads) and cross-channel mixing.
Architecture:
- Input projection: d_model -> 2*d_inner (x and z branches)
- Depthwise causal Conv1D on x branch
- SiLU activation on x
- SSM parameter projection: x -> (dt, B, C) per head
- Multi-head selective scan: each head processes d_inner/num_heads channels with its own A, D parameters and independent state space
- Cross-head mixing: linear projection across head outputs
- Gate: mixed_y * SiLU(z)
- Output projection: d_inner -> d_model
The multi-head design allows different heads to specialize on different temporal patterns, similar to multi-head attention but for SSM recurrence.
Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]
func NewMIMOMambaBlock ¶ added in v1.7.0
func NewMIMOMambaBlock[T tensor.Numeric]( name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dModel, dInner, dState, dtRank, convKer, numHeads int, opts ...MIMOMambaBlockOption[T], ) (*MIMOMambaBlock[T], error)
NewMIMOMambaBlock creates a new multi-head MIMO SSM block.
Parameters:
- dModel: input/output dimension
- dInner: inner SSM dimension (must be divisible by numHeads)
- dState: SSM state dimension per head
- dtRank: rank of dt projection
- convKer: depthwise conv1d kernel size
- numHeads: number of parallel SSM heads
func (*MIMOMambaBlock[T]) Attributes ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) Attributes() map[string]interface{}
func (*MIMOMambaBlock[T]) Backward ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for the MIMO Mamba block.
func (*MIMOMambaBlock[T]) Forward ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the MIMO Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]
func (*MIMOMambaBlock[T]) Name ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) Name() string
func (*MIMOMambaBlock[T]) OpType ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) OpType() string
func (*MIMOMambaBlock[T]) OutputShape ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) OutputShape() []int
func (*MIMOMambaBlock[T]) Parameters ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters.
func (*MIMOMambaBlock[T]) SetName ¶ added in v1.7.0
func (m *MIMOMambaBlock[T]) SetName(n string)
type MIMOMambaBlockOption ¶ added in v1.7.0
type MIMOMambaBlockOption[T tensor.Numeric] func(*MIMOMambaBlock[T])
MIMOMambaBlockOption is a functional option for MIMOMambaBlock.
func WithMIMODiscretizationMode ¶ added in v1.7.0
func WithMIMODiscretizationMode[T tensor.Numeric](mode DiscretizationMode) MIMOMambaBlockOption[T]
WithMIMODiscretizationMode sets the SSM discretization mode.
type MambaBlock ¶
type MambaBlock[T tensor.Numeric] struct { // SSM parameters A *graph.Parameter[T] // [d_inner, d_state] — log-space initialization D *graph.Parameter[T] // [d_inner] — skip connection // contains filtered or unexported fields }
MambaBlock implements the Mamba selective state space model block.
Architecture (Mamba-1 style):
- Input projection: d_model -> 2*d_inner (split into x and z branches)
- Depthwise causal Conv1D on x branch (kernel_size=4, groups=d_inner)
- SiLU activation on x
- SSM parameter projection: x -> (dt, B, C)
- Selective scan: discretize A,B with softplus(dt), run parallel scan
- Gate: y * SiLU(z)
- Output projection: d_inner -> d_model
Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]
func NewMambaBlock ¶
func NewMambaBlock[T tensor.Numeric]( name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dModel, dInner, dState, dtRank, convKer int, opts ...MambaBlockOption[T], ) (*MambaBlock[T], error)
NewMambaBlock creates a new MambaBlock.
Parameters:
- dModel: input/output dimension
- dInner: inner SSM dimension (typically 2*dModel)
- dState: SSM state dimension (e.g. 16)
- dtRank: rank of dt projection (typically dModel/16 or ceil(dModel/16))
- convKer: depthwise conv1d kernel size (typically 4)
- opts: optional functional options (e.g. WithDiscretizationMode)
func (*MambaBlock[T]) Attributes ¶
func (m *MambaBlock[T]) Attributes() map[string]interface{}
func (*MambaBlock[T]) Backward ¶
func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for the Mamba block using the chain rule.
func (*MambaBlock[T]) Forward ¶
func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]
func (*MambaBlock[T]) Name ¶
func (m *MambaBlock[T]) Name() string
func (*MambaBlock[T]) OpType ¶
func (m *MambaBlock[T]) OpType() string
func (*MambaBlock[T]) OutputShape ¶
func (m *MambaBlock[T]) OutputShape() []int
func (*MambaBlock[T]) Parameters ¶
func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters.
func (*MambaBlock[T]) SetName ¶
func (m *MambaBlock[T]) SetName(n string)
type MambaBlockOption ¶ added in v1.7.0
type MambaBlockOption[T tensor.Numeric] func(*MambaBlock[T])
MambaBlockOption is a functional option for configuring a MambaBlock.
func WithDiscretizationMode ¶ added in v1.7.0
func WithDiscretizationMode[T tensor.Numeric](mode DiscretizationMode) MambaBlockOption[T]
WithDiscretizationMode sets the SSM discretization mode. Defaults to ZOH for backward compatibility.
type S4 ¶ added in v1.8.0
S4 implements a diagonal State Space Model (S4D variant).
The continuous-time state space model is:
x'(t) = A x(t) + B u(t) y(t) = C x(t) + D u(t)
With diagonal A, the discrete-time equations become element-wise:
x_k = a * x_{k-1} + b * u_k
y_k = sum(c * x_k) + d * u_k
where a = exp(dt * A_diag) ensures stability when A_diag < 0.
Input shape: [batch, seq_len, input_dim] Output shape: [batch, seq_len, input_dim]
Parameters (per input dimension, state_dim internal states):
A_log [input_dim, state_dim] - log(-A), parameterizing stable eigenvalues B [input_dim, state_dim] - input-to-state projection C [input_dim, state_dim] - state-to-output projection D [input_dim] - skip connection
func NewS4 ¶ added in v1.8.0
func NewS4[T tensor.Float]( name string, engine compute.Engine[T], ops numeric.Arithmetic[T], inputDim, stateDim int, ) (*S4[T], error)
NewS4 creates a new S4 layer with HiPPO-inspired initialization.
func (*S4[T]) Attributes ¶ added in v1.8.0
Attributes returns the layer attributes.
func (*S4[T]) Backward ¶ added in v1.8.0
func (s *S4[T]) Backward(_ context.Context, _ types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients using full backpropagation through time (BPTT).
The adjoint equation for the hidden state is:
dL/dx_k = dL/dy_k * C + dL/dx_{k+1} * A_disc
We iterate backward from the last timestep, accumulating gradients for all parameters.
func (*S4[T]) Forward ¶ added in v1.8.0
func (s *S4[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward runs the diagonal SSM scan over the sequence.
Input: [batch, seq_len, input_dim] Output: [batch, seq_len, input_dim]
All arithmetic is routed through engine primitives so the computation graph is fully traceable by the tracing compiler.
func (*S4[T]) OutputShape ¶ added in v1.8.0
OutputShape returns the output shape.
func (*S4[T]) Parameters ¶ added in v1.8.0
Parameters returns all trainable parameters.