Documentation
¶
Overview ¶
Package attention provides attention mechanisms for neural networks.
Stability: stable
Package attention provides attention mechanisms for neural networks.
Package attention provides attention mechanisms for neural networks.
Index ¶
- func BuildCausalSlidingWindowMask[T tensor.Numeric](seqLen, windowSize int) *tensor.TensorNumeric[T]
- func BuildGlobalAttention[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], _ string, ...) (graph.Node[T], error)
- func BuildGroupQueryAttention[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], name string, ...) (graph.Node[T], error)
- func BuildMultiHeadLatentAttention[T tensor.Numeric](engine compute.Engine[T], ops numeric.Arithmetic[T], name string, ...) (graph.Node[T], error)
- func QKNorm[T tensor.Numeric](ctx context.Context, engine compute.Engine[T], q, k *tensor.TensorNumeric[T], ...) (qNorm, kNorm *tensor.TensorNumeric[T], err error)
- type AttentionHead
- func (ah *AttentionHead[T]) Attributes() map[string]interface{}
- func (ah *AttentionHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (ah *AttentionHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (ah *AttentionHead[T]) OpType() string
- func (ah *AttentionHead[T]) OutputShape() []int
- func (ah *AttentionHead[T]) Parameters() []*graph.Parameter[T]
- type AttentionHeadOption
- type AttentionHeadOptions
- type BlockTableReader
- type GQAOption
- type GQAOptions
- type GlobalAttention
- func (ga *GlobalAttention[T]) Attributes() map[string]interface{}
- func (ga *GlobalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (ga *GlobalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (ga *GlobalAttention[T]) OpType() string
- func (ga *GlobalAttention[T]) OutputShape() []int
- func (ga *GlobalAttention[T]) Parameters() []*graph.Parameter[T]
- func (ga *GlobalAttention[T]) ScaleRope(ctx context.Context, factor float64) error
- func (ga *GlobalAttention[T]) SetLayerIndex(idx int)
- type GlobalAttentionOption
- type GlobalAttentionOptions
- type GroupedQueryAttention
- func (gqa *GroupedQueryAttention[T]) Attributes() map[string]interface{}
- func (gqa *GroupedQueryAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (gqa *GroupedQueryAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (gqa *GroupedQueryAttention[T]) MergedQKVParameter() *graph.Parameter[T]
- func (gqa *GroupedQueryAttention[T]) OpType() string
- func (gqa *GroupedQueryAttention[T]) OutputShape() []int
- func (gqa *GroupedQueryAttention[T]) Parameters() []*graph.Parameter[T]
- func (gqa *GroupedQueryAttention[T]) ScaleRope(ctx context.Context, factor float64) error
- func (gqa *GroupedQueryAttention[T]) SetBidirectional(bidirectional bool)
- func (gqa *GroupedQueryAttention[T]) SetBlockTableReader(r BlockTableReader[T])
- func (gqa *GroupedQueryAttention[T]) SetMergedQKV(weight *tensor.TensorNumeric[T], qDim, kDim, vDim int)
- func (gqa *GroupedQueryAttention[T]) SetQKNormWeights(qWeight, kWeight *tensor.TensorNumeric[T], eps float32)
- func (gqa *GroupedQueryAttention[T]) SetQKNorms(qNorm, kNorm graph.Node[T])
- type LocalAttention
- func (la *LocalAttention[T]) Attributes() map[string]interface{}
- func (la *LocalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (la *LocalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (la *LocalAttention[T]) OpType() string
- func (la *LocalAttention[T]) OutputShape() []int
- func (la *LocalAttention[T]) Parameters() []*graph.Parameter[T]
- type LocalAttentionOption
- type LocalAttentionOptions
- type MultiHeadLatentAttention
- func (m *MultiHeadLatentAttention[T]) Attributes() map[string]any
- func (m *MultiHeadLatentAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (m *MultiHeadLatentAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (m *MultiHeadLatentAttention[T]) OpType() string
- func (m *MultiHeadLatentAttention[T]) OutputShape() []int
- func (m *MultiHeadLatentAttention[T]) Parameters() []*graph.Parameter[T]
- type RopeScaler
- type ScaledDotProductAttention
- func (sdpa *ScaledDotProductAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (sdpa *ScaledDotProductAttention[T]) Forward(ctx context.Context, q, k, v, mask *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (sdpa *ScaledDotProductAttention[T]) SetCausal(causal bool)
- type ScaledDotProductAttentionOption
- type ScaledDotProductAttentionOptions
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func BuildCausalSlidingWindowMask ¶ added in v0.2.1
func BuildCausalSlidingWindowMask[T tensor.Numeric](seqLen, windowSize int) *tensor.TensorNumeric[T]
BuildCausalSlidingWindowMask creates a causal attention mask that also restricts attention to the last windowSize positions. Positions outside the window or in the future are set to a large negative value. Shape: [1, 1, seqLen, seqLen].
func BuildGlobalAttention ¶ added in v0.2.1
func BuildGlobalAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], _ string, _ map[string]*graph.Parameter[T], attributes map[string]interface{}, ) (graph.Node[T], error)
BuildGlobalAttention constructs a GlobalAttention node from attributes. Required attributes: - "embed_dim" (int): embedding dimension - "num_heads" (int): number of attention heads - "num_kv_heads" (int): number of key-value heads
func BuildGroupQueryAttention ¶ added in v0.2.0
func BuildGroupQueryAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], name string, params map[string]*graph.Parameter[T], attributes map[string]interface{}, ) (graph.Node[T], error)
BuildGroupQueryAttention constructs a GroupedQueryAttention node for the model builder. Unused parameters are accepted to satisfy the common builder signature.
func BuildMultiHeadLatentAttention ¶ added in v0.2.1
func BuildMultiHeadLatentAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], name string, params map[string]*graph.Parameter[T], attributes map[string]any, ) (graph.Node[T], error)
BuildMultiHeadLatentAttention constructs a MultiHeadLatentAttention node for the model builder. It reads kv_lora_dim, num_heads, head_dim, and max_seq_len from attributes, and loads W_Q, W_DKV, W_UK, W_UV, W_O from node parameters.
func QKNorm ¶ added in v0.2.0
func QKNorm[T tensor.Numeric](ctx context.Context, engine compute.Engine[T], q, k *tensor.TensorNumeric[T], epsilon float64) (qNorm, kNorm *tensor.TensorNumeric[T], err error)
QKNorm applies a form of normalization to Query (Q) and Key (K) tensors to stabilize attention score scales, similar to RMSNorm. It normalizes Q and K independently by their respective RMS values. All operations use Engine primitives so they appear in the ExecutionPlan instruction tape.
Types ¶
type AttentionHead ¶ added in v0.2.0
AttentionHead implements a single attention head, including linear projections for Query, Key, and Value, followed by scaled dot-product attention.
func NewAttentionHead ¶ added in v0.2.0
func NewAttentionHead[T tensor.Numeric](engine compute.Engine[T], inputDim, headDim int, opts ...AttentionHeadOption[T]) (*AttentionHead[T], error)
NewAttentionHead creates a new AttentionHead instance. inputDim is the dimension of the input features. headDim is the dimension of the query, key, and value vectors for this head.
func (*AttentionHead[T]) Attributes ¶ added in v0.2.1
func (ah *AttentionHead[T]) Attributes() map[string]interface{}
Attributes returns the attributes for the AttentionHead.
func (*AttentionHead[T]) Backward ¶ added in v0.2.0
func (ah *AttentionHead[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for the AttentionHead. dOut has shape (batch, seq_len, head_dim). inputs[0] has shape (batch, seq_len, input_dim).
func (*AttentionHead[T]) Forward ¶ added in v0.2.0
func (ah *AttentionHead[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the output of the attention head. input is expected to be a 3D tensor (batch_size, seq_len, input_dim).
func (*AttentionHead[T]) OpType ¶ added in v0.2.1
func (ah *AttentionHead[T]) OpType() string
OpType returns the operation type of the AttentionHead.
func (*AttentionHead[T]) OutputShape ¶ added in v0.2.0
func (ah *AttentionHead[T]) OutputShape() []int
OutputShape returns the output shape of the AttentionHead. It assumes the input shape is (batch_size, seq_len, input_dim). The output shape will be (batch_size, seq_len, head_dim).
func (*AttentionHead[T]) Parameters ¶ added in v0.2.0
func (ah *AttentionHead[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters of the AttentionHead.
type AttentionHeadOption ¶ added in v0.2.0
type AttentionHeadOption[T tensor.Numeric] func(*AttentionHeadOptions[T])
AttentionHeadOption applies an option to AttentionHeadOptions.
func WithBidirectionalAttention ¶ added in v1.9.0
func WithBidirectionalAttention[T tensor.Numeric]() AttentionHeadOption[T]
WithBidirectionalAttention returns an option that disables causal masking, allowing every token to attend to every other token. This is used by encoder-style models such as BERT.
type AttentionHeadOptions ¶ added in v0.2.0
AttentionHeadOptions holds configuration options for AttentionHead.
type BlockTableReader ¶ added in v0.2.1
type BlockTableReader[T tensor.Numeric] interface { // ReadKV returns the cached key and value tensors for the given layer // as contiguous [batch*numKVHeads, seqLen, headDim] tensors read directly // from blocks. Returns false if the layer has no cached data. ReadKV(layer int) (k, v *tensor.TensorNumeric[T], ok bool) }
BlockTableReader reads key/value tensors directly from paged block tables, avoiding the gather-to-contiguous copy. Implementations should iterate over blocks and return the full KV sequence for a given layer.
type GQAOption ¶ added in v0.2.0
type GQAOption[T tensor.Numeric] func(*GQAOptions[T])
GQAOption is a function that applies an option to GQAOptions.
func WithBidirectionalGQA ¶ added in v1.9.0
WithBidirectionalGQA returns an option that disables causal masking in the grouped query attention layer, allowing every position to attend to every other position. This is required for encoder-style models such as BERT.
func WithMaxSeqLen ¶ added in v0.2.0
WithMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.
type GQAOptions ¶ added in v0.2.0
type GQAOptions[T tensor.Numeric] struct { Base float64 MaxSeqLen int Bidirectional bool // when true, disables causal masking for encoder-style models }
GQAOptions holds configuration options for the GroupedQueryAttention layer.
type GlobalAttention ¶ added in v0.2.0
GlobalAttention wraps GroupedQueryAttention to provide a global attention interface.
func NewGlobalAttention ¶ added in v0.2.0
func NewGlobalAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, options ...GlobalAttentionOption, ) (*GlobalAttention[T], error)
NewGlobalAttention creates a new GlobalAttention layer.
Parameters: - engine: compute engine for tensor operations - ops: arithmetic operations for the numeric type - modelDim: model dimension - numQueryHeads: number of query heads - numKeyValueHeads: number of key/value heads - options: functional options for configuration
Default values: - base: 10000.0 - maxSeqLen: 2048.
func NewGlobalAttentionFromParams ¶ added in v0.2.0
func NewGlobalAttentionFromParams[T tensor.Numeric](gqa *GroupedQueryAttention[T]) *GlobalAttention[T]
NewGlobalAttentionFromParams creates a new GlobalAttention layer from an existing GroupedQueryAttention layer.
func (*GlobalAttention[T]) Attributes ¶ added in v0.2.0
func (ga *GlobalAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes.
func (*GlobalAttention[T]) Backward ¶ added in v0.2.0
func (ga *GlobalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward delegates the backward pass to the wrapped GroupedQueryAttention.
func (*GlobalAttention[T]) Forward ¶ added in v0.2.0
func (ga *GlobalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the forward pass of the GlobalAttention layer.
func (*GlobalAttention[T]) OpType ¶ added in v0.2.0
func (ga *GlobalAttention[T]) OpType() string
OpType returns the operation type.
func (*GlobalAttention[T]) OutputShape ¶ added in v0.2.0
func (ga *GlobalAttention[T]) OutputShape() []int
OutputShape returns the output shape of the GlobalAttention layer.
func (*GlobalAttention[T]) Parameters ¶ added in v0.2.0
func (ga *GlobalAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the GlobalAttention layer.
func (*GlobalAttention[T]) ScaleRope ¶ added in v0.2.0
func (ga *GlobalAttention[T]) ScaleRope(ctx context.Context, factor float64) error
ScaleRope scales the rotary positional embeddings.
func (*GlobalAttention[T]) SetLayerIndex ¶ added in v0.2.1
func (ga *GlobalAttention[T]) SetLayerIndex(idx int)
SetLayerIndex sets the layer index for KV cache routing.
type GlobalAttentionOption ¶ added in v0.2.0
type GlobalAttentionOption func(*GlobalAttentionOptions)
GlobalAttentionOption is a function that configures GlobalAttentionOptions.
func WithGlobalAttentionBase ¶ added in v0.2.0
func WithGlobalAttentionBase(base float64) GlobalAttentionOption
WithGlobalAttentionBase sets the base (theta) parameter for rotary positional embeddings.
func WithGlobalAttentionMaxSeqLen ¶ added in v0.2.0
func WithGlobalAttentionMaxSeqLen(maxSeqLen int) GlobalAttentionOption
WithGlobalAttentionMaxSeqLen sets the maximum sequence length.
type GlobalAttentionOptions ¶ added in v0.2.0
GlobalAttentionOptions holds configuration options for GlobalAttention layer.
type GroupedQueryAttention ¶
type GroupedQueryAttention[T tensor.Numeric] struct { // LayerIndex identifies this layer within a model for KV cache indexing. LayerIndex int // SlidingWindowSize, if > 0, restricts attention to the last N positions // using a causal sliding window mask during prefill (seqLen > 1). SlidingWindowSize int // contains filtered or unexported fields }
GroupedQueryAttention implements grouped query attention mechanism.
func NewGroupedQueryAttention ¶
func NewGroupedQueryAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, opts ...GQAOption[T], ) (*GroupedQueryAttention[T], error)
NewGroupedQueryAttention creates a new GroupedQueryAttention layer. modelDim: The dimension of the input and output of the block (d_model). numQueryHeads: The number of query heads. numKeyValueHeads: The number of key/value heads.
func NewGroupedQueryAttentionFromParams ¶ added in v0.2.0
func NewGroupedQueryAttentionFromParams[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads int, wq, wk, wv, wo *core.Dense[T], rope *embeddings.RotaryPositionalEmbedding[T], headDimOverride ...int, ) (*GroupedQueryAttention[T], error)
NewGroupedQueryAttentionFromParams creates a new GroupedQueryAttention layer from existing parameters. headDimOverride, if > 0, sets the per-head dimension explicitly instead of deriving it from modelDim/numQueryHeads. This is required for architectures like Gemma 3 where key_length differs from hidden_size/num_heads.
func (*GroupedQueryAttention[T]) Attributes ¶ added in v0.2.0
func (gqa *GroupedQueryAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes.
func (*GroupedQueryAttention[T]) Backward ¶
func (gqa *GroupedQueryAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for GroupedQueryAttention.
The backward mirrors the forward in reverse order:
- wo backward
- Reverse reshape/transpose (head concatenation)
- SDPA backward
- Reverse K/V head replication (sum over group copies)
- RoPE backward
- Reverse head split (reshape/transpose back to projection shape)
- wq/wk/wv backward
func (*GroupedQueryAttention[T]) Forward ¶
func (gqa *GroupedQueryAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the grouped query attention.
func (*GroupedQueryAttention[T]) MergedQKVParameter ¶ added in v0.2.1
func (gqa *GroupedQueryAttention[T]) MergedQKVParameter() *graph.Parameter[T]
MergedQKVParameter returns the merged QKV parameter for GPU upload, or nil if not set.
func (*GroupedQueryAttention[T]) OpType ¶ added in v0.2.0
func (gqa *GroupedQueryAttention[T]) OpType() string
OpType returns the operation type.
func (*GroupedQueryAttention[T]) OutputShape ¶
func (gqa *GroupedQueryAttention[T]) OutputShape() []int
OutputShape returns the output shape of the GroupedQueryAttention.
func (*GroupedQueryAttention[T]) Parameters ¶
func (gqa *GroupedQueryAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the GroupedQueryAttention layer.
func (*GroupedQueryAttention[T]) ScaleRope ¶ added in v0.2.0
func (gqa *GroupedQueryAttention[T]) ScaleRope(ctx context.Context, factor float64) error
ScaleRope scales the rotary positional embeddings.
func (*GroupedQueryAttention[T]) SetBidirectional ¶ added in v1.9.0
func (gqa *GroupedQueryAttention[T]) SetBidirectional(bidirectional bool)
SetBidirectional enables or disables bidirectional (non-causal) attention.
func (*GroupedQueryAttention[T]) SetBlockTableReader ¶ added in v0.2.1
func (gqa *GroupedQueryAttention[T]) SetBlockTableReader(r BlockTableReader[T])
SetBlockTableReader sets an optional BlockTableReader that provides KV data directly from paged block tables, bypassing the standard cache gather path.
func (*GroupedQueryAttention[T]) SetMergedQKV ¶ added in v0.2.1
func (gqa *GroupedQueryAttention[T]) SetMergedQKV(weight *tensor.TensorNumeric[T], qDim, kDim, vDim int)
SetMergedQKV sets a merged Q/K/V weight tensor for single-GEMV decode optimization. During decode (seqLen=1), a single MatMul with this weight replaces three separate Q/K/V projections, reducing kernel launch overhead. The output is split into Q, K, V using zero-copy GPU storage views.
func (*GroupedQueryAttention[T]) SetQKNormWeights ¶ added in v0.2.1
func (gqa *GroupedQueryAttention[T]) SetQKNormWeights(qWeight, kWeight *tensor.TensorNumeric[T], eps float32)
SetQKNormWeights stores raw RMSNorm weights for the fused QK norm+RoPE decode path. When set alongside SetQKNorms, the fused kernel replaces 4 kernel launches (Q norm, K norm, Q RoPE, K RoPE) with 1 during decode.
func (*GroupedQueryAttention[T]) SetQKNorms ¶ added in v0.2.1
func (gqa *GroupedQueryAttention[T]) SetQKNorms(qNorm, kNorm graph.Node[T])
SetQKNorms sets optional per-head RMSNorm layers for Q and K projections. Used by architectures like Gemma 3 that normalize Q/K after projection.
type LocalAttention ¶ added in v0.2.0
LocalAttention implements a local, sliding-window self-attention mechanism.
func NewLocalAttention ¶ added in v0.2.0
func NewLocalAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], modelDim, numQueryHeads, numKeyValueHeads, windowSize int, opts ...LocalAttentionOption[T], ) (*LocalAttention[T], error)
NewLocalAttention creates a new LocalAttention layer.
func (*LocalAttention[T]) Attributes ¶ added in v0.2.1
func (la *LocalAttention[T]) Attributes() map[string]interface{}
Attributes returns the attributes of the LocalAttention layer.
func (*LocalAttention[T]) Backward ¶ added in v0.2.0
func (la *LocalAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward delegates the backward pass to the wrapped GroupedQueryAttention.
func (*LocalAttention[T]) Forward ¶ added in v0.2.0
func (la *LocalAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the forward pass of the LocalAttention layer.
func (*LocalAttention[T]) OpType ¶ added in v0.2.1
func (la *LocalAttention[T]) OpType() string
OpType returns the operation type of the LocalAttention layer.
func (*LocalAttention[T]) OutputShape ¶ added in v0.2.0
func (la *LocalAttention[T]) OutputShape() []int
OutputShape returns the output shape of the LocalAttention layer.
func (*LocalAttention[T]) Parameters ¶ added in v0.2.0
func (la *LocalAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns the parameters of the LocalAttention layer.
type LocalAttentionOption ¶ added in v0.2.0
type LocalAttentionOption[T tensor.Numeric] func(*LocalAttentionOptions[T])
LocalAttentionOption is a function that applies an option to LocalAttentionOptions.
func WithLocalMaxSeqLen ¶ added in v0.2.0
func WithLocalMaxSeqLen[T tensor.Numeric](maxSeqLen int) LocalAttentionOption[T]
WithLocalMaxSeqLen sets the maximum sequence length for Rotary Positional Embeddings.
maxSeqLen: The maximum sequence length for Rotary Positional Embeddings.
func WithLocalRopeBase ¶ added in v0.2.0
func WithLocalRopeBase[T tensor.Numeric](base float64) LocalAttentionOption[T]
WithLocalRopeBase sets the base for Rotary Positional Embeddings.
base: The base for Rotary Positional Embeddings.
type LocalAttentionOptions ¶ added in v0.2.0
LocalAttentionOptions holds configuration options for the LocalAttention layer.
type MultiHeadLatentAttention ¶ added in v0.2.1
type MultiHeadLatentAttention[T tensor.Numeric] struct { // contains filtered or unexported fields }
MultiHeadLatentAttention implements Multi-head Latent Attention (MLA) as used in DeepSeek V3/R1. MLA compresses KV into a low-rank latent vector, dramatically reducing KV cache size.
Partial RoPE: RoPE is applied only to the first ropeHeadDim dimensions of Q and K. The remaining dimensions are position-independent, matching the DeepSeek V3 paper specification.
func NewMultiHeadLatentAttention ¶ added in v0.2.1
func NewMultiHeadLatentAttention[T tensor.Numeric]( engine compute.Engine[T], ops numeric.Arithmetic[T], numHeads, headDim, kvLoraDim, ropeHeadDim int, wQ, wDKV, wUK, wUV, wO *core.Dense[T], rope *embeddings.RotaryPositionalEmbedding[T], ) *MultiHeadLatentAttention[T]
NewMultiHeadLatentAttention creates a new MLA layer. ropeHeadDim specifies how many of the headDim dimensions receive RoPE. If ropeHeadDim <= 0 or >= headDim, RoPE is applied to all dimensions (backwards-compatible behavior).
func (*MultiHeadLatentAttention[T]) Attributes ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) Attributes() map[string]any
Attributes returns the layer attributes.
func (*MultiHeadLatentAttention[T]) Backward ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes gradients for MLA. dOut: [batch, seqLen, hidden], inputs: original input to Forward.
func (*MultiHeadLatentAttention[T]) Forward ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the MLA forward pass. Input: [batch, seqLen, hidden] Output: [batch, seqLen, hidden]
func (*MultiHeadLatentAttention[T]) OpType ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) OpType() string
OpType returns the layer operation type.
func (*MultiHeadLatentAttention[T]) OutputShape ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) OutputShape() []int
OutputShape returns the output shape.
func (*MultiHeadLatentAttention[T]) Parameters ¶ added in v0.2.1
func (m *MultiHeadLatentAttention[T]) Parameters() []*graph.Parameter[T]
Parameters returns all trainable parameters.
type RopeScaler ¶ added in v0.2.0
type RopeScaler[T tensor.Numeric] interface { ScaleRope(ctx context.Context, factor float64) error }
RopeScaler is an interface for layers that support scaling of RoPE.
type ScaledDotProductAttention ¶
type ScaledDotProductAttention[T tensor.Numeric] struct { // contains filtered or unexported fields }
ScaledDotProductAttention implements the scaled dot-product attention mechanism.
func NewBidirectionalSDPA ¶ added in v1.9.0
func NewBidirectionalSDPA[T tensor.Numeric](engine compute.Engine[T], headDim int, opts ...ScaledDotProductAttentionOption[T]) *ScaledDotProductAttention[T]
NewBidirectionalSDPA creates a ScaledDotProductAttention layer with causal masking disabled. All positions attend to all other positions, which is the attention pattern used by encoder models such as BERT.
func NewScaledDotProductAttention ¶
func NewScaledDotProductAttention[T tensor.Numeric](engine compute.Engine[T], headDim int, opts ...ScaledDotProductAttentionOption[T]) *ScaledDotProductAttention[T]
NewScaledDotProductAttention creates a new ScaledDotProductAttention layer.
func (*ScaledDotProductAttention[T]) Backward ¶
func (sdpa *ScaledDotProductAttention[T]) Backward(ctx context.Context, mode types.BackwardMode, dOut, _, _, _ *tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for ScaledDotProductAttention. dOut is the gradient from the subsequent layer.
func (*ScaledDotProductAttention[T]) Forward ¶
func (sdpa *ScaledDotProductAttention[T]) Forward(ctx context.Context, q, k, v, mask *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the scaled dot-product attention. Q, K, V are expected to be 3D tensors (batch_size, seq_len, head_dim). mask is an optional 4D tensor (batch_size, num_heads, seq_len_q, seq_len_k).
func (*ScaledDotProductAttention[T]) SetCausal ¶ added in v0.2.1
func (sdpa *ScaledDotProductAttention[T]) SetCausal(causal bool)
SetCausal enables or disables causal (lower-triangular) masking.
type ScaledDotProductAttentionOption ¶ added in v0.2.0
type ScaledDotProductAttentionOption[T tensor.Numeric] func(*ScaledDotProductAttentionOptions[T])
ScaledDotProductAttentionOption applies an option to ScaledDotProductAttentionOptions.
func WithBidirectional ¶ added in v1.9.0
func WithBidirectional[T tensor.Numeric]() ScaledDotProductAttentionOption[T]
WithBidirectional returns an option that disables causal masking, allowing every position to attend to every other position. This is required for encoder-style models such as BERT.
type ScaledDotProductAttentionOptions ¶ added in v0.2.0
type ScaledDotProductAttentionOptions[T tensor.Numeric] struct { // contains filtered or unexported fields }
ScaledDotProductAttentionOptions holds configuration options for ScaledDotProductAttention.