ssm

package
v1.25.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 26, 2026 License: Apache-2.0 Imports: 12 Imported by: 0

Documentation

Overview

Package ssm implements state space model layers.

Package ssm implements state space model layers.

Stability: alpha

Package ssm implements state space model layers.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type BCNorm added in v1.7.0

type BCNorm[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

BCNorm implements L2 normalization with a learnable gain for the B and C matrices of a state space model. It stabilizes the SSM recurrence by preventing the B/C values from growing unbounded, which is especially important when complex-valued RoPE rotations are applied.

For an input x of shape [..., dim]:

norm = sqrt(sum(x^2, dim=-1) / dim + eps)
out = gain * x / norm

This is similar to RMSNorm but applied specifically to the SSM projection outputs before they enter the selective scan.

func NewBCNorm added in v1.7.0

func NewBCNorm[T tensor.Numeric](name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dim int) (*BCNorm[T], error)

NewBCNorm creates a new BCNorm layer.

func (*BCNorm[T]) Backward added in v1.7.0

func (bn *BCNorm[T]) Backward(_ context.Context, dOut *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Backward computes gradients for BCNorm.

func (*BCNorm[T]) Forward added in v1.7.0

func (bn *BCNorm[T]) Forward(_ context.Context, input *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward applies BCNorm: gain * x / rms(x).

func (*BCNorm[T]) Name added in v1.7.0

func (bn *BCNorm[T]) Name() string

Name returns the layer name.

func (*BCNorm[T]) Parameters added in v1.7.0

func (bn *BCNorm[T]) Parameters() []*graph.Parameter[T]

Parameters returns the trainable parameters (gain).

type ComplexSSMState added in v1.7.0

type ComplexSSMState[T tensor.Numeric] struct {

	// SSM parameters
	A *graph.Parameter[T] // [d_inner, d_state]
	D *graph.Parameter[T] // [d_inner]
	// contains filtered or unexported fields
}

ComplexSSMState implements complex-valued SSM state tracking using RoPE embeddings on the B and C matrices. The hidden state is split into pairs of dimensions that are treated as (real, imaginary) components. RoPE rotates each pair by a position-dependent angle, encoding temporal structure into the state without doubling memory.

This follows the Mamba 3 design where B and C are rotated by RoPE before the selective scan, giving the recurrence complex-valued dynamics.

Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]

func NewComplexSSMState added in v1.7.0

func NewComplexSSMState[T tensor.Numeric](
	name string,
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	dModel, dInner, dState, dtRank, convKer int,
	maxSeqLen int,
	opts ...ComplexSSMStateOption[T],
) (*ComplexSSMState[T], error)

NewComplexSSMState creates a new ComplexSSMState block.

dState must be even since dimensions are paired as (real, imaginary) for complex-valued RoPE rotation.

func (*ComplexSSMState[T]) Attributes added in v1.7.0

func (c *ComplexSSMState[T]) Attributes() map[string]interface{}

func (*ComplexSSMState[T]) Backward added in v1.7.0

func (c *ComplexSSMState[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the ComplexSSMState block.

func (*ComplexSSMState[T]) Forward added in v1.7.0

func (c *ComplexSSMState[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the complex-valued SSM forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]

func (*ComplexSSMState[T]) Name added in v1.7.0

func (c *ComplexSSMState[T]) Name() string

func (*ComplexSSMState[T]) OpType added in v1.7.0

func (c *ComplexSSMState[T]) OpType() string

func (*ComplexSSMState[T]) OutputShape added in v1.7.0

func (c *ComplexSSMState[T]) OutputShape() []int

func (*ComplexSSMState[T]) Parameters added in v1.7.0

func (c *ComplexSSMState[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

func (*ComplexSSMState[T]) SetName added in v1.7.0

func (c *ComplexSSMState[T]) SetName(n string)

type ComplexSSMStateOption added in v1.7.0

type ComplexSSMStateOption[T tensor.Numeric] func(*ComplexSSMState[T])

ComplexSSMStateOption is a functional option for ComplexSSMState.

func WithComplexDiscretizationMode added in v1.7.0

func WithComplexDiscretizationMode[T tensor.Numeric](mode DiscretizationMode) ComplexSSMStateOption[T]

WithComplexDiscretizationMode sets the discretization mode.

type DiscretizationMode added in v1.7.0

type DiscretizationMode int

DiscretizationMode controls how the continuous SSM (A, B) is discretized.

const (
	// ZOH uses zero-order hold discretization (Mamba 1/2 default):
	//   Ā = exp(Δ * A)
	//   B̄ = Δ * B
	ZOH DiscretizationMode = iota

	// ExpTrap uses exponential-trapezoidal discretization (Mamba 3):
	//   Ā = exp(Δ * A)
	//   B̄ = Δ * (I + exp(Δ * A)) / 2 * B
	//
	// This gives richer system dynamics by taking a trapezoidal average of the
	// continuous-time B at both endpoints of the discretization interval.
	ExpTrap
)

type Linear added in v1.7.0

type Linear[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Linear is a simple linear projection layer used by ComplexSSMState.

func NewLinear added in v1.7.0

func NewLinear[T tensor.Numeric](name string, engine compute.Engine[T], ops numeric.Arithmetic[T], inDim, outDim int) (*Linear[T], error)

NewLinear creates a simple linear projection.

func (*Linear[T]) Backward added in v1.7.0

func (l *Linear[T]) Backward(_ context.Context, _ types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the linear layer.

func (*Linear[T]) Forward added in v1.7.0

func (l *Linear[T]) Forward(_ context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes y = x @ W for the linear projection.

func (*Linear[T]) Parameters added in v1.7.0

func (l *Linear[T]) Parameters() []*graph.Parameter[T]

Parameters returns the trainable parameters.

type MIMOMambaBlock added in v1.7.0

type MIMOMambaBlock[T tensor.Numeric] struct {

	// Per-head SSM parameters
	A []*graph.Parameter[T] // each [headDim, d_state]
	D []*graph.Parameter[T] // each [headDim]
	// contains filtered or unexported fields
}

MIMOMambaBlock implements a multi-input multi-output SSM block with multiple parallel state spaces (heads) and cross-channel mixing.

Architecture:

  • Input projection: d_model -> 2*d_inner (x and z branches)
  • Depthwise causal Conv1D on x branch
  • SiLU activation on x
  • SSM parameter projection: x -> (dt, B, C) per head
  • Multi-head selective scan: each head processes d_inner/num_heads channels with its own A, D parameters and independent state space
  • Cross-head mixing: linear projection across head outputs
  • Gate: mixed_y * SiLU(z)
  • Output projection: d_inner -> d_model

The multi-head design allows different heads to specialize on different temporal patterns, similar to multi-head attention but for SSM recurrence.

Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]

func NewMIMOMambaBlock added in v1.7.0

func NewMIMOMambaBlock[T tensor.Numeric](
	name string,
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	dModel, dInner, dState, dtRank, convKer, numHeads int,
	opts ...MIMOMambaBlockOption[T],
) (*MIMOMambaBlock[T], error)

NewMIMOMambaBlock creates a new multi-head MIMO SSM block.

Parameters:

  • dModel: input/output dimension
  • dInner: inner SSM dimension (must be divisible by numHeads)
  • dState: SSM state dimension per head
  • dtRank: rank of dt projection
  • convKer: depthwise conv1d kernel size
  • numHeads: number of parallel SSM heads

func (*MIMOMambaBlock[T]) Attributes added in v1.7.0

func (m *MIMOMambaBlock[T]) Attributes() map[string]interface{}

func (*MIMOMambaBlock[T]) Backward added in v1.7.0

func (m *MIMOMambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the MIMO Mamba block.

func (*MIMOMambaBlock[T]) Forward added in v1.7.0

func (m *MIMOMambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the MIMO Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]

func (*MIMOMambaBlock[T]) Name added in v1.7.0

func (m *MIMOMambaBlock[T]) Name() string

func (*MIMOMambaBlock[T]) OpType added in v1.7.0

func (m *MIMOMambaBlock[T]) OpType() string

func (*MIMOMambaBlock[T]) OutputShape added in v1.7.0

func (m *MIMOMambaBlock[T]) OutputShape() []int

func (*MIMOMambaBlock[T]) Parameters added in v1.7.0

func (m *MIMOMambaBlock[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

func (*MIMOMambaBlock[T]) SetName added in v1.7.0

func (m *MIMOMambaBlock[T]) SetName(n string)

type MIMOMambaBlockOption added in v1.7.0

type MIMOMambaBlockOption[T tensor.Numeric] func(*MIMOMambaBlock[T])

MIMOMambaBlockOption is a functional option for MIMOMambaBlock.

func WithMIMODiscretizationMode added in v1.7.0

func WithMIMODiscretizationMode[T tensor.Numeric](mode DiscretizationMode) MIMOMambaBlockOption[T]

WithMIMODiscretizationMode sets the SSM discretization mode.

type MambaBlock

type MambaBlock[T tensor.Numeric] struct {

	// SSM parameters
	A *graph.Parameter[T] // [d_inner, d_state] — log-space initialization
	D *graph.Parameter[T] // [d_inner] — skip connection
	// contains filtered or unexported fields
}

MambaBlock implements the Mamba selective state space model block.

Architecture (Mamba-1 style):

  • Input projection: d_model -> 2*d_inner (split into x and z branches)
  • Depthwise causal Conv1D on x branch (kernel_size=4, groups=d_inner)
  • SiLU activation on x
  • SSM parameter projection: x -> (dt, B, C)
  • Selective scan: discretize A,B with softplus(dt), run parallel scan
  • Gate: y * SiLU(z)
  • Output projection: d_inner -> d_model

Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]

func NewMambaBlock

func NewMambaBlock[T tensor.Numeric](
	name string,
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	dModel, dInner, dState, dtRank, convKer int,
	opts ...MambaBlockOption[T],
) (*MambaBlock[T], error)

NewMambaBlock creates a new MambaBlock.

Parameters:

  • dModel: input/output dimension
  • dInner: inner SSM dimension (typically 2*dModel)
  • dState: SSM state dimension (e.g. 16)
  • dtRank: rank of dt projection (typically dModel/16 or ceil(dModel/16))
  • convKer: depthwise conv1d kernel size (typically 4)
  • opts: optional functional options (e.g. WithDiscretizationMode)

func (*MambaBlock[T]) Attributes

func (m *MambaBlock[T]) Attributes() map[string]interface{}

func (*MambaBlock[T]) Backward

func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the Mamba block using the chain rule.

func (*MambaBlock[T]) Forward

func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]

func (*MambaBlock[T]) Name

func (m *MambaBlock[T]) Name() string

func (*MambaBlock[T]) OpType

func (m *MambaBlock[T]) OpType() string

func (*MambaBlock[T]) OutputShape

func (m *MambaBlock[T]) OutputShape() []int

func (*MambaBlock[T]) Parameters

func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

func (*MambaBlock[T]) ScaleOutputProj added in v1.15.1

func (m *MambaBlock[T]) ScaleOutputProj(scale float64)

ScaleOutputProj scales the output projection weights by the given factor. Used by multi-layer models to apply residual scaling (1/sqrt(NLayers)).

func (*MambaBlock[T]) SetName

func (m *MambaBlock[T]) SetName(n string)

type MambaBlockOption added in v1.7.0

type MambaBlockOption[T tensor.Numeric] func(*MambaBlock[T])

MambaBlockOption is a functional option for configuring a MambaBlock.

func WithDiscretizationMode added in v1.7.0

func WithDiscretizationMode[T tensor.Numeric](mode DiscretizationMode) MambaBlockOption[T]

WithDiscretizationMode sets the SSM discretization mode. Defaults to ZOH for backward compatibility.

type S4 added in v1.8.0

type S4[T tensor.Float] struct {
	// contains filtered or unexported fields
}

S4 implements a diagonal State Space Model (S4D variant).

The continuous-time state space model is:

x'(t) = A x(t) + B u(t)
y(t)  = C x(t) + D u(t)

With diagonal A, the discrete-time equations become element-wise:

x_k = a * x_{k-1} + b * u_k
y_k = sum(c * x_k) + d * u_k

where a = exp(dt * A_diag) ensures stability when A_diag < 0.

Input shape: [batch, seq_len, input_dim] Output shape: [batch, seq_len, input_dim]

Parameters (per input dimension, state_dim internal states):

A_log [input_dim, state_dim] - log(-A), parameterizing stable eigenvalues
B     [input_dim, state_dim] - input-to-state projection
C     [input_dim, state_dim] - state-to-output projection
D     [input_dim]            - skip connection

func NewS4 added in v1.8.0

func NewS4[T tensor.Float](
	name string,
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	inputDim, stateDim int,
) (*S4[T], error)

NewS4 creates a new S4 layer with HiPPO-inspired initialization.

func (*S4[T]) Attributes added in v1.8.0

func (s *S4[T]) Attributes() map[string]interface{}

Attributes returns the layer attributes.

func (*S4[T]) Backward added in v1.8.0

func (s *S4[T]) Backward(_ context.Context, _ types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients using full backpropagation through time (BPTT).

The adjoint equation for the hidden state is:

dL/dx_k = dL/dy_k * C + dL/dx_{k+1} * A_disc

We iterate backward from the last timestep, accumulating gradients for all parameters.

func (*S4[T]) Forward added in v1.8.0

func (s *S4[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward runs the diagonal SSM scan over the sequence.

Input: [batch, seq_len, input_dim] Output: [batch, seq_len, input_dim]

All arithmetic is routed through engine primitives so the computation graph is fully traceable by the tracing compiler.

func (*S4[T]) OpType added in v1.8.0

func (s *S4[T]) OpType() string

OpType returns the operation type.

func (*S4[T]) OutputShape added in v1.8.0

func (s *S4[T]) OutputShape() []int

OutputShape returns the output shape.

func (*S4[T]) Parameters added in v1.8.0

func (s *S4[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL