compute

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 16, 2026 License: Apache-2.0 Imports: 26 Imported by: 0

Documentation

Overview

Package compute implements tensor computation engines and operations.

Index

Constants

This section is empty.

Variables

View Source
var ErrMemoryLimitExceeded = errors.New("memory limit exceeded")

ErrMemoryLimitExceeded is returned when a tensor allocation would exceed the configured memory limit.

Functions

func FusedRMSNorm

func FusedRMSNorm(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error)

FusedRMSNorm computes x * rsqrt(mean(x^2) + eps) * weight in a single pass. This avoids materializing squared, mean, and rsqrt intermediate tensors. Input shape: [..., D] where D is the last dimension (hidden size). Weight shape: [D]. Returns (output, scales) where output has same shape as input and scales has shape [..., 1] containing the per-row rsqrt(mean(x^2)+eps) values.

func FusedRoPE

func FusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[float32], rotaryDim int) (*tensor.TensorNumeric[float32], error)

FusedRoPE applies rotary position embeddings in a single pass. Input shape: [batch, seq_len, head_dim] where head_dim is even. cos/sin shape: [seq_len, half_dim] (precomputed angles). rotaryDim: number of dimensions that receive rotation (<= head_dim, must be even). For each position (b, s):

out[..., i]            = in[..., i] * cos[s,i] - in[..., i+half] * sin[s,i]      (i < half)
out[..., i+half]       = in[..., i+half] * cos[s,i] + in[..., i] * sin[s,i]      (i < half)
out[..., rotaryDim..]  = in[..., rotaryDim..]                                      (pass-through)

func FusedSiLUGate

func FusedSiLUGate(gate, up *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)

FusedSiLUGate computes silu(gate) * up in a single element-wise pass. SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)). gate and up must have the same shape. This avoids materializing separate sigmoid, mul, and mul intermediate tensors.

Types

type CPUEngine

type CPUEngine[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

CPUEngine is a CPU-based implementation of the Engine interface.

func NewCPUEngine

func NewCPUEngine[T tensor.Numeric](ops numeric.Arithmetic[T]) *CPUEngine[T]

NewCPUEngine constructs a new CPUEngine for the given numeric operations. A no-op logger and no-op collector are used by default; call SetLogger/SetCollector to override.

func (*CPUEngine[T]) Add

func (e *CPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Add performs element-wise addition with broadcasting.

func (*CPUEngine[T]) AddScalar

func (e *CPUEngine[T]) AddScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

AddScalar performs element-wise addition of a tensor by a scalar.

func (*CPUEngine[T]) Close

func (e *CPUEngine[T]) Close(_ context.Context) error

Close is a no-op for CPUEngine. It satisfies the shutdown.Closer interface.

func (*CPUEngine[T]) Concat

func (e *CPUEngine[T]) Concat(_ context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Concat concatenates a list of tensors along a given axis.

func (*CPUEngine[T]) Copy

func (e *CPUEngine[T]) Copy(_ context.Context, dst, src *tensor.TensorNumeric[T]) error

Copy copies src into dst; shapes must match.

func (*CPUEngine[T]) Cos

func (e *CPUEngine[T]) Cos(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Cos computes the element-wise cosine of a tensor.

func (*CPUEngine[T]) Div

func (e *CPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Div performs element-wise division with broadcasting. For integer types, division by zero returns an error.

func (*CPUEngine[T]) DivScalar

func (e *CPUEngine[T]) DivScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

DivScalar divides a tensor by a scalar value element-wise.

func (*CPUEngine[T]) Exp

func (e *CPUEngine[T]) Exp(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Exp computes the element-wise exponential of a tensor.

func (*CPUEngine[T]) Fill

func (e *CPUEngine[T]) Fill(_ context.Context, t *tensor.TensorNumeric[T], value T) error

Fill sets all elements of t to value.

func (*CPUEngine[T]) Gather

func (e *CPUEngine[T]) Gather(_ context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

Gather performs an embedding-style gather. params must be 2D [vocab, dim]. indices may be 1D [N] or 2D [batch, seq]. output must be [indices..., dim], i.e., [N, dim] or [batch, seq, dim].

func (*CPUEngine[T]) Log

func (e *CPUEngine[T]) Log(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Log computes the element-wise natural logarithm of a tensor.

func (*CPUEngine[T]) MatMul

func (e *CPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMul performs matrix multiplication of two tensors.

func (*CPUEngine[T]) MemoryTracker

func (e *CPUEngine[T]) MemoryTracker() *MemoryTracker

MemoryTracker returns the engine's memory tracker.

func (*CPUEngine[T]) Mul

func (e *CPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Mul performs element-wise multiplication with broadcasting.

func (*CPUEngine[T]) MulScalar

func (e *CPUEngine[T]) MulScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MulScalar performs element-wise multiplication of a tensor by a scalar.

func (*CPUEngine[T]) OneHot

func (e *CPUEngine[T]) OneHot(_ context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

OneHot creates a one-hot encoding of the input tensor.

func (*CPUEngine[T]) Ops

func (e *CPUEngine[T]) Ops() numeric.Arithmetic[T]

Ops returns the arithmetic ops for this engine.

func (*CPUEngine[T]) Pow

func (e *CPUEngine[T]) Pow(
	ctx context.Context,
	base, exponent *tensor.TensorNumeric[T],
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

Pow raises each element of a tensor to the power of the corresponding element in another tensor.

func (*CPUEngine[T]) RandomUniform

func (e *CPUEngine[T]) RandomUniform(_ context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

RandomUniform fills t with random values between minVal and maxVal.

func (*CPUEngine[T]) ReduceMean

func (e *CPUEngine[T]) ReduceMean(
	ctx context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

ReduceMean calculates the mean of elements along a specified axis.

func (*CPUEngine[T]) ReduceSum

func (e *CPUEngine[T]) ReduceSum(
	ctx context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

ReduceSum delegates to Sum for reduction along an axis.

func (*CPUEngine[T]) Repeat

func (e *CPUEngine[T]) Repeat(_ context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Repeat repeats the input tensor along a given axis a specified number of times.

func (*CPUEngine[T]) Reshape

func (e *CPUEngine[T]) Reshape(_ context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Reshape changes the shape of a tensor without changing its data.

func (*CPUEngine[T]) Rsqrt

func (e *CPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Rsqrt computes the element-wise reciprocal square root of a tensor.

func (*CPUEngine[T]) ScatterAdd

func (e *CPUEngine[T]) ScatterAdd(_ context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

ScatterAdd performs a row-wise scatter-add for embeddings. dEmbeddingTable must be [vocab, dim]. indices may be 1D [N] or multi-dim with flattened length N. dOut must be [N, dim]. For each i in [0..N), it applies: dEmbeddingTable[indices[i], :] += dOut[i, :].

func (*CPUEngine[T]) SetCollector

func (e *CPUEngine[T]) SetCollector(c metrics.Collector)

SetCollector replaces the engine's metrics collector.

func (*CPUEngine[T]) SetLogger

func (e *CPUEngine[T]) SetLogger(l log.Logger)

SetLogger replaces the engine's logger.

func (*CPUEngine[T]) SetMemoryLimit

func (e *CPUEngine[T]) SetMemoryLimit(bytes int64)

SetMemoryLimit configures the maximum number of bytes this engine may allocate for tensors. A limit of 0 disables enforcement.

func (*CPUEngine[T]) Sin

func (e *CPUEngine[T]) Sin(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sin computes the element-wise sine of a tensor.

func (*CPUEngine[T]) Softmax

func (e *CPUEngine[T]) Softmax(_ context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Softmax applies the softmax function to a tensor along a given axis. If axis is negative, it is interpreted relative to the last axis (e.g., -1 means last axis).

func (*CPUEngine[T]) Split

func (e *CPUEngine[T]) Split(_ context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

Split splits a tensor into numSplits along the given axis. All splits are equal-sized; shape[axis] must be divisible by numSplits.

func (*CPUEngine[T]) Sqrt

func (e *CPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sqrt computes the element-wise square root of a tensor.

func (*CPUEngine[T]) Sub

func (e *CPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sub performs element-wise subtraction with broadcasting.

func (*CPUEngine[T]) Sum

func (e *CPUEngine[T]) Sum(
	_ context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

Sum computes the sum of tensor elements along the specified axis. If keepDims is true, the reduced dimensions are retained with size 1. An optional destination tensor can be provided to store the result.

func (*CPUEngine[T]) Tanh

func (e *CPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Tanh applies the hyperbolic tangent activation element-wise.

func (*CPUEngine[T]) TanhPrime

func (e *CPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

TanhPrime computes tanh'(a) * upstream element-wise.

func (*CPUEngine[T]) Transpose

func (e *CPUEngine[T]) Transpose(_ context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*CPUEngine[T]) UnaryOp

func (e *CPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

UnaryOp applies a unary element-wise operation.

func (*CPUEngine[T]) Zero

func (e *CPUEngine[T]) Zero(_ context.Context, a *tensor.TensorNumeric[T]) error

Zero sets all elements of tensor a to zero.

func (*CPUEngine[T]) Zeros

func (e *CPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

Zeros fills the tensor with zeros. If shape is provided, (re)allocates to that shape.

type DType

type DType int

DType selects the compute precision for GPU operations.

const (
	// DTypeF32 uses float32 for all compute (default).
	DTypeF32 DType = iota
	// DTypeFP16 uses FP16 for elementwise ops and MatMul.
	// Activations are converted F32->FP16 before compute and FP16->F32 after.
	// Reductions (RMSNorm, Softmax) accumulate in FP32 for precision.
	DTypeFP16

	// DTypeFP8 uses FP8 E4M3 weights with FP16 compute for element-wise ops.
	// Weights are quantized to FP8 at load time, dequantized to FP16 on GPU.
	// MatMul uses cublasLtMatmul (auto-detected via FP8E4M3Storage).
	DTypeFP8
)

type Engine

type Engine[T tensor.Numeric] interface {
	// Ops returns the numeric.Arithmetic operations for the engine's numeric type.
	Ops() numeric.Arithmetic[T]
	// UnaryOp applies a unary function `op` to each element of tensor `a`.
	// It returns a new tensor with the results.
	// Returns an error if the input tensor is nil.
	UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Add performs element-wise addition of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sub performs element-wise subtraction of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Mul performs element-wise multiplication of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Div performs element-wise division of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// MatMul performs matrix multiplication of two 2D tensors.
	// It returns a new tensor with the result.
	// Returns an error if the tensors are nil, not 2D, or their shapes are incompatible for matrix multiplication.
	MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Transpose transposes a tensor along the given axes.
	// It returns a new tensor with the result.
	// Returns an error if the tensor is nil or the axes are invalid.
	Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sum calculates the sum of elements along a specified axis.
	// A negative axis means summing along all axes, returning a scalar tensor.
	// If keepDims is true, the reduced dimensions are retained with size 1.
	// Returns a new tensor with the reduced shape.
	// Returns an error if the tensor is nil or the axis is out of bounds.
	Sum(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// Exp computes the element-wise exponential of a tensor.
	Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Log computes the element-wise natural logarithm of a tensor.
	Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sin computes the element-wise sine of a tensor.
	Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Cos computes the element-wise cosine of a tensor.
	Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Tanh applies the hyperbolic tangent activation function element-wise.
	Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// TanhPrime computes the element-wise gradient of tanh at `a` multiplied by `upstream`.
	// This is useful for backpropagation where `upstream` is dL/dy and the result is dL/dx.
	TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Pow raises each element of a tensor to the power of the corresponding element in another tensor.
	Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Zero sets all elements of a tensor to zero.
	Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

	// Zeros fills the tensor with zeros. If a shape is provided, the tensor is reallocated to that shape.
	Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

	// Copy copies the data from one tensor to another.
	Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

	// Gather performs an embedding-style gather.
	// params must be 2D [vocab, dim].
	// indices may be 1D [N] or 2D [batch, seq].
	// output must be [indices..., dim], i.e., [N, dim] or [batch, seq, dim].
	Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

	// ScatterAdd performs a row-wise scatter-add for embeddings.
	// dEmbeddingTable must be [vocab, dim].
	// indices may be 1D [N] or multi-dim with flattened length N.
	// dOut must be [N, dim].
	// For each i in [0..N), it applies: dEmbeddingTable[indices[i], :] += dOut[i, :].
	ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

	// RandomUniform fills the tensor with random values from a uniform distribution.
	RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

	// Fill fills the tensor with a scalar value.
	Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

	// MulScalar performs element-wise multiplication of a tensor by a scalar.
	MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// DivScalar performs element-wise division of a tensor by a scalar.
	DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Softmax applies the softmax function to a tensor along a given axis.
	Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// ReduceSum calculates the sum of elements along a specified axis, similar to Sum but potentially with different
	// internal handling or optimizations for reduction operations.
	ReduceSum(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// AddScalar performs element-wise addition of a tensor by a scalar.
	AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sqrt computes the element-wise square root of a tensor.
	Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Split splits a tensor into multiple tensors along a given axis.
	Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

	// Concat concatenates a list of tensors along a given axis.
	Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Repeat repeats the input tensor along a given axis a specified number of times.
	Repeat(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		repetitions int,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// OneHot creates a one-hot encoding of the input tensor.
	OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Reshape changes the shape of a tensor without changing its data.
	Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// ReduceMean calculates the mean of elements along a specified axis.
	ReduceMean(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// Rsqrt computes the element-wise reciprocal square root of a tensor.
	Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

Engine defines the interface for a computation engine (e.g., CPU, GPU). All tensor operations should be routed through an Engine implementation to ensure hardware interoperability and optimized performance.

type EngineProxy

type EngineProxy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

EngineProxy wraps an Engine[T] and optionally records traced operations.

func NewEngineProxy

func NewEngineProxy[T tensor.Numeric](real Engine[T]) *EngineProxy[T]

NewEngineProxy creates a new EngineProxy wrapping the given engine.

func (*EngineProxy[T]) Add

func (p *EngineProxy[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) AddScalar

func (p *EngineProxy[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) ArenaUsedBytes

func (p *EngineProxy[T]) ArenaUsedBytes() int

ArenaUsedBytes returns the current arena offset from the underlying engine.

func (*EngineProxy[T]) Concat

func (p *EngineProxy[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Copy

func (p *EngineProxy[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

func (*EngineProxy[T]) Cos

func (p *EngineProxy[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Div

func (p *EngineProxy[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) DivScalar

func (p *EngineProxy[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Exp

func (p *EngineProxy[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Fill

func (p *EngineProxy[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

func (*EngineProxy[T]) FusedRMSNormGPU

func (p *EngineProxy[T]) FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (*tensor.TensorNumeric[float32], *tensor.TensorNumeric[float32], error)

FusedRMSNormGPU delegates to the underlying engine if it implements FusedRMSNormer.

func (*EngineProxy[T]) GPUFusedAddRMSNorm

func (p *EngineProxy[T]) GPUFusedAddRMSNorm(input, residual, weight *tensor.TensorNumeric[T], eps float32) (
	normed *tensor.TensorNumeric[T],
	residualOut *tensor.TensorNumeric[T],
	scales *tensor.TensorNumeric[T],
	err error,
)

GPUFusedAddRMSNorm delegates to the underlying engine's FusedAddRMSNormProvider implementation and records the operation for tracing. This allows fusedAddRMSNormNode to call through the proxy without unwrapping it, which is required for CompileTraced to capture the operation.

func (*EngineProxy[T]) Gather

func (p *EngineProxy[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

func (*EngineProxy[T]) Log

func (p *EngineProxy[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) MatMul

func (p *EngineProxy[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) MatMulTransposeB

func (p *EngineProxy[T]) MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMulTransposeB delegates to the underlying engine if it implements TransposeBMatMuler.

func (*EngineProxy[T]) Mul

func (p *EngineProxy[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) MulScalar

func (p *EngineProxy[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) OneHot

func (p *EngineProxy[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Ops

func (p *EngineProxy[T]) Ops() numeric.Arithmetic[T]

func (*EngineProxy[T]) Pow

func (p *EngineProxy[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) RandomUniform

func (p *EngineProxy[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

func (*EngineProxy[T]) Real

func (p *EngineProxy[T]) Real() Engine[T]

Real returns the underlying engine.

func (*EngineProxy[T]) ReduceMean

func (p *EngineProxy[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) ReduceSum

func (p *EngineProxy[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Repeat

func (p *EngineProxy[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) ResetPool

func (p *EngineProxy[T]) ResetPool()

ResetPool delegates to the underlying engine if it implements PoolResetter.

func (*EngineProxy[T]) Reshape

func (p *EngineProxy[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Rsqrt

func (p *EngineProxy[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) ScatterAdd

func (p *EngineProxy[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

func (*EngineProxy[T]) SetArenaResetFloor

func (p *EngineProxy[T]) SetArenaResetFloor(floor int)

SetArenaResetFloor sets the minimum reset offset on the underlying engine.

func (*EngineProxy[T]) Sin

func (p *EngineProxy[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Softmax

func (p *EngineProxy[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Split

func (p *EngineProxy[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Sqrt

func (p *EngineProxy[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) StartTracing

func (p *EngineProxy[T]) StartTracing(tracer TraceRecorder[T])

StartTracing enables tracing with the given recorder.

func (*EngineProxy[T]) StopTracing

func (p *EngineProxy[T]) StopTracing()

StopTracing disables tracing.

func (*EngineProxy[T]) Sub

func (p *EngineProxy[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Sum

func (p *EngineProxy[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Tanh

func (p *EngineProxy[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) TanhPrime

func (p *EngineProxy[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Transpose

func (p *EngineProxy[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) UnaryOp

func (p *EngineProxy[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*EngineProxy[T]) Zero

func (p *EngineProxy[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

func (*EngineProxy[T]) Zeros

func (p *EngineProxy[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

type FP16ToF32Converter

type FP16ToF32Converter interface {
	ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
}

FP16ToF32Converter is an optional interface for engines that can convert a tensor with Float16Storage to a regular float32 GPU tensor. This is used at the end of the FP16 forward pass to produce F32 logits for sampling.

type FailableTensor

type FailableTensor[T tensor.Numeric] struct {
	*tensor.TensorNumeric[T]
	// contains filtered or unexported fields
}

FailableTensor wraps a tensor and can be configured to fail on specific operations.

func NewFailableTensor

func NewFailableTensor[T tensor.Numeric](t *tensor.TensorNumeric[T]) *FailableTensor[T]

NewFailableTensor creates a new FailableTensor wrapper.

func (*FailableTensor[T]) Set

func (f *FailableTensor[T]) Set(value T, indices ...int) error

Set overrides the tensor's Set method to allow controlled failures.

func (*FailableTensor[T]) SetFailOnSet

func (f *FailableTensor[T]) SetFailOnSet(fail bool)

SetFailOnSet configures the tensor to fail on Set operations.

func (*FailableTensor[T]) SetFailOnSetAfter

func (f *FailableTensor[T]) SetFailOnSetAfter(count int)

SetFailOnSetAfter configures the tensor to fail after N Set calls.

type FailableZeroer

type FailableZeroer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

FailableZeroer can be configured to fail on Zero operations.

func NewFailableZeroer

func NewFailableZeroer[T tensor.Numeric](engine *TestableEngine[T]) *FailableZeroer[T]

NewFailableZeroer creates a new FailableZeroer.

func (*FailableZeroer[T]) SetFailOnZero

func (f *FailableZeroer[T]) SetFailOnZero(fail bool)

SetFailOnZero configures the zeroer to fail on Zero operations.

func (*FailableZeroer[T]) Zero

func (f *FailableZeroer[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

Zero performs the zero operation with controlled failure capability.

type FusedAddRMSNormProvider

type FusedAddRMSNormProvider[T tensor.Numeric] interface {
	// GPUFusedAddRMSNorm computes:
	//   sum    = input + residual
	//   normed = rmsnorm(sum, weight, eps)
	// Both inputs are read-only. Returns (normalized, sum, scales, error).
	GPUFusedAddRMSNorm(input, residual, weight *tensor.TensorNumeric[T], eps float32) (
		normed *tensor.TensorNumeric[T],
		residualOut *tensor.TensorNumeric[T],
		scales *tensor.TensorNumeric[T],
		err error,
	)
}

FusedAddRMSNormProvider is implemented by engines that support fused residual-add + RMS normalization in a single GPU kernel launch. This eliminates one kernel launch per fusion point (2 per transformer layer).

type FusedNormAddProvider

type FusedNormAddProvider[T tensor.Numeric] interface {
	// GPUFusedNormAdd computes:
	//   normed = rmsnorm(input, weight, eps)
	//   output = normed + residual
	// All inputs are read-only. Returns (output, error).
	GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error)
}

FusedNormAddProvider is implemented by engines that support fused RMSNorm + elementwise Add in a single GPU kernel launch. output = rmsnorm(input, weight, eps) + residual. This eliminates one kernel launch per fusion point.

type FusedQKNormRoPEProvider

type FusedQKNormRoPEProvider[T tensor.Numeric] interface {
	// GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K data.
	// input: [totalHeads, headDim] (Q heads then K heads, contiguous).
	// weightQ/weightK: [headDim] RMSNorm weights.
	// cosAngles/sinAngles: [halfRotary] precomputed angles for current position.
	// Returns output: [totalHeads, headDim].
	GPUFusedQKNormRoPE(
		input *tensor.TensorNumeric[T],
		weightQ, weightK *tensor.TensorNumeric[T],
		cosAngles, sinAngles *tensor.TensorNumeric[T],
		eps float32,
		totalHeads, headDim, numQHeads, halfRotary int,
	) (*tensor.TensorNumeric[T], error)
}

FusedQKNormRoPEProvider is implemented by engines that support fused per-head QK RMSNorm + RoPE in a single GPU kernel launch. This replaces 4 kernel launches (Q_norm + K_norm + Q_RoPE + K_RoPE) with 1 per GQA layer during decode.

type FusedRMSNormer

type FusedRMSNormer interface {
	FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error)
}

FusedRMSNormer is an optional interface for engines that support GPU-accelerated fused RMSNorm. Layers can type-assert to this to use the fused kernel. Returns (output, scales) where scales contains per-row rsqrt values for backward pass.

type FusedRoPEProvider

type FusedRoPEProvider[T tensor.Numeric] interface {
	GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error)
}

FusedRoPEProvider is implemented by engines that support fused GPU RoPE.

type FusedScaledSoftmaxProvider

type FusedScaledSoftmaxProvider[T tensor.Numeric] interface {
	GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error)
}

FusedScaledSoftmaxProvider is implemented by engines that support fused GPU scaled softmax. It computes output = softmax(input * scale) in a single kernel launch, eliminating the MulScalar + Softmax chain (saves 1 kernel launch per call).

type FusedSwiGLUProvider

type FusedSwiGLUProvider[T tensor.Numeric] interface {
	GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

FusedSwiGLUProvider is implemented by engines that support fused GPU SwiGLU. It computes output[i] = w1[i] * sigmoid(w1[i]) * w3[i] in a single kernel, eliminating the Concat + Split + sigmoid + Mul + Mul chain.

type GPUArgmaxer

type GPUArgmaxer interface {
	GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error)
}

GPUArgmaxer is an optional interface for engines that can compute argmax entirely on GPU, returning just the index without copying logits to host. This eliminates the ~1MB D2H copy per token for greedy decoding.

type GPUEngine

type GPUEngine[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GPUEngine is a GPU-accelerated implementation of the Engine interface. MatMul uses BLAS for maximum performance. Elementwise, scalar, activation, and math operations use native GPU kernels for float32 types. Operations without GPU kernels delegate to CPUEngine.

GPUEngine uses a device-resident pipeline: output tensors have GPUStorage so data stays on GPU between chained operations. A memory pool avoids per-operation malloc/free, and a dedicated stream enables async kernel execution.

GPUEngine is backend-agnostic via the GRAL interfaces (internal/gpuapi/). The CUDA, ROCm, and OpenCL adapters implement these interfaces.

func NewGPUEngine

func NewGPUEngine[T tensor.Numeric](ops numeric.Arithmetic[T], deviceID ...int) (*GPUEngine[T], error)

NewGPUEngine creates a new GPUEngine backed by CUDA via the GRAL abstraction. An optional deviceID selects the GPU (default 0). Call Close() when done to release all resources.

func (*GPUEngine[T]) Add

func (e *GPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) AddScalar

func (e *GPUEngine[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) ArenaUsedBytes

func (e *GPUEngine[T]) ArenaUsedBytes() int

ArenaUsedBytes returns the current arena offset (bytes in use).

func (*GPUEngine[T]) BatchNormForwardInference

func (e *GPUEngine[T]) BatchNormForwardInference(
	_ context.Context,
	x, scale, bias, mean, variance *tensor.TensorNumeric[T],
	epsilon float64,
) (*tensor.TensorNumeric[T], error)

BatchNormForwardInference performs batch normalization in inference mode using pre-computed running mean and variance via the GPU DNN backend. x must be [N, C, H, W]. scale, bias, mean, variance must each be [C].

func (*GPUEngine[T]) BatchNormForwardTraining

func (e *GPUEngine[T]) BatchNormForwardTraining(
	_ context.Context,
	x, scale, bias *tensor.TensorNumeric[T],
	runningMean, runningVariance *tensor.TensorNumeric[T],
	epsilon, expAvgFactor float64,
) (*tensor.TensorNumeric[T], *tensor.TensorNumeric[T], *tensor.TensorNumeric[T], error)

BatchNormForwardTraining performs batch normalization computing batch statistics. x: [N, C, H, W], scale/bias: [C]. Returns y: [N, C, H, W], saveMean: [C], saveInvVariance: [C]. Updates runningMean and runningVariance in-place.

func (*GPUEngine[T]) Close

func (e *GPUEngine[T]) Close() error

Close releases the BLAS handle, DNN handle, GPU stream, and drains the memory pool. The engine must not be used after Close.

func (*GPUEngine[T]) Concat

func (e *GPUEngine[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Conv2dBackwardData

func (e *GPUEngine[T]) Conv2dBackwardData(
	_ context.Context,
	w *tensor.TensorNumeric[T],
	dy *tensor.TensorNumeric[T],
	dxShape [4]int,
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dBackwardData computes the gradient of the convolution input via cuDNN. w: [C_out, C_in/groups, kH, kW], dy: [N, C_out, outH, outW]. Returns dx: [N, C_in, H, W].

func (*GPUEngine[T]) Conv2dBackwardFilter

func (e *GPUEngine[T]) Conv2dBackwardFilter(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	dy *tensor.TensorNumeric[T],
	dwShape [4]int,
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dBackwardFilter computes the gradient of the convolution filter via cuDNN. x: [N, C_in, H, W], dy: [N, C_out, outH, outW]. Returns dw: [C_out, C_in/groups, kH, kW].

func (*GPUEngine[T]) Conv2dForward

func (e *GPUEngine[T]) Conv2dForward(
	_ context.Context,
	x, w *tensor.TensorNumeric[T],
	bias *tensor.TensorNumeric[T],
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dForward performs 2D convolution using the GPU DNN backend. x must be [N, C_in, H, W], w must be [C_out, C_in/groups, kH, kW]. bias is optional (nil to skip). pads is [top, left, bottom, right]. Returns error if padding is asymmetric (cuDNN requires symmetric padding).

func (*GPUEngine[T]) ConvertFP16ToF32

func (e *GPUEngine[T]) ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)

ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32 GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it does not have Float16Storage.

func (*GPUEngine[T]) Copy

func (e *GPUEngine[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

func (*GPUEngine[T]) Cos

func (e *GPUEngine[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) CudnnActivationBackward

func (e *GPUEngine[T]) CudnnActivationBackward(
	_ context.Context,
	x, y, dy *tensor.TensorNumeric[T],
	mode gpuapi.ActivationMode,
) (*tensor.TensorNumeric[T], error)

CudnnActivationBackward computes the gradient of an activation function via cuDNN. y: forward output, dy: upstream gradient, x: original input. All must have the same shape. Returns dx with the same shape.

func (*GPUEngine[T]) CudnnActivationForward

func (e *GPUEngine[T]) CudnnActivationForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	mode gpuapi.ActivationMode,
) (*tensor.TensorNumeric[T], error)

CudnnActivationForward applies an activation function via the GPU DNN backend. mode selects the activation: ActivationReLU, ActivationSigmoid, ActivationTanh. The input tensor shape is preserved in the output.

func (*GPUEngine[T]) CudnnBatchNormBackward

func (e *GPUEngine[T]) CudnnBatchNormBackward(
	_ context.Context,
	x, dy, scale *tensor.TensorNumeric[T],
	saveMean, saveInvVariance *tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], *tensor.TensorNumeric[T], *tensor.TensorNumeric[T], error)

CudnnBatchNormBackward computes gradients for batch normalization via cuDNN. x: [N, C, H, W], dy: [N, C, H, W], scale: [C]. saveMean, saveInvVariance: [C] (from BatchNormForwardTraining). Returns dx: [N, C, H, W], dScale: [C], dBias: [C].

func (*GPUEngine[T]) CudnnPoolingBackward

func (e *GPUEngine[T]) CudnnPoolingBackward(
	_ context.Context,
	x, y, dy *tensor.TensorNumeric[T],
	mode gpuapi.PoolingMode,
	windowH, windowW, padH, padW, strideH, strideW int,
) (*tensor.TensorNumeric[T], error)

CudnnPoolingBackward computes the gradient of 2D pooling via cuDNN. y: forward output [N,C,outH,outW], dy: upstream gradient [N,C,outH,outW], x: forward input [N,C,H,W]. Returns dx: [N,C,H,W].

func (*GPUEngine[T]) CudnnPoolingForward

func (e *GPUEngine[T]) CudnnPoolingForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	mode gpuapi.PoolingMode,
	windowH, windowW, padH, padW, strideH, strideW int,
) (*tensor.TensorNumeric[T], error)

CudnnPoolingForward performs 2D pooling via the GPU DNN backend. x must be [N, C, H, W]. Returns [N, C, outH, outW].

func (*GPUEngine[T]) CudnnSoftmaxForward

func (e *GPUEngine[T]) CudnnSoftmaxForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

CudnnSoftmaxForward computes softmax via the GPU DNN backend over the channel dimension. x must be [N, C, H, W] (or reshaped to fit).

func (*GPUEngine[T]) DTypeValue

func (e *GPUEngine[T]) DTypeValue() DType

DTypeValue returns the current compute precision.

func (*GPUEngine[T]) DeviceID

func (e *GPUEngine[T]) DeviceID() int

DeviceID returns the GPU device ID this engine is bound to.

func (*GPUEngine[T]) Div

func (e *GPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) DivScalar

func (e *GPUEngine[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Exp

func (e *GPUEngine[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Fill

func (e *GPUEngine[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

func (*GPUEngine[T]) FusedRMSNormGPU

func (e *GPUEngine[T]) FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (*tensor.TensorNumeric[float32], *tensor.TensorNumeric[float32], error)

FusedRMSNormGPU implements the FusedRMSNormer interface for GPUEngine. Uses the fused GPU kernel when input is GPU-resident, falls back to CPU otherwise. Returns (output, scales) where scales contains per-row rsqrt values for backward pass.

func (*GPUEngine[T]) GPUArgmax

func (e *GPUEngine[T]) GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error)

GPUArgmax finds the index of the maximum element in a GPU-resident float32 tensor. Returns the index as an int without copying the full tensor to the host. Only copies back a single int32 (4 bytes) instead of the entire tensor.

func (*GPUEngine[T]) GPUFusedAddRMSNorm

func (e *GPUEngine[T]) GPUFusedAddRMSNorm(
	input, residual *tensor.TensorNumeric[T],
	weight *tensor.TensorNumeric[T],
	eps float32,
) (normed *tensor.TensorNumeric[T], residualOut *tensor.TensorNumeric[T], scales *tensor.TensorNumeric[T], err error)

GPUFusedAddRMSNorm computes sum = input + residual and normed = rmsnorm(sum, weight, eps) in a single GPU kernel launch. Both inputs are read-only; outputs go to separate buffers. This replaces Add + RMSNorm (2 kernel launches) with 1.

func (*GPUEngine[T]) GPUFusedNormAdd

func (e *GPUEngine[T]) GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error)

GPUFusedNormAdd computes output = rmsnorm(input, weight, eps) + residual in a single GPU kernel launch. Replaces separate RMSNorm + Add (2 launches → 1).

func (*GPUEngine[T]) GPUFusedQKNormRoPE

func (e *GPUEngine[T]) GPUFusedQKNormRoPE(
	input *tensor.TensorNumeric[T],
	weightQ, weightK *tensor.TensorNumeric[T],
	cosAngles, sinAngles *tensor.TensorNumeric[T],
	eps float32,
	totalHeads, headDim, numQHeads, halfRotary int,
) (*tensor.TensorNumeric[T], error)

GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K heads in a single GPU kernel launch. This replaces 4 kernel launches per GQA layer. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].

func (*GPUEngine[T]) GPUFusedRoPE

func (e *GPUEngine[T]) GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error)

GPUFusedRoPE applies rotary position embeddings in a single GPU kernel launch. This replaces Split + 4 Mul + Sub + Add + Concat (8 operations, ~10 D2D memcpy) with 1 kernel.

func (*GPUEngine[T]) GPUFusedSwiGLU

func (e *GPUEngine[T]) GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

GPUFusedSwiGLU computes SwiGLU(w1, w3) = w1 * sigmoid(w1) * w3 in a single GPU kernel. This replaces Concat + Split + sigmoid + Mul + Mul (5 operations, ~4 D2D memcpy per layer) with 1 kernel.

func (*GPUEngine[T]) GPUScaledSoftmax

func (e *GPUEngine[T]) GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error)

GPUScaledSoftmax computes softmax(input * scale) in a single GPU kernel launch. This replaces MulScalar + Softmax (2 kernel launches) with 1, saving 26 launches per token for 26 transformer layers.

func (*GPUEngine[T]) GPUStream

func (e *GPUEngine[T]) GPUStream() gpuapi.Stream

GPUStream returns the engine's gpuapi.Stream for async memory operations.

func (*GPUEngine[T]) Gather

func (e *GPUEngine[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

func (*GPUEngine[T]) IsManagedMemory

func (e *GPUEngine[T]) IsManagedMemory() bool

IsManagedMemory returns true if the engine uses managed memory for weight uploads and the arena allocator.

func (*GPUEngine[T]) Log

func (e *GPUEngine[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) MatMul

func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMul performs matrix multiplication using GPU BLAS for float32 and BFloat16 tensors. For Q4_0 quantized tensors, uses the Q4 dequant-GEMM kernel. For other types, it falls back to the CPU implementation. Supports 2D matrices and batched matmul (3D+ tensors).

func (*GPUEngine[T]) MatMulTransposeB

func (e *GPUEngine[T]) MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMulTransposeB computes C = A * B^T using cuBLAS SgemmNT, avoiding an explicit Transpose allocation and kernel launch. A is [...batch, m, k], B is [...batch, n, k], result is [...batch, m, n]. Supports batch broadcasting (bBatch=1 broadcasts B across A's batch).

func (*GPUEngine[T]) Mul

func (e *GPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) MulScalar

func (e *GPUEngine[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) OOMFallbackCount

func (e *GPUEngine[T]) OOMFallbackCount() int64

OOMFallbackCount returns the number of times GPU OOM triggered CPU fallback.

func (*GPUEngine[T]) OneHot

func (e *GPUEngine[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Ops

func (e *GPUEngine[T]) Ops() numeric.Arithmetic[T]

Ops returns the arithmetic ops for this engine.

func (*GPUEngine[T]) Pow

func (e *GPUEngine[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) RandomUniform

func (e *GPUEngine[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

func (*GPUEngine[T]) ReduceMean

func (e *GPUEngine[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) ReduceSum

func (e *GPUEngine[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Repeat

func (e *GPUEngine[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) ResetPool

func (e *GPUEngine[T]) ResetPool()

ResetPool resets the arena pool, reclaiming all per-pass allocations. This is a no-op if the pool is not arena-backed.

func (*GPUEngine[T]) Reshape

func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Rsqrt

func (e *GPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) ScatterAdd

func (e *GPUEngine[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

func (*GPUEngine[T]) SetArenaResetFloor

func (e *GPUEngine[T]) SetArenaResetFloor(floor int)

SetArenaResetFloor sets the minimum offset that arena Reset will rewind to.

func (*GPUEngine[T]) SetDType

func (e *GPUEngine[T]) SetDType(d DType)

SetDType sets the compute precision for elementwise ops and MatMul. DTypeFP16 enables the FP16 inference path: F32 inputs are converted to FP16 on GPU, FP16 kernels run, and results are converted back to F32.

func (*GPUEngine[T]) SetLogger

func (e *GPUEngine[T]) SetLogger(l log.Logger)

SetLogger replaces the engine's logger.

func (*GPUEngine[T]) Sin

func (e *GPUEngine[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Softmax

func (e *GPUEngine[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Split

func (e *GPUEngine[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Sqrt

func (e *GPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Stream

func (e *GPUEngine[T]) Stream() unsafe.Pointer

Stream returns the engine's GPU stream as an unsafe.Pointer (cudaStream_t).

func (*GPUEngine[T]) Sub

func (e *GPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Sum

func (e *GPUEngine[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Sync

func (e *GPUEngine[T]) Sync() error

Sync synchronizes the GPU stream, blocking until all enqueued operations complete. Use for benchmarking or when explicit synchronization is needed.

func (*GPUEngine[T]) Tanh

func (e *GPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) TanhPrime

func (e *GPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) Transpose

func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) UnaryOp

func (e *GPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) UploadWeights

func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) error

UploadWeights copies CPU-resident tensors to GPU device memory in place. Tensors that already have GPUStorage are skipped. Q4 quantized weights get their raw bytes uploaded and cached in Q4Storage to avoid per-op H2D. This is called once at model load time.

On devices with managed memory support (e.g., GB10), weights are allocated with cudaMallocManaged and populated via direct CPU memcpy. The GPU can then access them without any explicit H2D transfer.

func (*GPUEngine[T]) Zero

func (e *GPUEngine[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

func (*GPUEngine[T]) Zeros

func (e *GPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

type GPUStreamAccessor

type GPUStreamAccessor interface {
	GPUStream() gpuapi.Stream
}

GPUStreamAccessor is an optional interface for engines that provide their gpuapi.Stream for async memory operations (e.g., KV cache D2D copies during CUDA graph capture).

type MemoryTracker

type MemoryTracker struct {
	// contains filtered or unexported fields
}

MemoryTracker tracks total allocated bytes with an optional upper limit. All methods are safe for concurrent use.

func NewMemoryTracker

func NewMemoryTracker(limit int64) *MemoryTracker

NewMemoryTracker creates a tracker with the given byte limit. A limit of 0 disables enforcement (unlimited).

func (*MemoryTracker) Alloc

func (m *MemoryTracker) Alloc(bytes int64) error

Alloc reserves bytes. If the allocation would exceed the limit, it returns ErrMemoryLimitExceeded without modifying the counter.

func (*MemoryTracker) Allocated

func (m *MemoryTracker) Allocated() int64

Allocated returns the current total allocated bytes.

func (*MemoryTracker) Free

func (m *MemoryTracker) Free(bytes int64)

Free releases previously allocated bytes.

func (*MemoryTracker) Limit

func (m *MemoryTracker) Limit() int64

Limit returns the configured byte limit (0 means unlimited).

type PoolResetter

type PoolResetter interface {
	ResetPool()
}

PoolResetter is an optional interface for engines that use arena-based memory pools. Call ResetPool() at the start of each forward pass to reclaim all per-pass intermediate allocations in O(1).

type StreamProvider

type StreamProvider interface {
	// Stream returns the engine's GPU stream as an unsafe.Pointer (cudaStream_t).
	Stream() unsafe.Pointer
}

StreamProvider is an optional interface for engines that expose their underlying GPU stream for CUDA graph capture.

type TensorArena

type TensorArena struct {
	// contains filtered or unexported fields
}

TensorArena is a power-of-2 bucketed pool for float32 backing arrays. It reduces GC pressure during inference by reusing allocations. Thread-safe with per-bucket mutex protection.

func (*TensorArena) Get

func (a *TensorArena) Get(n int) []float32

Get returns a float32 slice of at least n elements from the arena. The returned slice has length n but may have capacity rounded up to a power of 2.

func (*TensorArena) Put

func (a *TensorArena) Put(buf []float32)

Put returns a buffer to the arena for reuse.

func (*TensorArena) Reset

func (a *TensorArena) Reset()

Reset clears all pooled buffers, allowing GC to collect them.

type TensorPool

type TensorPool[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

TensorPool provides reusable tensor buffers keyed by shape. Acquire returns a tensor from the pool or allocates a new one. Release returns a tensor to the pool for future reuse. The pool is safe for concurrent use.

func NewTensorPool

func NewTensorPool[T tensor.Numeric]() *TensorPool[T]

NewTensorPool creates a new empty tensor pool.

func (*TensorPool[T]) Acquire

func (p *TensorPool[T]) Acquire(shape []int) (*tensor.TensorNumeric[T], error)

Acquire returns a tensor with the given shape. If the pool has a matching buffer, it is returned (zeroed). Otherwise a new tensor is allocated.

func (*TensorPool[T]) Len

func (p *TensorPool[T]) Len() int

Len returns the total number of tensors currently in the pool.

func (*TensorPool[T]) Release

func (p *TensorPool[T]) Release(t *tensor.TensorNumeric[T])

Release returns a tensor to the pool for future reuse. For GPU-backed tensors, the device memory is freed immediately (returned to the GPU MemPool for reuse) rather than holding the tensor reference, since GPU memory is a scarce resource managed by a separate pool. The tensor must not be used after calling Release.

type TestableEngine

type TestableEngine[T tensor.Numeric] struct {
	*CPUEngine[T]
}

TestableEngine extends CPUEngine with methods that allow controlled error injection This enables testing of previously unreachable error paths.

func NewTestableEngine

func NewTestableEngine[T tensor.Numeric](ops numeric.Arithmetic[T]) *TestableEngine[T]

NewTestableEngine creates a new TestableEngine.

func (*TestableEngine[T]) TestableMatMul

func (e *TestableEngine[T]) TestableMatMul(_ context.Context, a, b *tensor.TensorNumeric[T], result *FailableTensor[T]) error

TestableMatMul performs matrix multiplication with a FailableTensor result This allows testing the error path in MatMul when result.Set() fails.

func (*TestableEngine[T]) TestableSum

func (e *TestableEngine[T]) TestableSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, _ bool, zeroer *FailableZeroer[T], result *tensor.TensorNumeric[T]) error

TestableSum performs sum with a FailableZeroer This allows testing the error path in Sum when Zero() fails.

func (*TestableEngine[T]) TestableTranspose

func (e *TestableEngine[T]) TestableTranspose(_ context.Context, a *tensor.TensorNumeric[T], result *FailableTensor[T]) error

TestableTranspose performs transpose with a FailableTensor result This allows testing the error path in Transpose when result.Set() fails.

type TraceRecorder

type TraceRecorder[T tensor.Numeric] interface {
	Record(opName string, inputs []*tensor.TensorNumeric[T], output *tensor.TensorNumeric[T], extra map[string]any)
	RecordMultiOutput(opName string, inputs []*tensor.TensorNumeric[T], outputs []*tensor.TensorNumeric[T], extra map[string]any)
	RecordGather(params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T], extra map[string]any)
}

TraceRecorder is the interface used by EngineProxy to record traced operations.

type TracedOp

type TracedOp struct {
	OpName    string
	InputIDs  []int          // slot indices for inputs
	OutputID  int            // slot index for output
	OutputIDs []int          // for multi-output ops like Split
	ExtraArgs map[string]any // axis, scalar value, shape, etc.
}

TracedOp records a single engine operation with slot-based tensor identity.

type Tracer

type Tracer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Tracer records engine operations and tracks tensor identity by pointer. It assigns each unique tensor pointer a slot index, enabling later compilation into an ExecutionPlan.

func NewTracer

func NewTracer[T tensor.Numeric](frozenTensors []*tensor.TensorNumeric[T]) *Tracer[T]

NewTracer creates a Tracer and pre-registers frozen tensors (model weights) with slot indices.

func (*Tracer[T]) FrozenSlots

func (t *Tracer[T]) FrozenSlots() []int

FrozenSlots returns the slot indices of frozen (weight) tensors.

func (*Tracer[T]) HasOpaqueOps

func (t *Tracer[T]) HasOpaqueOps() bool

HasOpaqueOps reports whether any opaque (non-traceable) operations were encountered during tracing.

func (*Tracer[T]) MarkOpaque

func (t *Tracer[T]) MarkOpaque()

MarkOpaque marks the trace as containing opaque operations.

func (*Tracer[T]) NextSlot

func (t *Tracer[T]) NextSlot() int

NextSlot returns the next slot index that would be assigned. This indicates the total number of slots allocated.

func (*Tracer[T]) Record

func (t *Tracer[T]) Record(opName string, inputs []*tensor.TensorNumeric[T], output *tensor.TensorNumeric[T], extra map[string]any)

Record appends a TracedOp for a single-output operation.

func (*Tracer[T]) RecordGather

func (t *Tracer[T]) RecordGather(params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T], extra map[string]any)

RecordGather appends a TracedOp for Gather which uses int indices.

func (*Tracer[T]) RecordMultiOutput

func (t *Tracer[T]) RecordMultiOutput(opName string, inputs []*tensor.TensorNumeric[T], outputs []*tensor.TensorNumeric[T], extra map[string]any)

RecordMultiOutput appends a TracedOp for a multi-output operation (e.g., Split).

func (*Tracer[T]) SlotFor

func (t *Tracer[T]) SlotFor(tn *tensor.TensorNumeric[T]) int

SlotFor returns the existing slot for a tensor or assigns a new one. This is the exported version of slotFor.

func (*Tracer[T]) SlotShapes

func (t *Tracer[T]) SlotShapes() map[int][]int

SlotShapes returns the shape for each slot index.

func (*Tracer[T]) TracedOps

func (t *Tracer[T]) TracedOps() []TracedOp

TracedOps returns the recorded operations in order.

type TransposeBMatMuler

type TransposeBMatMuler[T tensor.Numeric] interface {
	MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

TransposeBMatMuler is an optional interface for engines that can compute C = A * B^T without explicitly transposing B. This avoids an extra GPU allocation and kernel launch for the transpose operation. A is [batch, m, k], B is [batch, n, k], result is [batch, m, n].

type WeightUploader

type WeightUploader interface {
	UploadWeights(tensors []*tensor.TensorNumeric[float32]) error
}

WeightUploader is an optional interface for engines that can pre-upload model weights to device memory at load time. This eliminates per-operation host-to-device copies during inference. Each tensor's storage is replaced in-place from CPUStorage to device-resident storage.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL