attention

package
v1.33.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 28, 2026 License: Apache-2.0 Imports: 18 Imported by: 1

Documentation

Overview

Package attention provides attention mechanisms for neural networks.

Stability: stable

Package attention provides attention mechanisms for neural networks.

Package attention provides attention mechanisms for neural networks.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func BuildCausalSlidingWindowMask added in v0.2.1

func BuildCausalSlidingWindowMask[T tensor.Numeric](seqLen, windowSize int) *tensor.TensorNumeric[T]

BuildCausalSlidingWindowMask creates a causal attention mask that also restricts attention to the last windowSize positions. Positions outside the window or in the future are set to a large negative value. Shape: [1, 1, seqLen, seqLen].

func BuildGlobalAttention added in v0.2.1

func BuildGlobalAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	_ string,
	_ map[string]*graph.Parameter[T],
	attributes map[string]interface{},
) (graph.Node[T], error)

BuildGlobalAttention constructs a GlobalAttention node from attributes. Required attributes: - "embed_dim" (int): embedding dimension - "num_heads" (int): number of attention heads - "num_kv_heads" (int): number of key-value heads

func BuildGroupQueryAttention added in v0.2.0

func BuildGroupQueryAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attributes map[string]interface{},
) (graph.Node[T], error)

BuildGroupQueryAttention constructs a GroupedQueryAttention node for the model builder. Unused parameters are accepted to satisfy the common builder signature.

func BuildMultiHeadLatentAttention added in v0.2.1

func BuildMultiHeadLatentAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attributes map[string]any,
) (graph.Node[T], error)

BuildMultiHeadLatentAttention constructs a MultiHeadLatentAttention node for the model builder. It reads kv_lora_dim, num_heads, head_dim, and max_seq_len from attributes, and loads W_Q, W_DKV, W_UK, W_UV, W_O from node parameters.

func BuildNativeSparseAttention added in v1.29.0

func BuildNativeSparseAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attributes map[string]interface{},
) (graph.Node[T], error)

BuildNativeSparseAttention constructs a NativeSparseAttention node for the model builder. It reads model_dim, num_heads, num_kv_heads, block_size, top_blocks, top_tokens, and window_size from attributes, and loads gate_coarse, gate_fine, gate_window from node parameters.

func BuildSparseRoutedAttention added in v1.29.0

func BuildSparseRoutedAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	name string,
	params map[string]*graph.Parameter[T],
	attributes map[string]interface{},
) (graph.Node[T], error)

BuildSparseRoutedAttention constructs a SparseRoutedAttention node for the model builder. It reads num_heads, num_kv_heads, head_dim, segment_size, top_k, max_seq_len, and rope_base from attributes. A nil KV cache is used since cache binding happens at generation time.

func QKNorm added in v0.2.0

func QKNorm[T tensor.Numeric](ctx context.Context, engine compute.Engine[T], q, k *tensor.TensorNumeric[T], epsilon float64) (qNorm, kNorm *tensor.TensorNumeric[T], err error)

QKNorm applies a form of normalization to Query (Q) and Key (K) tensors to stabilize attention score scales, similar to RMSNorm. It normalizes Q and K independently by their respective RMS values. All operations use Engine primitives so they appear in the ExecutionPlan instruction tape.

Types

type AttentionHead added in v0.2.0

type AttentionHead[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AttentionHead implements a single attention head, including linear projections for Query, Key, and Value, followed by scaled dot-product attention.

func NewAttentionHead added in v0.2.0

func NewAttentionHead[T tensor.Numeric](engine compute.Engine[T], inputDim, headDim int, opts ...AttentionHeadOption[T]) (*AttentionHead[T], error)

NewAttentionHead creates a new AttentionHead instance. inputDim is the dimension of the input features. headDim is the dimension of the query, key, and value vectors for this head.

func (*AttentionHead[T]) Attributes added in v0.2.1

func (ah *AttentionHead[T]) Attributes() map[string]interface{}

Attributes returns the attributes for the AttentionHead.

func (*AttentionHead[T]) Backward added in v0.2.0

func (ah *AttentionHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients for the AttentionHead. dOut has shape (batch, seq_len, head_dim). inputs[0] has shape (batch, seq_len, input_dim).

func (*AttentionHead[T]) Forward added in v0.2.0

func (ah *AttentionHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the output of the attention head. input is expected to be a 3D tensor (batch_size, seq_len, input_dim).

func (*AttentionHead[T]) OpType added in v0.2.1

func (ah *AttentionHead[T]) OpType() string

OpType returns the operation type of the AttentionHead.

func (*AttentionHead[T]) OutputShape added in v0.2.0

func (ah *AttentionHead[T]) OutputShape() []int

OutputShape returns the output shape of the AttentionHead. It assumes the input shape is (batch_size, seq_len, input_dim). The output shape will be (batch_size, seq_len, head_dim).

func (*AttentionHead[T]) Parameters added in v0.2.0

func (ah *AttentionHead[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters of the AttentionHead.

type AttentionHeadOption added in v0.2.0

type AttentionHeadOption[T tensor.Numeric] func(*AttentionHeadOptions[T])

AttentionHeadOption applies an option to AttentionHeadOptions.

func WithBidirectionalAttention added in v1.9.0

func WithBidirectionalAttention[T tensor.Numeric]() AttentionHeadOption[T]

WithBidirectionalAttention returns an option that disables causal masking, allowing every token to attend to every other token. This is used by encoder-style models such as BERT.

type AttentionHeadOptions added in v0.2.0

type AttentionHeadOptions[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AttentionHeadOptions holds configuration options for AttentionHead.

type BlockTableReader added in v0.2.1

type BlockTableReader[T tensor.Numeric] interface {
	// ReadKV returns the cached key and value tensors for the given layer
	// as contiguous [batch*numKVHeads, seqLen, headDim] tensors read directly
	// from blocks. Returns false if the layer has no cached data.
	ReadKV(layer int) (k, v *tensor.TensorNumeric[T], ok bool)
}

BlockTableReader reads key/value tensors directly from paged block tables, avoiding the gather-to-contiguous copy. Implementations should iterate over blocks and return the full KV sequence for a given layer.

type GQAOption added in v0.2.0

type GQAOption[T tensor.Numeric] func(*GQAOptions[T])

GQAOption is a function that applies an option to GQAOptions.

func WithBidirectionalGQA added in v1.9.0

func WithBidirectionalGQA[T tensor.Numeric]() GQAOption[T]

WithBidirectionalGQA returns an option that disables causal masking in the grouped query attention layer, allowing every position to attend to every other position. This is required for encoder-style models such as BERT.

func WithMaxSeqLen added in v0.2.0

func WithMaxSeqLen[T tensor.Numeric](maxSeqLen int) GQAOption[T]

WithMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.

func WithNoRoPE added in v1.32.0

func WithNoRoPE[T tensor.Numeric]() GQAOption[T]

WithNoRoPE returns an option that disables rotary positional embeddings. Models like GPT-2 use learned position embeddings instead of RoPE, so the GQA layer should pass Q and K through without rotational encoding.

func WithRopeBase added in v0.2.0

func WithRopeBase[T tensor.Numeric](base float64) GQAOption[T]

WithRopeBase sets the base for Rotary Positional Embeddings.

type GQAOptions added in v0.2.0

type GQAOptions[T tensor.Numeric] struct {
	Base          float64
	MaxSeqLen     int
	Bidirectional bool // when true, disables causal masking for encoder-style models
	NoRoPE        bool // when true, skip RoPE creation (for models like GPT-2 that use learned position embeddings)
}

GQAOptions holds configuration options for the GroupedQueryAttention layer.

type GlobalAttention added in v0.2.0

type GlobalAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GlobalAttention wraps GroupedQueryAttention to provide a global attention interface.

func NewGlobalAttention added in v0.2.0

func NewGlobalAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, numQueryHeads, numKeyValueHeads int,
	options ...GlobalAttentionOption,
) (*GlobalAttention[T], error)

NewGlobalAttention creates a new GlobalAttention layer.

Parameters: - engine: compute engine for tensor operations - ops: arithmetic operations for the numeric type - modelDim: model dimension - numQueryHeads: number of query heads - numKeyValueHeads: number of key/value heads - options: functional options for configuration

Default values: - base: 10000.0 - maxSeqLen: 2048.

func NewGlobalAttentionFromParams added in v0.2.0

func NewGlobalAttentionFromParams[T tensor.Numeric](gqa *GroupedQueryAttention[T]) *GlobalAttention[T]

NewGlobalAttentionFromParams creates a new GlobalAttention layer from an existing GroupedQueryAttention layer.

func (*GlobalAttention[T]) Attributes added in v0.2.0

func (ga *GlobalAttention[T]) Attributes() map[string]interface{}

Attributes returns the attributes.

func (*GlobalAttention[T]) Backward added in v0.2.0

func (ga *GlobalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward delegates the backward pass to the wrapped GroupedQueryAttention.

func (*GlobalAttention[T]) Forward added in v0.2.0

func (ga *GlobalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the forward pass of the GlobalAttention layer.

func (*GlobalAttention[T]) OpType added in v0.2.0

func (ga *GlobalAttention[T]) OpType() string

OpType returns the operation type.

func (*GlobalAttention[T]) OutputShape added in v0.2.0

func (ga *GlobalAttention[T]) OutputShape() []int

OutputShape returns the output shape of the GlobalAttention layer.

func (*GlobalAttention[T]) Parameters added in v0.2.0

func (ga *GlobalAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns the parameters of the GlobalAttention layer.

func (*GlobalAttention[T]) ScaleRope added in v0.2.0

func (ga *GlobalAttention[T]) ScaleRope(ctx context.Context, factor float64) error

ScaleRope scales the rotary positional embeddings.

func (*GlobalAttention[T]) SetLayerIndex added in v0.2.1

func (ga *GlobalAttention[T]) SetLayerIndex(idx int)

SetLayerIndex sets the layer index for KV cache routing.

type GlobalAttentionOption added in v0.2.0

type GlobalAttentionOption func(*GlobalAttentionOptions)

GlobalAttentionOption is a function that configures GlobalAttentionOptions.

func WithGlobalAttentionBase added in v0.2.0

func WithGlobalAttentionBase(base float64) GlobalAttentionOption

WithGlobalAttentionBase sets the base (theta) parameter for rotary positional embeddings.

func WithGlobalAttentionMaxSeqLen added in v0.2.0

func WithGlobalAttentionMaxSeqLen(maxSeqLen int) GlobalAttentionOption

WithGlobalAttentionMaxSeqLen sets the maximum sequence length.

type GlobalAttentionOptions added in v0.2.0

type GlobalAttentionOptions struct {
	Base      float64
	MaxSeqLen int
}

GlobalAttentionOptions holds configuration options for GlobalAttention layer.

type GroupedQueryAttention

type GroupedQueryAttention[T tensor.Numeric] struct {

	// LayerIndex identifies this layer within a model for KV cache indexing.
	LayerIndex int

	// SlidingWindowSize, if > 0, restricts attention to the last N positions
	// using a causal sliding window mask during prefill (seqLen > 1).
	SlidingWindowSize int
	// contains filtered or unexported fields
}

GroupedQueryAttention implements grouped query attention mechanism.

func NewGroupedQueryAttention

func NewGroupedQueryAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, numQueryHeads, numKeyValueHeads int,
	opts ...GQAOption[T],
) (*GroupedQueryAttention[T], error)

NewGroupedQueryAttention creates a new GroupedQueryAttention layer. modelDim: The dimension of the input and output of the block (d_model). numQueryHeads: The number of query heads. numKeyValueHeads: The number of key/value heads.

func NewGroupedQueryAttentionFromParams added in v0.2.0

func NewGroupedQueryAttentionFromParams[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, numQueryHeads, numKeyValueHeads int,
	wq, wk, wv, wo *core.Dense[T],
	rope *embeddings.RotaryPositionalEmbedding[T],
	headDimOverride ...int,
) (*GroupedQueryAttention[T], error)

NewGroupedQueryAttentionFromParams creates a new GroupedQueryAttention layer from existing parameters. headDimOverride, if > 0, sets the per-head dimension explicitly instead of deriving it from modelDim/numQueryHeads. This is required for architectures like Gemma 3 where key_length differs from hidden_size/num_heads.

func (*GroupedQueryAttention[T]) Attributes added in v0.2.0

func (gqa *GroupedQueryAttention[T]) Attributes() map[string]interface{}

Attributes returns the attributes.

func (*GroupedQueryAttention[T]) Backward

func (gqa *GroupedQueryAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients for GroupedQueryAttention.

The backward mirrors the forward in reverse order:

  1. wo backward
  2. Reverse reshape/transpose (head concatenation)
  3. SDPA backward
  4. Reverse K/V head replication (sum over group copies)
  5. RoPE backward
  6. Reverse head split (reshape/transpose back to projection shape)
  7. wq/wk/wv backward

func (*GroupedQueryAttention[T]) Forward

func (gqa *GroupedQueryAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the grouped query attention.

func (*GroupedQueryAttention[T]) MergedQKVParameter added in v0.2.1

func (gqa *GroupedQueryAttention[T]) MergedQKVParameter() *graph.Parameter[T]

MergedQKVParameter returns the merged QKV parameter for GPU upload, or nil if not set.

func (*GroupedQueryAttention[T]) OpType added in v0.2.0

func (gqa *GroupedQueryAttention[T]) OpType() string

OpType returns the operation type.

func (*GroupedQueryAttention[T]) OutputShape

func (gqa *GroupedQueryAttention[T]) OutputShape() []int

OutputShape returns the output shape of the GroupedQueryAttention.

func (*GroupedQueryAttention[T]) Parameters

func (gqa *GroupedQueryAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns the parameters of the GroupedQueryAttention layer.

func (*GroupedQueryAttention[T]) ScaleRope added in v0.2.0

func (gqa *GroupedQueryAttention[T]) ScaleRope(ctx context.Context, factor float64) error

ScaleRope scales the rotary positional embeddings.

func (*GroupedQueryAttention[T]) SetBidirectional added in v1.9.0

func (gqa *GroupedQueryAttention[T]) SetBidirectional(bidirectional bool)

SetBidirectional enables or disables bidirectional (non-causal) attention.

func (*GroupedQueryAttention[T]) SetBlockTableReader added in v0.2.1

func (gqa *GroupedQueryAttention[T]) SetBlockTableReader(r BlockTableReader[T])

SetBlockTableReader sets an optional BlockTableReader that provides KV data directly from paged block tables, bypassing the standard cache gather path.

func (*GroupedQueryAttention[T]) SetDocumentBoundaries added in v1.28.0

func (gqa *GroupedQueryAttention[T]) SetDocumentBoundaries(boundaries []int)

SetDocumentBoundaries sets document boundary positions for document-wise RoPE. When boundaries are set, position IDs reset to 0 at each boundary so each document receives independent positional encoding during multi-document inference. Boundaries are sequence positions (0-indexed) where new documents begin. Pass nil to disable document-wise mode.

func (*GroupedQueryAttention[T]) SetMergedQKV added in v0.2.1

func (gqa *GroupedQueryAttention[T]) SetMergedQKV(weight *tensor.TensorNumeric[T], qDim, kDim, vDim int)

SetMergedQKV sets a merged Q/K/V weight tensor for single-GEMV decode optimization. During decode (seqLen=1), a single MatMul with this weight replaces three separate Q/K/V projections, reducing kernel launch overhead. The output is split into Q, K, V using zero-copy GPU storage views.

func (*GroupedQueryAttention[T]) SetQKNormWeights added in v0.2.1

func (gqa *GroupedQueryAttention[T]) SetQKNormWeights(qWeight, kWeight *tensor.TensorNumeric[T], eps float32)

SetQKNormWeights stores raw RMSNorm weights for the fused QK norm+RoPE decode path. When set alongside SetQKNorms, the fused kernel replaces 4 kernel launches (Q norm, K norm, Q RoPE, K RoPE) with 1 during decode.

func (*GroupedQueryAttention[T]) SetQKNorms added in v0.2.1

func (gqa *GroupedQueryAttention[T]) SetQKNorms(qNorm, kNorm graph.Node[T])

SetQKNorms sets optional per-head RMSNorm layers for Q and K projections. Used by architectures like Gemma 3 that normalize Q/K after projection.

type LocalAttention added in v0.2.0

type LocalAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

LocalAttention implements a local, sliding-window self-attention mechanism.

func NewLocalAttention added in v0.2.0

func NewLocalAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, numQueryHeads, numKeyValueHeads, windowSize int,
	opts ...LocalAttentionOption[T],
) (*LocalAttention[T], error)

NewLocalAttention creates a new LocalAttention layer.

func (*LocalAttention[T]) Attributes added in v0.2.1

func (la *LocalAttention[T]) Attributes() map[string]interface{}

Attributes returns the attributes of the LocalAttention layer.

func (*LocalAttention[T]) Backward added in v0.2.0

func (la *LocalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward delegates the backward pass to the wrapped GroupedQueryAttention.

func (*LocalAttention[T]) Forward added in v0.2.0

func (la *LocalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the forward pass of the LocalAttention layer.

func (*LocalAttention[T]) OpType added in v0.2.1

func (la *LocalAttention[T]) OpType() string

OpType returns the operation type of the LocalAttention layer.

func (*LocalAttention[T]) OutputShape added in v0.2.0

func (la *LocalAttention[T]) OutputShape() []int

OutputShape returns the output shape of the LocalAttention layer.

func (*LocalAttention[T]) Parameters added in v0.2.0

func (la *LocalAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns the parameters of the LocalAttention layer.

type LocalAttentionOption added in v0.2.0

type LocalAttentionOption[T tensor.Numeric] func(*LocalAttentionOptions[T])

LocalAttentionOption is a function that applies an option to LocalAttentionOptions.

func WithLocalMaxSeqLen added in v0.2.0

func WithLocalMaxSeqLen[T tensor.Numeric](maxSeqLen int) LocalAttentionOption[T]

WithLocalMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.

maxSeqLen: The maximum sequence length for Rotary Positional Embeddings.

func WithLocalRopeBase added in v0.2.0

func WithLocalRopeBase[T tensor.Numeric](base float64) LocalAttentionOption[T]

WithLocalRopeBase sets the base for Rotary Positional Embeddings.

base: The base for Rotary Positional Embeddings.

type LocalAttentionOptions added in v0.2.0

type LocalAttentionOptions[T tensor.Numeric] struct {
	Base      float64
	MaxSeqLen int
}

LocalAttentionOptions holds configuration options for the LocalAttention layer.

type MultiHeadLatentAttention added in v0.2.1

type MultiHeadLatentAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

MultiHeadLatentAttention implements Multi-head Latent Attention (MLA) as used in DeepSeek V3/R1. MLA compresses KV into a low-rank latent vector, dramatically reducing KV cache size.

Partial RoPE: RoPE is applied only to the first ropeHeadDim dimensions of Q and K. The remaining dimensions are position-independent, matching the DeepSeek V3 paper specification.

func NewMultiHeadLatentAttention added in v0.2.1

func NewMultiHeadLatentAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	numHeads, headDim, kvLoraDim, ropeHeadDim int,
	wQ, wDKV, wUK, wUV, wO *core.Dense[T],
	rope *embeddings.RotaryPositionalEmbedding[T],
) *MultiHeadLatentAttention[T]

NewMultiHeadLatentAttention creates a new MLA layer. ropeHeadDim specifies how many of the headDim dimensions receive RoPE. If ropeHeadDim <= 0 or >= headDim, RoPE is applied to all dimensions (backwards-compatible behavior).

func (*MultiHeadLatentAttention[T]) Attributes added in v0.2.1

func (m *MultiHeadLatentAttention[T]) Attributes() map[string]any

Attributes returns the layer attributes.

func (*MultiHeadLatentAttention[T]) Backward added in v0.2.1

func (m *MultiHeadLatentAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for MLA. dOut: [batch, seqLen, hidden], inputs: original input to Forward.

func (*MultiHeadLatentAttention[T]) Forward added in v0.2.1

func (m *MultiHeadLatentAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the MLA forward pass. Input: [batch, seqLen, hidden] Output: [batch, seqLen, hidden]

func (*MultiHeadLatentAttention[T]) OpType added in v0.2.1

func (m *MultiHeadLatentAttention[T]) OpType() string

OpType returns the layer operation type.

func (*MultiHeadLatentAttention[T]) OutputShape added in v0.2.1

func (m *MultiHeadLatentAttention[T]) OutputShape() []int

OutputShape returns the output shape.

func (*MultiHeadLatentAttention[T]) Parameters added in v0.2.1

func (m *MultiHeadLatentAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns all trainable parameters.

type NSACoarseCompression added in v1.28.0

type NSACoarseCompression[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

NSACoarseCompression implements the coarse-grained token compression path of Native Sparse Attention (NSA). It divides the KV sequence into blocks of B tokens, computes block-level attention scores by averaging key representations per block, selects top-c blocks per query position, and attends to the selected blocks at full resolution.

func NewNSACoarseCompression added in v1.28.0

func NewNSACoarseCompression[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	blockSize, topBlocks, numHeads, numKVHeads, headDim int,
) *NSACoarseCompression[T]

NewNSACoarseCompression creates a new NSACoarseCompression layer.

Parameters:

  • engine: compute engine for tensor operations
  • ops: arithmetic operations for the numeric type
  • blockSize: number of tokens per KV block (B)
  • topBlocks: number of blocks to select per query position (c)
  • numHeads: number of query attention heads
  • numKVHeads: number of key/value heads
  • headDim: dimension of each attention head

func (*NSACoarseCompression[T]) Attributes added in v1.28.0

func (nsa *NSACoarseCompression[T]) Attributes() map[string]interface{}

Attributes returns the layer configuration.

func (*NSACoarseCompression[T]) Backward added in v1.28.0

Backward implements the straight-through estimator for the block selection. In the forward pass we use hard top-k; the backward pass passes gradients through as if the selection were soft (identity on the selected paths).

func (*NSACoarseCompression[T]) Forward added in v1.28.0

func (nsa *NSACoarseCompression[T]) Forward(_ context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the coarse-grained compression attention.

Inputs:

  • Q: [batch, numHeads, seqQ, headDim]
  • K: [batch, numKVHeads, seqKV, headDim]
  • V: [batch, numKVHeads, seqKV, headDim]

Returns output with shape [batch, numHeads, seqQ, headDim].

The algorithm:

  1. Reshape K into blocks of size B and compute block-level keys by averaging.
  2. Compute coarse attention scores: Q @ blockKeys^T.
  3. Select top-c blocks per query position.
  4. Gather full-resolution K,V from selected blocks.
  5. Compute fine-grained attention on selected blocks.

func (*NSACoarseCompression[T]) OpType added in v1.28.0

func (nsa *NSACoarseCompression[T]) OpType() string

OpType returns the operation type identifier.

func (*NSACoarseCompression[T]) OutputShape added in v1.28.0

func (nsa *NSACoarseCompression[T]) OutputShape() []int

OutputShape returns the output shape from the last forward call.

func (*NSACoarseCompression[T]) Parameters added in v1.28.0

func (nsa *NSACoarseCompression[T]) Parameters() []*graph.Parameter[T]

Parameters returns nil (no trainable parameters).

type NSAFineSelection added in v1.28.0

type NSAFineSelection[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

NSAFineSelection implements the fine-grained token selection path of Natively Sparse Attention (NSA). For each query position it scores all KV positions via Q*K^T / sqrt(d), selects the top-f tokens by score, and computes softmax attention only over those selected tokens.

func NewNSAFineSelection added in v1.28.0

func NewNSAFineSelection[T tensor.Numeric](
	engine compute.Engine[T],
	topTokens, numHeads, numKVHeads, headDim int,
) *NSAFineSelection[T]

NewNSAFineSelection creates a new NSAFineSelection layer.

Parameters:

  • engine: compute engine for tensor operations
  • topTokens: number of tokens to select per query position (f)
  • numHeads: number of query heads
  • numKVHeads: number of key/value heads
  • headDim: dimension of each head

func (*NSAFineSelection[T]) Forward added in v1.28.0

func (n *NSAFineSelection[T]) Forward(
	ctx context.Context,
	Q, K, V *tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

Forward computes the fine-grained token selection attention.

Inputs:

  • Q: [batch, numHeads, seqQ, headDim]
  • K: [batch, numKVHeads, seqKV, headDim]
  • V: [batch, numKVHeads, seqKV, headDim]

Returns:

  • output: [batch, numHeads, seqQ, headDim]

func (*NSAFineSelection[T]) SelectedIndices added in v1.28.0

func (n *NSAFineSelection[T]) SelectedIndices(
	Q, K *tensor.TensorNumeric[T],
) []int

SelectedIndices computes and returns the top-f token indices per query position without computing the full attention output. This is useful for testing that the selection logic matches manual computation.

Returns indices with shape [batch, numHeads, seqQ, topTokens] (sorted descending by score).

type NSAWindowAttention added in v1.28.0

type NSAWindowAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

NSAWindowAttention implements the sliding window attention path for Native Sparse Attention (NSA). It takes pre-projected Q, K, V tensors with shape [batch, heads, seq, dim] and applies causal sliding window attention, restricting each query to attend only to keys within a fixed-size window.

func NewNSAWindowAttention added in v1.28.0

func NewNSAWindowAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	windowSize, numHeads, numKVHeads, headDim int,
) (*NSAWindowAttention[T], error)

NewNSAWindowAttention creates a new NSAWindowAttention layer.

Parameters:

  • engine: compute engine for tensor operations
  • ops: arithmetic operations for the numeric type
  • windowSize: number of past tokens each query can attend to (the window spans from max(0, q-windowSize) to q inclusive)
  • numHeads: number of query attention heads
  • numKVHeads: number of key/value attention heads
  • headDim: dimension of each attention head

func (*NSAWindowAttention[T]) Forward added in v1.28.0

func (nw *NSAWindowAttention[T]) Forward(ctx context.Context, Q, K, V *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes sliding window attention over pre-projected Q, K, V.

Input shapes:

  • Q: [batch, numHeads, seqQ, headDim]
  • K: [batch, numKVHeads, seqKV, headDim]
  • V: [batch, numKVHeads, seqKV, headDim]

Output shape: [batch, numHeads, seqQ, headDim]

For each query position q, attention is restricted to keys in the range [max(0, q - windowSize + 1), q] (causal sliding window). Positions outside this window receive -1e9 additive masking before softmax.

func (*NSAWindowAttention[T]) Scale added in v1.28.0

func (nw *NSAWindowAttention[T]) Scale() float64

Scale returns the attention scaling factor (1/sqrt(headDim)).

func (*NSAWindowAttention[T]) WindowSize added in v1.28.0

func (nw *NSAWindowAttention[T]) WindowSize() int

WindowSize returns the configured window size.

type NativeSparseAttention added in v1.28.0

type NativeSparseAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

NativeSparseAttention implements the full Native Sparse Attention (NSA) mechanism, combining three parallel attention paths:

  • Coarse: block-level compression attention
  • Fine: token-level selection attention
  • Window: sliding window local attention

The outputs are combined via learned per-head sigmoid gates:

O = sigmoid(gateCoarse) * O_coarse + sigmoid(gateFine) * O_fine + sigmoid(gateWindow) * O_window

Gates are initialized to zero so that sigmoid(0) = 0.5 gives equal weighting at initialization.

func NewNativeSparseAttention added in v1.28.0

func NewNativeSparseAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	modelDim, numHeads, numKVHeads, blockSize, topBlocks, topTokens, windowSize int,
) (*NativeSparseAttention[T], error)

NewNativeSparseAttention creates a new NativeSparseAttention layer combining coarse, fine, and window attention paths with learned sigmoid gates.

Parameters:

  • engine: compute engine for tensor operations
  • ops: arithmetic operations for the numeric type
  • modelDim: model dimension (unused directly, reserved for projection layers)
  • numHeads: number of query attention heads
  • numKVHeads: number of key/value attention heads
  • blockSize: number of tokens per KV block for coarse path
  • topBlocks: number of blocks to select for coarse path
  • topTokens: number of tokens to select for fine path
  • windowSize: sliding window size for window path

func (*NativeSparseAttention[T]) Attributes added in v1.28.0

func (nsa *NativeSparseAttention[T]) Attributes() map[string]interface{}

Attributes returns the layer configuration.

func (*NativeSparseAttention[T]) Backward added in v1.28.0

func (nsa *NativeSparseAttention[T]) Backward(_ context.Context, _ types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes gradients for the NativeSparseAttention layer. Gradients flow through the sigmoid gates to the gate parameters and through each attention path via straight-through estimation.

func (*NativeSparseAttention[T]) Forward added in v1.28.0

func (nsa *NativeSparseAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes NativeSparseAttention by running all three paths and combining with learned sigmoid gates.

Inputs:

  • Q: [batch, numHeads, seqQ, headDim]
  • K: [batch, numKVHeads, seqKV, headDim]
  • V: [batch, numKVHeads, seqKV, headDim]

Returns output with shape [batch, numHeads, seqQ, headDim].

func (*NativeSparseAttention[T]) OpType added in v1.28.0

func (nsa *NativeSparseAttention[T]) OpType() string

OpType returns the operation type identifier.

func (*NativeSparseAttention[T]) OutputShape added in v1.28.0

func (nsa *NativeSparseAttention[T]) OutputShape() []int

OutputShape returns the output shape from the last forward call.

func (*NativeSparseAttention[T]) Parameters added in v1.28.0

func (nsa *NativeSparseAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns the trainable gate parameters.

type RopeScaler added in v0.2.0

type RopeScaler[T tensor.Numeric] interface {
	ScaleRope(ctx context.Context, factor float64) error
}

RopeScaler is an interface for layers that support scaling of RoPE.

type ScaledDotProductAttention

type ScaledDotProductAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ScaledDotProductAttention implements the scaled dot-product attention mechanism.

func NewBidirectionalSDPA added in v1.9.0

func NewBidirectionalSDPA[T tensor.Numeric](engine compute.Engine[T], headDim int, opts ...ScaledDotProductAttentionOption[T]) *ScaledDotProductAttention[T]

NewBidirectionalSDPA creates a ScaledDotProductAttention layer with causal masking disabled. All positions attend to all other positions, which is the attention pattern used by encoder models such as BERT.

func NewScaledDotProductAttention

func NewScaledDotProductAttention[T tensor.Numeric](engine compute.Engine[T], headDim int, opts ...ScaledDotProductAttentionOption[T]) *ScaledDotProductAttention[T]

NewScaledDotProductAttention creates a new ScaledDotProductAttention layer.

func (*ScaledDotProductAttention[T]) Backward

func (sdpa *ScaledDotProductAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut, _, _, _ *tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward computes the gradients for ScaledDotProductAttention. dOut is the gradient from the subsequent layer.

func (*ScaledDotProductAttention[T]) Forward

func (sdpa *ScaledDotProductAttention[T]) Forward(ctx context.Context, q, k, v, mask *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes the scaled dot-product attention. Q, K, V are expected to be 3D tensors (batch_size, seq_len, head_dim). mask is an optional 4D tensor (batch_size, num_heads, seq_len_q, seq_len_k).

func (*ScaledDotProductAttention[T]) SetCausal added in v0.2.1

func (sdpa *ScaledDotProductAttention[T]) SetCausal(causal bool)

SetCausal enables or disables causal (lower-triangular) masking.

type ScaledDotProductAttentionOption added in v0.2.0

type ScaledDotProductAttentionOption[T tensor.Numeric] func(*ScaledDotProductAttentionOptions[T])

ScaledDotProductAttentionOption applies an option to ScaledDotProductAttentionOptions.

func WithBidirectional added in v1.9.0

func WithBidirectional[T tensor.Numeric]() ScaledDotProductAttentionOption[T]

WithBidirectional returns an option that disables causal masking, allowing every position to attend to every other position. This is required for encoder-style models such as BERT.

func WithHeadCounts added in v1.29.0

func WithHeadCounts[T tensor.Numeric](numQueryHeads, numKVHeads int) ScaledDotProductAttentionOption[T]

WithHeadCounts sets the query and KV head counts, enabling the split-KV flash decode kernel for autoregressive decode with GQA support.

type ScaledDotProductAttentionOptions added in v0.2.0

type ScaledDotProductAttentionOptions[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ScaledDotProductAttentionOptions holds configuration options for ScaledDotProductAttention.

type SparseRoutedAttention added in v1.29.0

type SparseRoutedAttention[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

SparseRoutedAttention routes queries to a subset of KV segments using cosine similarity scoring. It divides the KV sequence into fixed-size segments, computes cosine similarity between each query and segment centroids (mean key vectors), selects the top-k most similar segments, and performs scaled dot-product attention over the selected segments.

Position encoding uses document-wise RoPE so that position IDs reset at document boundaries during multi-document inference.

KV history is stored in a CompressedKVCache (via SparseRoutedKVCache interface) for efficient long-context inference.

func NewSparseRoutedAttention added in v1.29.0

func NewSparseRoutedAttention[T tensor.Numeric](
	engine compute.Engine[T],
	ops numeric.Arithmetic[T],
	rope *embeddings.RotaryPositionalEmbedding[T],
	kvCache SparseRoutedKVCache[T],
	numHeads, numKVHeads, headDim, segmentSize, topK int,
) (*SparseRoutedAttention[T], error)

NewSparseRoutedAttention creates a new SparseRoutedAttention layer.

Parameters:

  • engine: compute engine for tensor operations (must support CosineSimilarity)
  • ops: arithmetic operations for the numeric type
  • rope: rotary positional embedding (supports document-wise mode)
  • kvCache: compressed KV cache for storing key-value pairs (satisfies SparseRoutedKVCache)
  • numHeads: number of query attention heads
  • numKVHeads: number of key/value attention heads
  • headDim: dimension of each attention head
  • segmentSize: number of tokens per KV segment for routing
  • topK: number of segments to select per query position

func (*SparseRoutedAttention[T]) Attributes added in v1.29.0

func (sra *SparseRoutedAttention[T]) Attributes() map[string]interface{}

Attributes returns the layer configuration.

func (*SparseRoutedAttention[T]) Backward added in v1.29.0

Backward computes gradients for the SparseRoutedAttention layer. Uses straight-through estimation for the routing selection.

func (*SparseRoutedAttention[T]) Forward added in v1.29.0

func (sra *SparseRoutedAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward computes sparse routed attention.

Inputs:

  • Q: [batch, numHeads, seqQ, headDim]
  • K: [batch, numKVHeads, seqKV, headDim]
  • V: [batch, numKVHeads, seqKV, headDim]

Returns output with shape [batch, numHeads, seqQ, headDim].

The algorithm:

  1. Divide K into segments of segmentSize tokens and compute segment centroids.
  2. For each query position, compute cosine similarity with all centroids.
  3. Select top-k segments per query.
  4. Gather full-resolution K,V from selected segments.
  5. Apply RoPE to Q and selected K.
  6. Compute scaled dot-product attention over selected segments.

func (*SparseRoutedAttention[T]) OpType added in v1.29.0

func (sra *SparseRoutedAttention[T]) OpType() string

OpType returns the operation type identifier.

func (*SparseRoutedAttention[T]) OutputShape added in v1.29.0

func (sra *SparseRoutedAttention[T]) OutputShape() []int

OutputShape returns the output shape from the last forward call.

func (*SparseRoutedAttention[T]) Parameters added in v1.29.0

func (sra *SparseRoutedAttention[T]) Parameters() []*graph.Parameter[T]

Parameters returns nil (no trainable parameters).

func (*SparseRoutedAttention[T]) SetDocumentBoundaries added in v1.29.0

func (sra *SparseRoutedAttention[T]) SetDocumentBoundaries(boundaries []int)

SetDocumentBoundaries sets document boundary positions for document-wise RoPE. Position IDs reset to 0 at each boundary so each document receives independent positional encoding. Pass nil to disable.

type SparseRoutedKVCache added in v1.29.0

type SparseRoutedKVCache[T tensor.Numeric] interface {
	NumLayers() int
	SeqLen() int
	Reset()
}

SparseRoutedKVCache is the interface for compressed KV caches used by SparseRoutedAttention. This is satisfied by generate.CompressedKVCache and avoids an import cycle with the generate package.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL