Documentation
¶
Overview ¶
Package ssm implements state space model layers.
Index ¶
- type MambaBlock
- func (m *MambaBlock[T]) Attributes() map[string]interface{}
- func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *MambaBlock[T]) Name() string
- func (m *MambaBlock[T]) OpType() string
- func (m *MambaBlock[T]) OutputShape() []int
- func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]
- func (m *MambaBlock[T]) SetName(n string)
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
This section is empty.
Types ¶
type MambaBlock ¶
type MambaBlock[T tensor.Numeric] struct { // SSM parameters A *graph.Parameter[T] // [d_inner, d_state] — log-space initialization D *graph.Parameter[T] // [d_inner] — skip connection // contains filtered or unexported fields }
MambaBlock implements the Mamba selective state space model block.
Architecture (Mamba-1 style):
- Input projection: d_model -> 2*d_inner (split into x and z branches)
- Depthwise causal Conv1D on x branch (kernel_size=4, groups=d_inner)
- SiLU activation on x
- SSM parameter projection: x -> (dt, B, C)
- Selective scan: discretize A,B with softplus(dt), run parallel scan
- Gate: y * SiLU(z)
- Output projection: d_inner -> d_model
Input shape: [batch, seq_len, d_model] Output shape: [batch, seq_len, d_model]
func NewMambaBlock ¶
func NewMambaBlock[T tensor.Numeric]( name string, engine compute.Engine[T], ops numeric.Arithmetic[T], dModel, dInner, dState, dtRank, convKer int, ) (*MambaBlock[T], error)
NewMambaBlock creates a new MambaBlock.
Parameters:
- dModel: input/output dimension
- dInner: inner SSM dimension (typically 2*dModel)
- dState: SSM state dimension (e.g. 16)
- dtRank: rank of dt projection (typically dModel/16 or ceil(dModel/16))
- convKer: depthwise conv1d kernel size (typically 4)
func (*MambaBlock[T]) Attributes ¶
func (m *MambaBlock[T]) Attributes() map[string]interface{}
func (*MambaBlock[T]) Backward ¶
func (m *MambaBlock[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for the Mamba block using the chain rule.
func (*MambaBlock[T]) Forward ¶
func (m *MambaBlock[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the Mamba block forward pass. Input: [batch, seq_len, d_model] Output: [batch, seq_len, d_model]
func (*MambaBlock[T]) Name ¶
func (m *MambaBlock[T]) Name() string
func (*MambaBlock[T]) OpType ¶
func (m *MambaBlock[T]) OpType() string
func (*MambaBlock[T]) OutputShape ¶
func (m *MambaBlock[T]) OutputShape() []int
func (*MambaBlock[T]) Parameters ¶
func (m *MambaBlock[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters.
func (*MambaBlock[T]) SetName ¶
func (m *MambaBlock[T]) SetName(n string)
Click to show internal directories.
Click to hide internal directories.