ssm

package
v1.6.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 18, 2026 License: Apache-2.0 Imports: 10 Imported by: 0

Documentation

Overview

Package ssm implements state space model layers.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type MambaBlock

type MambaBlock[T tensor.Numeric] struct {

	// SSM parameters
	A *graph.Parameter[T] // [d_inner, d_state] — log-space initialization
	D *graph.Parameter[T] // [d_inner] — skip connection
	// contains filtered or unexported fields
}

MambaBlock implements the Mamba selective state space model block.

Architecture (Mamba-1 style):

  • Input projection: d_model -> 2*d_inner (split into x and z branches)
  • Depthwise causal Conv1D on x branch (kernel_size=4, groups=d_inner)
  • SiLU activation on x
  • SSM parameter projection: x -> (dt, B, C)
  • Selective scan: discretize A,B with softplus(dt), run parallel scan
  • Gate: y * SiLU(z)
  • Output projection: d_inner -> d_model

Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]

func NewMambaBlock

func NewMambaBlock[T tensor.Numeric](
	name string,
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	dModel, dInner, dState, dtRank, convKer int,
) (*MambaBlock[T], error)

NewMambaBlock creates a new MambaBlock.

Parameters:

  • dModel: input/output dimension
  • dInner: inner SSM dimension (typically 2*dModel)
  • dState: SSM state dimension (e.g. 16)
  • dtRank: rank of dt projection (typically dModel/16 or ceil(dModel/16))
  • convKer: depthwise conv1d kernel size (typically 4)

func (*MambaBlock[T]) Attributes

func (m *MambaBlock[T]) Attributes() map[string]interface{}

func (*MambaBlock[T]) Backward

func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the Mamba block using the chain rule.

func (*MambaBlock[T]) Forward

func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]

func (*MambaBlock[T]) Name

func (m *MambaBlock[T]) Name() string

func (*MambaBlock[T]) OpType

func (m *MambaBlock[T]) OpType() string

func (*MambaBlock[T]) OutputShape

func (m *MambaBlock[T]) OutputShape() []int

func (*MambaBlock[T]) Parameters

func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

func (*MambaBlock[T]) SetName

func (m *MambaBlock[T]) SetName(n string)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL