compute

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 21, 2026 License: Apache-2.0 Imports: 30 Imported by: 23

Documentation

Overview

Package compute implements tensor computation engines and operations.

Index

Constants

This section is empty.

Variables

View Source
var ErrMemoryLimitExceeded = errors.New("memory limit exceeded")

ErrMemoryLimitExceeded is returned when a tensor allocation would exceed the configured memory limit.

Functions

func ComputeAmax added in v0.3.0

func ComputeAmax[T tensor.Numeric](_ context.Context, ops numeric.Arithmetic[T], t *tensor.TensorNumeric[T]) (float32, error)

ComputeAmax returns the maximum absolute value of all elements in t as float32. It scans the tensor data on the CPU. For GPU tensors, data must be accessible on the host (e.g. via a prior device-to-host copy). Returns 0 for empty tensors.

func DequantW4ToFP16 added in v0.3.0

func DequantW4ToFP16(w4 w4Storage) *tensor.Float16Storage

DequantW4ToFP16 dequantizes 4-bit quantized weights to FP16 format. This is useful for GPU paths that want FP16 weight data for cuBLAS HGEMM.

func FusedRMSNorm

func FusedRMSNorm(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error)

FusedRMSNorm computes x * rsqrt(mean(x^2) + eps) * weight in a single pass. This avoids materializing squared, mean, and rsqrt intermediate tensors. Input shape: [..., D] where D is the last dimension (hidden size). Weight shape: [D]. Returns (output, scales) where output has same shape as input and scales has shape [..., 1] containing the per-row rsqrt(mean(x^2)+eps) values.

func FusedRoPE

func FusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[float32], rotaryDim int) (*tensor.TensorNumeric[float32], error)

FusedRoPE applies rotary position embeddings in a single pass. Input shape: [batch, seq_len, head_dim] where head_dim is even. cos/sin shape: [seq_len, half_dim] (precomputed angles). rotaryDim: number of dimensions that receive rotation (<= head_dim, must be even). For each position (b, s):

out[..., i]            = in[..., i] * cos[s,i] - in[..., i+half] * sin[s,i]      (i < half)
out[..., i+half]       = in[..., i+half] * cos[s,i] + in[..., i] * sin[s,i]      (i < half)
out[..., rotaryDim..]  = in[..., rotaryDim..]                                      (pass-through)

func FusedSiLUGate

func FusedSiLUGate(gate, up *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)

FusedSiLUGate computes silu(gate) * up in a single element-wise pass. SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)). gate and up must have the same shape. This avoids materializing separate sigmoid, mul, and mul intermediate tensors.

func IsW4A16 added in v0.3.0

func IsW4A16[T tensor.Numeric](a, b *tensor.TensorNumeric[T]) bool

IsW4A16 returns true if the two tensors form a W4A16 mixed-precision pair (one operand has 4-bit weights, the other has FP16 activations).

func MatMulW4A16 added in v0.3.0

func MatMulW4A16[T tensor.Numeric](
	ctx context.Context,
	eng Engine[T],
	a, b *tensor.TensorNumeric[T],
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

MatMulW4A16 performs mixed-precision matrix multiplication with 4-bit quantized weights and FP16 activations. The 4-bit weights are dequantized to float32 and the FP16 activations are decoded to float32 for computation. The result is a float32 tensor.

This is the CPU fallback path. GPU engines can override with fused dequant-GEMM kernels for better performance.

func QuantFormat added in v0.3.0

func QuantFormat[T tensor.Numeric](s tensor.Storage[T]) string

QuantFormat returns a string identifying the 4-bit quantization format of the given storage, or "" if it's not a recognized 4-bit format.

func ScaleForFP8 added in v0.3.0

func ScaleForFP8[T tensor.Numeric](ctx context.Context, ops numeric.Arithmetic[T], t *tensor.TensorNumeric[T]) (float32, error)

ScaleForFP8 returns the scale factor for FP8 E4M3FN quantization: 448.0 / amax. Returns an error if the tensor is nil. Returns +Inf if amax is zero (all-zero tensor).

func TryW4A16MatMul added in v0.3.0

func TryW4A16MatMul[T tensor.Numeric](
	ctx context.Context,
	eng Engine[T],
	a, b *tensor.TensorNumeric[T],
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], bool, error)

TryW4A16MatMul attempts to dispatch a W4A16 mixed-precision MatMul. Returns (result, true) if the inputs matched the W4A16 pattern, or (nil, false) if the inputs are not a W4A16 combination.

Types

type CPUEngine

type CPUEngine[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

CPUEngine is a CPU-based implementation of the Engine interface.

func NewCPUEngine

func NewCPUEngine[T tensor.Numeric](ops numeric.Arithmetic[T]) *CPUEngine[T]

NewCPUEngine constructs a new CPUEngine for the given numeric operations. A no-op logger and no-op collector are used by default; call SetLogger/SetCollector to override.

func (*CPUEngine[T]) Add

func (e *CPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Add performs element-wise addition with broadcasting.

func (*CPUEngine[T]) AddScalar

func (e *CPUEngine[T]) AddScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

AddScalar performs element-wise addition of a tensor by a scalar.

func (*CPUEngine[T]) Close

func (e *CPUEngine[T]) Close(_ context.Context) error

Close is a no-op for CPUEngine. It satisfies the shutdown.Closer interface.

func (*CPUEngine[T]) Concat

func (e *CPUEngine[T]) Concat(_ context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Concat concatenates a list of tensors along a given axis.

func (*CPUEngine[T]) Copy

func (e *CPUEngine[T]) Copy(_ context.Context, dst, src *tensor.TensorNumeric[T]) error

Copy copies src into dst; shapes must match.

func (*CPUEngine[T]) Cos

func (e *CPUEngine[T]) Cos(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Cos computes the element-wise cosine of a tensor.

func (*CPUEngine[T]) Div

func (e *CPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Div performs element-wise division with broadcasting. For integer types, division by zero returns an error.

func (*CPUEngine[T]) DivScalar

func (e *CPUEngine[T]) DivScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

DivScalar divides a tensor by a scalar value element-wise.

func (*CPUEngine[T]) Exp

func (e *CPUEngine[T]) Exp(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Exp computes the element-wise exponential of a tensor.

func (*CPUEngine[T]) Fill

func (e *CPUEngine[T]) Fill(_ context.Context, t *tensor.TensorNumeric[T], value T) error

Fill sets all elements of t to value.

func (*CPUEngine[T]) Gather

func (e *CPUEngine[T]) Gather(_ context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

Gather performs an embedding-style gather. params must be 2D [vocab, dim]. indices may be 1D [N] or 2D [batch, seq]. output must be [indices..., dim], i.e., [N, dim] or [batch, seq, dim].

func (*CPUEngine[T]) Log

func (e *CPUEngine[T]) Log(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Log computes the element-wise natural logarithm of a tensor.

func (*CPUEngine[T]) MatMul

func (e *CPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMul performs matrix multiplication of two tensors.

func (*CPUEngine[T]) MemoryTracker

func (e *CPUEngine[T]) MemoryTracker() *MemoryTracker

MemoryTracker returns the engine's memory tracker.

func (*CPUEngine[T]) Mul

func (e *CPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Mul performs element-wise multiplication with broadcasting.

func (*CPUEngine[T]) MulScalar

func (e *CPUEngine[T]) MulScalar(_ context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MulScalar performs element-wise multiplication of a tensor by a scalar.

func (*CPUEngine[T]) OneHot

func (e *CPUEngine[T]) OneHot(_ context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

OneHot creates a one-hot encoding of the input tensor.

func (*CPUEngine[T]) Ops

func (e *CPUEngine[T]) Ops() numeric.Arithmetic[T]

Ops returns the arithmetic ops for this engine.

func (*CPUEngine[T]) Pow

func (e *CPUEngine[T]) Pow(
	ctx context.Context,
	base, exponent *tensor.TensorNumeric[T],
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

Pow raises each element of a tensor to the power of the corresponding element in another tensor.

func (*CPUEngine[T]) RandomUniform

func (e *CPUEngine[T]) RandomUniform(_ context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

RandomUniform fills t with random values between minVal and maxVal.

func (*CPUEngine[T]) ReduceMean

func (e *CPUEngine[T]) ReduceMean(
	ctx context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

ReduceMean calculates the mean of elements along a specified axis.

func (*CPUEngine[T]) ReduceSum

func (e *CPUEngine[T]) ReduceSum(
	ctx context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

ReduceSum delegates to Sum for reduction along an axis.

func (*CPUEngine[T]) Repeat

func (e *CPUEngine[T]) Repeat(_ context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Repeat repeats the input tensor along a given axis a specified number of times.

func (*CPUEngine[T]) Reshape

func (e *CPUEngine[T]) Reshape(_ context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Reshape changes the shape of a tensor without changing its data.

func (*CPUEngine[T]) Rsqrt

func (e *CPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Rsqrt computes the element-wise reciprocal square root of a tensor.

func (*CPUEngine[T]) ScatterAdd

func (e *CPUEngine[T]) ScatterAdd(_ context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

ScatterAdd performs a row-wise scatter-add for embeddings. dEmbeddingTable must be [vocab, dim]. indices may be 1D [N] or multi-dim with flattened length N. dOut must be [N, dim]. For each i in [0..N), it applies: dEmbeddingTable[indices[i], :] += dOut[i, :].

func (*CPUEngine[T]) SetCollector

func (e *CPUEngine[T]) SetCollector(c metrics.Collector)

SetCollector replaces the engine's metrics collector.

func (*CPUEngine[T]) SetLogger

func (e *CPUEngine[T]) SetLogger(l log.Logger)

SetLogger replaces the engine's logger.

func (*CPUEngine[T]) SetMemoryLimit

func (e *CPUEngine[T]) SetMemoryLimit(bytes int64)

SetMemoryLimit configures the maximum number of bytes this engine may allocate for tensors. A limit of 0 disables enforcement.

func (*CPUEngine[T]) Sin

func (e *CPUEngine[T]) Sin(_ context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sin computes the element-wise sine of a tensor.

func (*CPUEngine[T]) Softmax

func (e *CPUEngine[T]) Softmax(_ context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Softmax applies the softmax function to a tensor along a given axis. If axis is negative, it is interpreted relative to the last axis (e.g., -1 means last axis).

func (*CPUEngine[T]) Split

func (e *CPUEngine[T]) Split(_ context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

Split splits a tensor into numSplits along the given axis. All splits are equal-sized; shape[axis] must be divisible by numSplits.

func (*CPUEngine[T]) Sqrt

func (e *CPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sqrt computes the element-wise square root of a tensor.

func (*CPUEngine[T]) Sub

func (e *CPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sub performs element-wise subtraction with broadcasting.

func (*CPUEngine[T]) Sum

func (e *CPUEngine[T]) Sum(
	_ context.Context,
	a *tensor.TensorNumeric[T],
	axis int,
	keepDims bool,
	dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

Sum computes the sum of tensor elements along the specified axis. If keepDims is true, the reduced dimensions are retained with size 1. An optional destination tensor can be provided to store the result.

func (*CPUEngine[T]) Tanh

func (e *CPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Tanh applies the hyperbolic tangent activation element-wise.

func (*CPUEngine[T]) TanhPrime

func (e *CPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

TanhPrime computes tanh'(a) * upstream element-wise.

func (*CPUEngine[T]) Transpose

func (e *CPUEngine[T]) Transpose(_ context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Transpose transposes the tensor along the given axes.

func (*CPUEngine[T]) UnaryOp

func (e *CPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

UnaryOp applies a unary element-wise operation.

func (*CPUEngine[T]) Zero

func (e *CPUEngine[T]) Zero(_ context.Context, a *tensor.TensorNumeric[T]) error

Zero sets all elements of tensor a to zero.

func (*CPUEngine[T]) Zeros

func (e *CPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

Zeros fills the tensor with zeros. If shape is provided, (re)allocates to that shape.

type DType

type DType int

DType selects the compute precision for GPU operations.

const (
	// DTypeF32 uses float32 for all compute (default).
	DTypeF32 DType = iota
	// DTypeFP16 uses FP16 for elementwise ops and MatMul.
	// Activations are converted F32->FP16 before compute and FP16->F32 after.
	// Reductions (RMSNorm, Softmax) accumulate in FP32 for precision.
	DTypeFP16

	// DTypeFP8 uses FP8 E4M3 weights with FP16 compute for element-wise ops.
	// Weights are quantized to FP8 at load time, dequantized to FP16 on GPU.
	// MatMul uses cublasLtMatmul (auto-detected via FP8E4M3Storage).
	DTypeFP8
)

type Engine

type Engine[T tensor.Numeric] interface {
	// Ops returns the numeric.Arithmetic operations for the engine's numeric type.
	Ops() numeric.Arithmetic[T]
	// UnaryOp applies a unary function `op` to each element of tensor `a`.
	// It returns a new tensor with the results.
	// Returns an error if the input tensor is nil.
	UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Add performs element-wise addition of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sub performs element-wise subtraction of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Mul performs element-wise multiplication of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Div performs element-wise division of two tensors, with support for broadcasting.
	// It returns a new tensor with the results.
	// Returns an error if tensors are nil or their shapes are not compatible for broadcasting.
	Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// MatMul performs matrix multiplication of two 2D tensors.
	// It returns a new tensor with the result.
	// Returns an error if the tensors are nil, not 2D, or their shapes are incompatible for matrix multiplication.
	MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Transpose transposes a tensor along the given axes.
	// It returns a new tensor with the result.
	// Returns an error if the tensor is nil or the axes are invalid.
	Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sum calculates the sum of elements along a specified axis.
	// A negative axis means summing along all axes, returning a scalar tensor.
	// If keepDims is true, the reduced dimensions are retained with size 1.
	// Returns a new tensor with the reduced shape.
	// Returns an error if the tensor is nil or the axis is out of bounds.
	Sum(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// Exp computes the element-wise exponential of a tensor.
	Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Log computes the element-wise natural logarithm of a tensor.
	Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sin computes the element-wise sine of a tensor.
	Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Cos computes the element-wise cosine of a tensor.
	Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Tanh applies the hyperbolic tangent activation function element-wise.
	Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// TanhPrime computes the element-wise gradient of tanh at `a` multiplied by `upstream`.
	// This is useful for backpropagation where `upstream` is dL/dy and the result is dL/dx.
	TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Pow raises each element of a tensor to the power of the corresponding element in another tensor.
	Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Zero sets all elements of a tensor to zero.
	Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

	// Zeros fills the tensor with zeros. If a shape is provided, the tensor is reallocated to that shape.
	Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

	// Copy copies the data from one tensor to another.
	Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

	// Gather performs an embedding-style gather.
	// params must be 2D [vocab, dim].
	// indices may be 1D [N] or 2D [batch, seq].
	// output must be [indices..., dim], i.e., [N, dim] or [batch, seq, dim].
	Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

	// ScatterAdd performs a row-wise scatter-add for embeddings.
	// dEmbeddingTable must be [vocab, dim].
	// indices may be 1D [N] or multi-dim with flattened length N.
	// dOut must be [N, dim].
	// For each i in [0..N), it applies: dEmbeddingTable[indices[i], :] += dOut[i, :].
	ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

	// RandomUniform fills the tensor with random values from a uniform distribution.
	RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

	// Fill fills the tensor with a scalar value.
	Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

	// MulScalar performs element-wise multiplication of a tensor by a scalar.
	MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// DivScalar performs element-wise division of a tensor by a scalar.
	DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Softmax applies the softmax function to a tensor along a given axis.
	Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// ReduceSum calculates the sum of elements along a specified axis, similar to Sum but potentially with different
	// internal handling or optimizations for reduction operations.
	ReduceSum(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// AddScalar performs element-wise addition of a tensor by a scalar.
	AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Sqrt computes the element-wise square root of a tensor.
	Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Split splits a tensor into multiple tensors along a given axis.
	Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

	// Concat concatenates a list of tensors along a given axis.
	Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Repeat repeats the input tensor along a given axis a specified number of times.
	Repeat(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		repetitions int,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// OneHot creates a one-hot encoding of the input tensor.
	OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// Reshape changes the shape of a tensor without changing its data.
	Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

	// ReduceMean calculates the mean of elements along a specified axis.
	ReduceMean(
		ctx context.Context,
		a *tensor.TensorNumeric[T],
		axis int,
		keepDims bool,
		dst ...*tensor.TensorNumeric[T],
	) (*tensor.TensorNumeric[T], error)

	// Rsqrt computes the element-wise reciprocal square root of a tensor.
	Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

Engine defines the interface for a computation engine (e.g., CPU, GPU). All tensor operations should be routed through an Engine implementation to ensure hardware interoperability and optimized performance.

type EngineProxy

type EngineProxy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

EngineProxy wraps an Engine[T] and optionally records traced operations.

func NewEngineProxy

func NewEngineProxy[T tensor.Numeric](real Engine[T]) *EngineProxy[T]

NewEngineProxy creates a new EngineProxy wrapping the given engine.

func (*EngineProxy[T]) Add

func (p *EngineProxy[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Add delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) AddScalar

func (p *EngineProxy[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

AddScalar delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) ArenaUsedBytes

func (p *EngineProxy[T]) ArenaUsedBytes() int

ArenaUsedBytes returns the current arena offset from the underlying engine.

func (*EngineProxy[T]) Concat

func (p *EngineProxy[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Concat delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Copy

func (p *EngineProxy[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

Copy delegates to the underlying engine.

func (*EngineProxy[T]) Cos

func (p *EngineProxy[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Cos delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Div

func (p *EngineProxy[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Div delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) DivScalar

func (p *EngineProxy[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

DivScalar delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Exp

func (p *EngineProxy[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Exp delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Fill

func (p *EngineProxy[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

Fill delegates to the underlying engine.

func (*EngineProxy[T]) FusedRMSNormGPU

func (p *EngineProxy[T]) FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (*tensor.TensorNumeric[float32], *tensor.TensorNumeric[float32], error)

FusedRMSNormGPU delegates to the underlying engine if it implements FusedRMSNormer.

func (*EngineProxy[T]) GPUFusedAddRMSNorm

func (p *EngineProxy[T]) GPUFusedAddRMSNorm(input, residual, weight *tensor.TensorNumeric[T], eps float32) (
	normed *tensor.TensorNumeric[T],
	residualOut *tensor.TensorNumeric[T],
	scales *tensor.TensorNumeric[T],
	err error,
)

GPUFusedAddRMSNorm delegates to the underlying engine's FusedAddRMSNormProvider implementation and records the operation for tracing. This allows fusedAddRMSNormNode to call through the proxy without unwrapping it, which is required for CompileTraced to capture the operation.

func (*EngineProxy[T]) Gather

func (p *EngineProxy[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

Gather delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Log

func (p *EngineProxy[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Log delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) MatMul

func (p *EngineProxy[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMul delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) MatMulTransposeB

func (p *EngineProxy[T]) MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMulTransposeB delegates to the underlying engine if it implements TransposeBMatMuler.

func (*EngineProxy[T]) Mul

func (p *EngineProxy[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Mul delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) MulScalar

func (p *EngineProxy[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MulScalar delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) OneHot

func (p *EngineProxy[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

OneHot delegates to the underlying engine.

func (*EngineProxy[T]) Ops

func (p *EngineProxy[T]) Ops() numeric.Arithmetic[T]

Ops delegates to the underlying engine.

func (*EngineProxy[T]) Pow

func (p *EngineProxy[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Pow delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) RandomUniform

func (p *EngineProxy[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

RandomUniform delegates to the underlying engine.

func (*EngineProxy[T]) Real

func (p *EngineProxy[T]) Real() Engine[T]

Real returns the underlying engine.

func (*EngineProxy[T]) ReduceMean

func (p *EngineProxy[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

ReduceMean delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) ReduceSum

func (p *EngineProxy[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

ReduceSum delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Repeat

func (p *EngineProxy[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Repeat delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) ResetPool

func (p *EngineProxy[T]) ResetPool()

ResetPool delegates to the underlying engine if it implements PoolResetter.

func (*EngineProxy[T]) Reshape

func (p *EngineProxy[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Reshape delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Rsqrt

func (p *EngineProxy[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Rsqrt delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) ScatterAdd

func (p *EngineProxy[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

ScatterAdd delegates to the underlying engine.

func (*EngineProxy[T]) SetArenaResetFloor

func (p *EngineProxy[T]) SetArenaResetFloor(floor int)

SetArenaResetFloor sets the minimum reset offset on the underlying engine.

func (*EngineProxy[T]) Sin

func (p *EngineProxy[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sin delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Softmax

func (p *EngineProxy[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Softmax delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Split

func (p *EngineProxy[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

Split delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Sqrt

func (p *EngineProxy[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sqrt delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) StartTracing

func (p *EngineProxy[T]) StartTracing(tracer TraceRecorder[T])

StartTracing enables tracing with the given recorder.

func (*EngineProxy[T]) StopTracing

func (p *EngineProxy[T]) StopTracing()

StopTracing disables tracing.

func (*EngineProxy[T]) Sub

func (p *EngineProxy[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sub delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Sum

func (p *EngineProxy[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sum delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Tanh

func (p *EngineProxy[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Tanh delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) TanhPrime

func (p *EngineProxy[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

TanhPrime delegates to the underlying engine.

func (*EngineProxy[T]) Transpose

func (p *EngineProxy[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Transpose delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) UnaryOp

func (p *EngineProxy[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

UnaryOp delegates to the underlying engine and records the operation if tracing.

func (*EngineProxy[T]) Zero

func (p *EngineProxy[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

Zero delegates to the underlying engine.

func (*EngineProxy[T]) Zeros

func (p *EngineProxy[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

Zeros delegates to the underlying engine.

type FP16ToF32Converter

type FP16ToF32Converter interface {
	ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)
}

FP16ToF32Converter is an optional interface for engines that can convert a tensor with Float16Storage to a regular float32 GPU tensor. This is used at the end of the FP16 forward pass to produce F32 logits for sampling.

type FailableTensor

type FailableTensor[T tensor.Numeric] struct {
	*tensor.TensorNumeric[T]
	// contains filtered or unexported fields
}

FailableTensor wraps a tensor and can be configured to fail on specific operations.

func NewFailableTensor

func NewFailableTensor[T tensor.Numeric](t *tensor.TensorNumeric[T]) *FailableTensor[T]

NewFailableTensor creates a new FailableTensor wrapper.

func (*FailableTensor[T]) Set

func (f *FailableTensor[T]) Set(value T, indices ...int) error

Set overrides the tensor's Set method to allow controlled failures.

func (*FailableTensor[T]) SetFailOnSet

func (f *FailableTensor[T]) SetFailOnSet(fail bool)

SetFailOnSet configures the tensor to fail on Set operations.

func (*FailableTensor[T]) SetFailOnSetAfter

func (f *FailableTensor[T]) SetFailOnSetAfter(count int)

SetFailOnSetAfter configures the tensor to fail after N Set calls.

type FailableZeroer

type FailableZeroer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

FailableZeroer can be configured to fail on Zero operations.

func NewFailableZeroer

func NewFailableZeroer[T tensor.Numeric](engine *TestableEngine[T]) *FailableZeroer[T]

NewFailableZeroer creates a new FailableZeroer.

func (*FailableZeroer[T]) SetFailOnZero

func (f *FailableZeroer[T]) SetFailOnZero(fail bool)

SetFailOnZero configures the zeroer to fail on Zero operations.

func (*FailableZeroer[T]) Zero

func (f *FailableZeroer[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

Zero performs the zero operation with controlled failure capability.

type FusedAddRMSNormProvider

type FusedAddRMSNormProvider[T tensor.Numeric] interface {
	// GPUFusedAddRMSNorm computes:
	//   sum    = input + residual
	//   normed = rmsnorm(sum, weight, eps)
	// Both inputs are read-only. Returns (normalized, sum, scales, error).
	GPUFusedAddRMSNorm(input, residual, weight *tensor.TensorNumeric[T], eps float32) (
		normed *tensor.TensorNumeric[T],
		residualOut *tensor.TensorNumeric[T],
		scales *tensor.TensorNumeric[T],
		err error,
	)
}

FusedAddRMSNormProvider is implemented by engines that support fused residual-add + RMS normalization in a single GPU kernel launch. This eliminates one kernel launch per fusion point (2 per transformer layer).

type FusedNormAddProvider

type FusedNormAddProvider[T tensor.Numeric] interface {
	// GPUFusedNormAdd computes:
	//   normed = rmsnorm(input, weight, eps)
	//   output = normed + residual
	// All inputs are read-only. Returns (output, error).
	GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error)
}

FusedNormAddProvider is implemented by engines that support fused RMSNorm + elementwise Add in a single GPU kernel launch. output = rmsnorm(input, weight, eps) + residual. This eliminates one kernel launch per fusion point.

type FusedQKNormRoPEProvider

type FusedQKNormRoPEProvider[T tensor.Numeric] interface {
	// GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K data.
	// input: [totalHeads, headDim] (Q heads then K heads, contiguous).
	// weightQ/weightK: [headDim] RMSNorm weights.
	// cosAngles/sinAngles: [halfRotary] precomputed angles for current position.
	// Returns output: [totalHeads, headDim].
	GPUFusedQKNormRoPE(
		input *tensor.TensorNumeric[T],
		weightQ, weightK *tensor.TensorNumeric[T],
		cosAngles, sinAngles *tensor.TensorNumeric[T],
		eps float32,
		totalHeads, headDim, numQHeads, halfRotary int,
	) (*tensor.TensorNumeric[T], error)
}

FusedQKNormRoPEProvider is implemented by engines that support fused per-head QK RMSNorm + RoPE in a single GPU kernel launch. This replaces 4 kernel launches (Q_norm + K_norm + Q_RoPE + K_RoPE) with 1 per GQA layer during decode.

type FusedRMSNormer

type FusedRMSNormer interface {
	FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (output, scales *tensor.TensorNumeric[float32], err error)
}

FusedRMSNormer is an optional interface for engines that support GPU-accelerated fused RMSNorm. Layers can type-assert to this to use the fused kernel. Returns (output, scales) where scales contains per-row rsqrt values for backward pass.

type FusedRoPEProvider

type FusedRoPEProvider[T tensor.Numeric] interface {
	GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error)
}

FusedRoPEProvider is implemented by engines that support fused GPU RoPE.

type FusedScaledSoftmaxProvider

type FusedScaledSoftmaxProvider[T tensor.Numeric] interface {
	GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error)
}

FusedScaledSoftmaxProvider is implemented by engines that support fused GPU scaled softmax. It computes output = softmax(input * scale) in a single kernel launch, eliminating the MulScalar + Softmax chain (saves 1 kernel launch per call).

type FusedSwiGLUProvider

type FusedSwiGLUProvider[T tensor.Numeric] interface {
	GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

FusedSwiGLUProvider is implemented by engines that support fused GPU SwiGLU. It computes output[i] = w1[i] * sigmoid(w1[i]) * w3[i] in a single kernel, eliminating the Concat + Split + sigmoid + Mul + Mul chain.

type GPUArgmaxer

type GPUArgmaxer interface {
	GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error)
}

GPUArgmaxer is an optional interface for engines that can compute argmax entirely on GPU, returning just the index without copying logits to host. This eliminates the ~1MB D2H copy per token for greedy decoding.

type GPUEngine

type GPUEngine[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GPUEngine is a GPU-accelerated implementation of the Engine interface. MatMul uses BLAS for maximum performance. Elementwise, scalar, activation, and math operations use native GPU kernels for float32 types. Operations without GPU kernels delegate to CPUEngine.

GPUEngine uses a device-resident pipeline: output tensors have GPUStorage so data stays on GPU between chained operations. A memory pool avoids per-operation malloc/free, and a dedicated stream enables async kernel execution.

GPUEngine is backend-agnostic via the GRAL interfaces (internal/gpuapi/). The CUDA, ROCm, and OpenCL adapters implement these interfaces.

func NewGPUEngine

func NewGPUEngine[T tensor.Numeric](ops numeric.Arithmetic[T], deviceID ...int) (*GPUEngine[T], error)

NewGPUEngine creates a new GPUEngine backed by CUDA via the GRAL abstraction. An optional deviceID selects the GPU (default 0). Call Close() when done to release all resources.

func (*GPUEngine[T]) Add

func (e *GPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Add performs element-wise addition.

func (*GPUEngine[T]) AddScalar

func (e *GPUEngine[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

AddScalar adds a scalar to each element.

func (*GPUEngine[T]) ArenaUsedBytes

func (e *GPUEngine[T]) ArenaUsedBytes() int

ArenaUsedBytes returns the current arena offset (bytes in use).

func (*GPUEngine[T]) BatchNormForwardInference

func (e *GPUEngine[T]) BatchNormForwardInference(
	_ context.Context,
	x, scale, bias, mean, variance *tensor.TensorNumeric[T],
	epsilon float64,
) (*tensor.TensorNumeric[T], error)

BatchNormForwardInference performs batch normalization in inference mode using pre-computed running mean and variance via the GPU DNN backend. x must be [N, C, H, W]. scale, bias, mean, variance must each be [C].

func (*GPUEngine[T]) BatchNormForwardTraining

func (e *GPUEngine[T]) BatchNormForwardTraining(
	_ context.Context,
	x, scale, bias *tensor.TensorNumeric[T],
	runningMean, runningVariance *tensor.TensorNumeric[T],
	epsilon, expAvgFactor float64,
) (*tensor.TensorNumeric[T], *tensor.TensorNumeric[T], *tensor.TensorNumeric[T], error)

BatchNormForwardTraining performs batch normalization computing batch statistics. x: [N, C, H, W], scale/bias: [C]. Returns y: [N, C, H, W], saveMean: [C], saveInvVariance: [C]. Updates runningMean and runningVariance in-place.

func (*GPUEngine[T]) Close

func (e *GPUEngine[T]) Close() error

Close releases the BLAS handle, DNN handle, GPU stream, and drains the memory pool. The engine must not be used after Close.

func (*GPUEngine[T]) Concat

func (e *GPUEngine[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Concat concatenates tensors along an axis.

func (*GPUEngine[T]) Conv2dBackwardData

func (e *GPUEngine[T]) Conv2dBackwardData(
	_ context.Context,
	w *tensor.TensorNumeric[T],
	dy *tensor.TensorNumeric[T],
	dxShape [4]int,
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dBackwardData computes the gradient of the convolution input via cuDNN. w: [C_out, C_in/groups, kH, kW], dy: [N, C_out, outH, outW]. Returns dx: [N, C_in, H, W].

func (*GPUEngine[T]) Conv2dBackwardFilter

func (e *GPUEngine[T]) Conv2dBackwardFilter(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	dy *tensor.TensorNumeric[T],
	dwShape [4]int,
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dBackwardFilter computes the gradient of the convolution filter via cuDNN. x: [N, C_in, H, W], dy: [N, C_out, outH, outW]. Returns dw: [C_out, C_in/groups, kH, kW].

func (*GPUEngine[T]) Conv2dForward

func (e *GPUEngine[T]) Conv2dForward(
	_ context.Context,
	x, w *tensor.TensorNumeric[T],
	bias *tensor.TensorNumeric[T],
	strides [2]int,
	pads [4]int,
	dilations [2]int,
	groups int,
) (*tensor.TensorNumeric[T], error)

Conv2dForward performs 2D convolution using the GPU DNN backend. x must be [N, C_in, H, W], w must be [C_out, C_in/groups, kH, kW]. bias is optional (nil to skip). pads is [top, left, bottom, right]. Returns error if padding is asymmetric (cuDNN requires symmetric padding).

func (*GPUEngine[T]) ConvertFP16ToF32

func (e *GPUEngine[T]) ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error)

ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32 GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it does not have Float16Storage.

func (*GPUEngine[T]) Copy

func (e *GPUEngine[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error

Copy copies data from source to destination tensor.

func (*GPUEngine[T]) Cos

func (e *GPUEngine[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Cos computes the element-wise cosine.

func (*GPUEngine[T]) CudnnActivationBackward

func (e *GPUEngine[T]) CudnnActivationBackward(
	_ context.Context,
	x, y, dy *tensor.TensorNumeric[T],
	mode gpuapi.ActivationMode,
) (*tensor.TensorNumeric[T], error)

CudnnActivationBackward computes the gradient of an activation function via cuDNN. y: forward output, dy: upstream gradient, x: original input. All must have the same shape. Returns dx with the same shape.

func (*GPUEngine[T]) CudnnActivationForward

func (e *GPUEngine[T]) CudnnActivationForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	mode gpuapi.ActivationMode,
) (*tensor.TensorNumeric[T], error)

CudnnActivationForward applies an activation function via the GPU DNN backend. mode selects the activation: ActivationReLU, ActivationSigmoid, ActivationTanh. The input tensor shape is preserved in the output.

func (*GPUEngine[T]) CudnnBatchNormBackward

func (e *GPUEngine[T]) CudnnBatchNormBackward(
	_ context.Context,
	x, dy, scale *tensor.TensorNumeric[T],
	saveMean, saveInvVariance *tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], *tensor.TensorNumeric[T], *tensor.TensorNumeric[T], error)

CudnnBatchNormBackward computes gradients for batch normalization via cuDNN. x: [N, C, H, W], dy: [N, C, H, W], scale: [C]. saveMean, saveInvVariance: [C] (from BatchNormForwardTraining). Returns dx: [N, C, H, W], dScale: [C], dBias: [C].

func (*GPUEngine[T]) CudnnPoolingBackward

func (e *GPUEngine[T]) CudnnPoolingBackward(
	_ context.Context,
	x, y, dy *tensor.TensorNumeric[T],
	mode gpuapi.PoolingMode,
	windowH, windowW, padH, padW, strideH, strideW int,
) (*tensor.TensorNumeric[T], error)

CudnnPoolingBackward computes the gradient of 2D pooling via cuDNN. y: forward output [N,C,outH,outW], dy: upstream gradient [N,C,outH,outW], x: forward input [N,C,H,W]. Returns dx: [N,C,H,W].

func (*GPUEngine[T]) CudnnPoolingForward

func (e *GPUEngine[T]) CudnnPoolingForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
	mode gpuapi.PoolingMode,
	windowH, windowW, padH, padW, strideH, strideW int,
) (*tensor.TensorNumeric[T], error)

CudnnPoolingForward performs 2D pooling via the GPU DNN backend. x must be [N, C, H, W]. Returns [N, C, outH, outW].

func (*GPUEngine[T]) CudnnSoftmaxForward

func (e *GPUEngine[T]) CudnnSoftmaxForward(
	_ context.Context,
	x *tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error)

CudnnSoftmaxForward computes softmax via the GPU DNN backend over the channel dimension. x must be [N, C, H, W] (or reshaped to fit).

func (*GPUEngine[T]) DTypeValue

func (e *GPUEngine[T]) DTypeValue() DType

DTypeValue returns the current compute precision.

func (*GPUEngine[T]) DeviceID

func (e *GPUEngine[T]) DeviceID() int

DeviceID returns the GPU device ID this engine is bound to.

func (*GPUEngine[T]) Div

func (e *GPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Div performs element-wise division.

func (*GPUEngine[T]) DivScalar

func (e *GPUEngine[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

DivScalar divides each element by a scalar.

func (*GPUEngine[T]) Exp

func (e *GPUEngine[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Exp computes the element-wise exponential.

func (*GPUEngine[T]) Fill

func (e *GPUEngine[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error

Fill fills the tensor with a scalar value.

func (*GPUEngine[T]) FusedRMSNormGPU

func (e *GPUEngine[T]) FusedRMSNormGPU(input, weight *tensor.TensorNumeric[float32], epsilon float32) (*tensor.TensorNumeric[float32], *tensor.TensorNumeric[float32], error)

FusedRMSNormGPU implements the FusedRMSNormer interface for GPUEngine. Uses the fused GPU kernel when input is GPU-resident, falls back to CPU otherwise. Returns (output, scales) where scales contains per-row rsqrt values for backward pass.

func (*GPUEngine[T]) GPUArgmax

func (e *GPUEngine[T]) GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error)

GPUArgmax finds the index of the maximum element in a GPU-resident float32 tensor. Returns the index as an int without copying the full tensor to the host. Only copies back a single int32 (4 bytes) instead of the entire tensor.

func (*GPUEngine[T]) GPUFusedAddRMSNorm

func (e *GPUEngine[T]) GPUFusedAddRMSNorm(
	input, residual *tensor.TensorNumeric[T],
	weight *tensor.TensorNumeric[T],
	eps float32,
) (normed *tensor.TensorNumeric[T], residualOut *tensor.TensorNumeric[T], scales *tensor.TensorNumeric[T], err error)

GPUFusedAddRMSNorm computes sum = input + residual and normed = rmsnorm(sum, weight, eps) in a single GPU kernel launch. Both inputs are read-only; outputs go to separate buffers. This replaces Add + RMSNorm (2 kernel launches) with 1.

func (*GPUEngine[T]) GPUFusedNormAdd

func (e *GPUEngine[T]) GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error)

GPUFusedNormAdd computes output = rmsnorm(input, weight, eps) + residual in a single GPU kernel launch. Replaces separate RMSNorm + Add (2 launches → 1).

func (*GPUEngine[T]) GPUFusedQKNormRoPE

func (e *GPUEngine[T]) GPUFusedQKNormRoPE(
	input *tensor.TensorNumeric[T],
	weightQ, weightK *tensor.TensorNumeric[T],
	cosAngles, sinAngles *tensor.TensorNumeric[T],
	eps float32,
	totalHeads, headDim, numQHeads, halfRotary int,
) (*tensor.TensorNumeric[T], error)

GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K heads in a single GPU kernel launch. This replaces 4 kernel launches per GQA layer. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].

func (*GPUEngine[T]) GPUFusedRoPE

func (e *GPUEngine[T]) GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error)

GPUFusedRoPE applies rotary position embeddings in a single GPU kernel launch. This replaces Split + 4 Mul + Sub + Add + Concat (8 operations, ~10 D2D memcpy) with 1 kernel.

func (*GPUEngine[T]) GPUFusedSwiGLU

func (e *GPUEngine[T]) GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

GPUFusedSwiGLU computes SwiGLU(w1, w3) = w1 * sigmoid(w1) * w3 in a single GPU kernel. This replaces Concat + Split + sigmoid + Mul + Mul (5 operations, ~4 D2D memcpy per layer) with 1 kernel.

func (*GPUEngine[T]) GPUScaledSoftmax

func (e *GPUEngine[T]) GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error)

GPUScaledSoftmax computes softmax(input * scale) in a single GPU kernel launch. This replaces MulScalar + Softmax (2 kernel launches) with 1, saving 26 launches per token for 26 transformer layers.

func (*GPUEngine[T]) GPUStream

func (e *GPUEngine[T]) GPUStream() gpuapi.Stream

GPUStream returns the engine's gpuapi.Stream for async memory operations.

func (*GPUEngine[T]) Gather

func (e *GPUEngine[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error

Gather performs an embedding-style gather.

func (*GPUEngine[T]) IsManagedMemory

func (e *GPUEngine[T]) IsManagedMemory() bool

IsManagedMemory returns true if the engine uses managed memory for weight uploads and the arena allocator.

func (*GPUEngine[T]) IsPagedGQASupported added in v0.3.0

func (e *GPUEngine[T]) IsPagedGQASupported() bool

IsPagedGQASupported returns true when the paged attention CUDA kernel is loaded and available.

func (*GPUEngine[T]) Log

func (e *GPUEngine[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Log computes the element-wise natural logarithm.

func (*GPUEngine[T]) MatMul

func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMul performs matrix multiplication using GPU BLAS for float32 and BFloat16 tensors. For Q4_0 quantized tensors, uses the Q4 dequant-GEMM kernel. For other types, it falls back to the CPU implementation. Supports 2D matrices and batched matmul (3D+ tensors).

func (*GPUEngine[T]) MatMulTransposeB

func (e *GPUEngine[T]) MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MatMulTransposeB computes C = A * B^T using cuBLAS SgemmNT, avoiding an explicit Transpose allocation and kernel launch. A is [...batch, m, k], B is [...batch, n, k], result is [...batch, m, n]. Supports batch broadcasting (bBatch=1 broadcasts B across A's batch).

func (*GPUEngine[T]) Mul

func (e *GPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Mul performs element-wise multiplication.

func (*GPUEngine[T]) MulScalar

func (e *GPUEngine[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

MulScalar multiplies each element by a scalar.

func (*GPUEngine[T]) OOMFallbackCount

func (e *GPUEngine[T]) OOMFallbackCount() int64

OOMFallbackCount returns the number of times GPU OOM triggered CPU fallback.

func (*GPUEngine[T]) OneHot

func (e *GPUEngine[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

OneHot creates a one-hot encoding.

func (*GPUEngine[T]) Ops

func (e *GPUEngine[T]) Ops() numeric.Arithmetic[T]

Ops returns the arithmetic ops for this engine.

func (*GPUEngine[T]) PagedGQA added in v0.3.0

func (e *GPUEngine[T]) PagedGQA(
	Q *tensor.TensorNumeric[float32],
	blockPtrsK, blockPtrsV unsafe.Pointer,
	blockIndices unsafe.Pointer,
	seqLen, blockSize, headDim int,
	numQHeads, numKVHeads int,
	batch int,
) (*tensor.TensorNumeric[float32], error)

PagedGQA computes scaled dot-product attention with block-table indirection for paged KV caches. When the paged attention kernel is not available, it returns an error.

Q: [batch*numQHeads, headDim] query tensor (GPU-resident). blockPtrsK: device array of float* pointers to K blocks. blockPtrsV: device array of float* pointers to V blocks. blockIndices: device array [batch * maxNumBlocks] logical→physical mapping. Returns output [batch*numQHeads, headDim].

func (*GPUEngine[T]) Pow

func (e *GPUEngine[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Pow raises each element to the given power.

func (*GPUEngine[T]) RandomUniform

func (e *GPUEngine[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error

RandomUniform fills the tensor with uniform random values.

func (*GPUEngine[T]) ReduceMean

func (e *GPUEngine[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

ReduceMean computes the mean of elements along an axis.

func (*GPUEngine[T]) ReduceSum

func (e *GPUEngine[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

ReduceSum computes the sum of elements along an axis.

func (*GPUEngine[T]) Repeat

func (e *GPUEngine[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Repeat repeats the tensor along an axis.

func (*GPUEngine[T]) ResetPool

func (e *GPUEngine[T]) ResetPool()

ResetPool resets the arena pool, reclaiming all per-pass allocations. This is a no-op if the pool is not arena-backed.

func (*GPUEngine[T]) Reshape

func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Reshape changes the shape without changing data.

func (*GPUEngine[T]) Rsqrt

func (e *GPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Rsqrt computes the element-wise reciprocal square root.

func (*GPUEngine[T]) ScatterAdd

func (e *GPUEngine[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error

ScatterAdd performs a row-wise scatter-add for embeddings.

func (*GPUEngine[T]) SetArenaResetFloor

func (e *GPUEngine[T]) SetArenaResetFloor(floor int)

SetArenaResetFloor sets the minimum offset that arena Reset will rewind to.

func (*GPUEngine[T]) SetDType

func (e *GPUEngine[T]) SetDType(d DType)

SetDType sets the compute precision for elementwise ops and MatMul. DTypeFP16 enables the FP16 inference path: F32 inputs are converted to FP16 on GPU, FP16 kernels run, and results are converted back to F32.

func (*GPUEngine[T]) SetLogger

func (e *GPUEngine[T]) SetLogger(l log.Logger)

SetLogger replaces the engine's logger.

func (*GPUEngine[T]) Sin

func (e *GPUEngine[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sin computes the element-wise sine.

func (*GPUEngine[T]) Softmax

func (e *GPUEngine[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Softmax applies the softmax function along an axis.

func (*GPUEngine[T]) Split

func (e *GPUEngine[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error)

Split splits a tensor into multiple tensors along an axis.

func (*GPUEngine[T]) Sqrt

func (e *GPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sqrt computes the element-wise square root.

func (*GPUEngine[T]) Stream

func (e *GPUEngine[T]) Stream() unsafe.Pointer

Stream returns the engine's GPU stream as an unsafe.Pointer (cudaStream_t).

func (*GPUEngine[T]) Sub

func (e *GPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sub performs element-wise subtraction.

func (*GPUEngine[T]) Sum

func (e *GPUEngine[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Sum computes the sum of elements along an axis.

func (*GPUEngine[T]) Sync

func (e *GPUEngine[T]) Sync() error

Sync synchronizes the GPU stream, blocking until all enqueued operations complete. Use for benchmarking or when explicit synchronization is needed.

func (*GPUEngine[T]) Tanh

func (e *GPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Tanh computes the element-wise hyperbolic tangent.

func (*GPUEngine[T]) TanhPrime

func (e *GPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

TanhPrime computes the element-wise gradient of tanh.

func (*GPUEngine[T]) Transpose

func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Transpose transposes a tensor along the given axes.

func (*GPUEngine[T]) UnaryOp

func (e *GPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

func (*GPUEngine[T]) UploadWeights

func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) error

UploadWeights copies CPU-resident tensors to GPU device memory in place. Tensors that already have GPUStorage are skipped. Q4 quantized weights get their raw bytes uploaded and cached in Q4Storage to avoid per-op H2D. This is called once at model load time.

On devices with managed memory support (e.g., GB10), weights are allocated with cudaMallocManaged and populated via direct CPU memcpy. The GPU can then access them without any explicit H2D transfer.

func (*GPUEngine[T]) Zero

func (e *GPUEngine[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error

Zero sets all elements to zero.

func (*GPUEngine[T]) Zeros

func (e *GPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error

Zeros fills the tensor with zeros.

type GPUStreamAccessor

type GPUStreamAccessor interface {
	GPUStream() gpuapi.Stream
}

GPUStreamAccessor is an optional interface for engines that provide their gpuapi.Stream for async memory operations (e.g., KV cache D2D copies during CUDA graph capture).

type HardwareProfile added in v0.3.0

type HardwareProfile struct {
	// CPU
	CPUCores  int    // logical CPU count (GOMAXPROCS-visible)
	CPUModel  string // human-readable CPU model string
	HasNEON   bool   // ARM SIMD (Neon)
	HasAVX2   bool   // x86 SIMD (AVX2)
	HasAVX512 bool   // x86 advanced SIMD (AVX-512)
	CacheL1   int64  // L1 data cache size in bytes (0 if unknown)
	CacheL2   int64  // L2 cache size in bytes (0 if unknown)
	CacheL3   int64  // L3 cache size in bytes (0 if unknown)
	TotalRAM  int64  // total physical memory in bytes

	// GPU
	GPUAvailable  bool   // true if a usable GPU was detected
	GPUBackend    string // "cuda", "rocm", "metal", "opencl", or ""
	GPUName       string // human-readable GPU name
	GPUMemory     int64  // GPU memory in bytes (0 if unknown)
	GPUComputeCap string // e.g. "8.9" for CUDA compute capability
	MultiGPU      bool   // true if more than one GPU is available
	GPUCount      int    // number of GPUs (0 if none)
}

HardwareProfile describes the CPU and GPU capabilities of the current system. It is used by the auto-optimization framework to select the best engine, kernels, and quantization strategy for the detected hardware.

func ProfileHardware added in v0.3.0

func ProfileHardware() (*HardwareProfile, error)

ProfileHardware detects the hardware capabilities of the current system. CPU information is always populated. GPU fields are populated on a best-effort basis — they remain zero-valued when no GPU is detected.

func (*HardwareProfile) RecommendEngine added in v0.3.0

func (p *HardwareProfile) RecommendEngine() string

RecommendEngine returns the compute backend name that best fits the detected hardware: "cuda", "rocm", "metal", "opencl", or "cpu".

func (*HardwareProfile) String added in v0.3.0

func (p *HardwareProfile) String() string

String returns a human-readable summary of the hardware profile.

type MemoryTracker

type MemoryTracker struct {
	// contains filtered or unexported fields
}

MemoryTracker tracks total allocated bytes with an optional upper limit. All methods are safe for concurrent use.

func NewMemoryTracker

func NewMemoryTracker(limit int64) *MemoryTracker

NewMemoryTracker creates a tracker with the given byte limit. A limit of 0 disables enforcement (unlimited).

func (*MemoryTracker) Alloc

func (m *MemoryTracker) Alloc(bytes int64) error

Alloc reserves bytes. If the allocation would exceed the limit, it returns ErrMemoryLimitExceeded without modifying the counter.

func (*MemoryTracker) Allocated

func (m *MemoryTracker) Allocated() int64

Allocated returns the current total allocated bytes.

func (*MemoryTracker) Free

func (m *MemoryTracker) Free(bytes int64)

Free releases previously allocated bytes.

func (*MemoryTracker) Limit

func (m *MemoryTracker) Limit() int64

Limit returns the configured byte limit (0 means unlimited).

type PagedGQAer added in v0.3.0

type PagedGQAer interface {
	PagedGQA(
		Q *tensor.TensorNumeric[float32],
		blockPtrsK, blockPtrsV unsafe.Pointer,
		blockIndices unsafe.Pointer,
		seqLen, blockSize, headDim int,
		numQHeads, numKVHeads int,
		batch int,
	) (*tensor.TensorNumeric[float32], error)

	// IsPagedGQASupported returns true when the paged attention kernel is
	// available on this engine.
	IsPagedGQASupported() bool
}

PagedGQAer is an optional interface for engines that support paged grouped-query attention via block-table indirection. When the engine supports paged attention, callers can pass block pointers and indices instead of contiguous KV tensors.

Q: [batch*numQHeads, headDim] blockPtrsK: device array of float* pointers to K blocks blockPtrsV: device array of float* pointers to V blocks blockIndices: device array [batch * maxNumBlocks] logical→physical mapping seqLen: valid KV positions blockSize: tokens per block headDim: dimension per head numQHeads: query heads per batch element numKVHeads: KV heads per batch element batch: number of sequences

Returns output tensor [batch*numQHeads, headDim].

type PoolResetter

type PoolResetter interface {
	ResetPool()
}

PoolResetter is an optional interface for engines that use arena-based memory pools. Call ResetPool() at the start of each forward pass to reclaim all per-pass intermediate allocations in O(1).

type StreamProvider

type StreamProvider interface {
	// Stream returns the engine's GPU stream as an unsafe.Pointer (cudaStream_t).
	Stream() unsafe.Pointer
}

StreamProvider is an optional interface for engines that expose their underlying GPU stream for CUDA graph capture.

type TensorArena

type TensorArena struct {
	// contains filtered or unexported fields
}

TensorArena is a power-of-2 bucketed pool for float32 backing arrays. It reduces GC pressure during inference by reusing allocations. Thread-safe with per-bucket mutex protection.

func (*TensorArena) Get

func (a *TensorArena) Get(n int) []float32

Get returns a float32 slice of at least n elements from the arena. The returned slice has length n but may have capacity rounded up to a power of 2.

func (*TensorArena) Put

func (a *TensorArena) Put(buf []float32)

Put returns a buffer to the arena for reuse.

func (*TensorArena) Reset

func (a *TensorArena) Reset()

Reset clears all pooled buffers, allowing GC to collect them.

type TensorPool

type TensorPool[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

TensorPool provides reusable tensor buffers keyed by shape. Acquire returns a tensor from the pool or allocates a new one. Release returns a tensor to the pool for future reuse. The pool is safe for concurrent use.

func NewTensorPool

func NewTensorPool[T tensor.Numeric]() *TensorPool[T]

NewTensorPool creates a new empty tensor pool.

func (*TensorPool[T]) Acquire

func (p *TensorPool[T]) Acquire(shape []int) (*tensor.TensorNumeric[T], error)

Acquire returns a tensor with the given shape. If the pool has a matching buffer, it is returned (zeroed). Otherwise a new tensor is allocated.

func (*TensorPool[T]) Len

func (p *TensorPool[T]) Len() int

Len returns the total number of tensors currently in the pool.

func (*TensorPool[T]) Release

func (p *TensorPool[T]) Release(t *tensor.TensorNumeric[T])

Release returns a tensor to the pool for future reuse. For GPU-backed tensors, the device memory is freed immediately (returned to the GPU MemPool for reuse) rather than holding the tensor reference, since GPU memory is a scarce resource managed by a separate pool. The tensor must not be used after calling Release.

type TestableEngine

type TestableEngine[T tensor.Numeric] struct {
	*CPUEngine[T]
}

TestableEngine extends CPUEngine with methods that allow controlled error injection This enables testing of previously unreachable error paths.

func NewTestableEngine

func NewTestableEngine[T tensor.Numeric](ops numeric.Arithmetic[T]) *TestableEngine[T]

NewTestableEngine creates a new TestableEngine.

func (*TestableEngine[T]) TestableMatMul

func (e *TestableEngine[T]) TestableMatMul(_ context.Context, a, b *tensor.TensorNumeric[T], result *FailableTensor[T]) error

TestableMatMul performs matrix multiplication with a FailableTensor result This allows testing the error path in MatMul when result.Set() fails.

func (*TestableEngine[T]) TestableSum

func (e *TestableEngine[T]) TestableSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, _ bool, zeroer *FailableZeroer[T], result *tensor.TensorNumeric[T]) error

TestableSum performs sum with a FailableZeroer This allows testing the error path in Sum when Zero() fails.

func (*TestableEngine[T]) TestableTranspose

func (e *TestableEngine[T]) TestableTranspose(_ context.Context, a *tensor.TensorNumeric[T], result *FailableTensor[T]) error

TestableTranspose performs transpose with a FailableTensor result This allows testing the error path in Transpose when result.Set() fails.

type TraceRecorder

type TraceRecorder[T tensor.Numeric] interface {
	Record(opName string, inputs []*tensor.TensorNumeric[T], output *tensor.TensorNumeric[T], extra map[string]any)
	RecordMultiOutput(opName string, inputs []*tensor.TensorNumeric[T], outputs []*tensor.TensorNumeric[T], extra map[string]any)
	RecordGather(params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T], extra map[string]any)
}

TraceRecorder is the interface used by EngineProxy to record traced operations.

type TracedOp

type TracedOp struct {
	OpName    string
	InputIDs  []int          // slot indices for inputs
	OutputID  int            // slot index for output
	OutputIDs []int          // for multi-output ops like Split
	ExtraArgs map[string]any // axis, scalar value, shape, etc.
}

TracedOp records a single engine operation with slot-based tensor identity.

type Tracer

type Tracer[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

Tracer records engine operations and tracks tensor identity by pointer. It assigns each unique tensor pointer a slot index, enabling later compilation into an ExecutionPlan.

func NewTracer

func NewTracer[T tensor.Numeric](frozenTensors []*tensor.TensorNumeric[T]) *Tracer[T]

NewTracer creates a Tracer and pre-registers frozen tensors (model weights) with slot indices.

func (*Tracer[T]) FrozenSlots

func (t *Tracer[T]) FrozenSlots() []int

FrozenSlots returns the slot indices of frozen (weight) tensors.

func (*Tracer[T]) HasOpaqueOps

func (t *Tracer[T]) HasOpaqueOps() bool

HasOpaqueOps reports whether any opaque (non-traceable) operations were encountered during tracing.

func (*Tracer[T]) MarkOpaque

func (t *Tracer[T]) MarkOpaque()

MarkOpaque marks the trace as containing opaque operations.

func (*Tracer[T]) NextSlot

func (t *Tracer[T]) NextSlot() int

NextSlot returns the next slot index that would be assigned. This indicates the total number of slots allocated.

func (*Tracer[T]) Record

func (t *Tracer[T]) Record(opName string, inputs []*tensor.TensorNumeric[T], output *tensor.TensorNumeric[T], extra map[string]any)

Record appends a TracedOp for a single-output operation.

func (*Tracer[T]) RecordGather

func (t *Tracer[T]) RecordGather(params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T], extra map[string]any)

RecordGather appends a TracedOp for Gather which uses int indices.

func (*Tracer[T]) RecordMultiOutput

func (t *Tracer[T]) RecordMultiOutput(opName string, inputs []*tensor.TensorNumeric[T], outputs []*tensor.TensorNumeric[T], extra map[string]any)

RecordMultiOutput appends a TracedOp for a multi-output operation (e.g., Split).

func (*Tracer[T]) SlotFor

func (t *Tracer[T]) SlotFor(tn *tensor.TensorNumeric[T]) int

SlotFor returns the existing slot for a tensor or assigns a new one. This is the exported version of slotFor.

func (*Tracer[T]) SlotShapes

func (t *Tracer[T]) SlotShapes() map[int][]int

SlotShapes returns the shape for each slot index.

func (*Tracer[T]) TracedOps

func (t *Tracer[T]) TracedOps() []TracedOp

TracedOps returns the recorded operations in order.

type TransposeBMatMuler

type TransposeBMatMuler[T tensor.Numeric] interface {
	MatMulTransposeB(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

TransposeBMatMuler is an optional interface for engines that can compute C = A * B^T without explicitly transposing B. This avoids an extra GPU allocation and kernel launch for the transpose operation. A is [batch, m, k], B is [batch, n, k], result is [batch, m, n].

type W4A16MatMuler added in v0.3.0

type W4A16MatMuler[T tensor.Numeric] interface {
	// MatMulW4A16 performs C = dequant(W_4bit) * A_fp16 with the weight and
	// activation operands identified by the caller.
	MatMulW4A16(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
}

W4A16MatMuler is an optional interface for engines that support W4A16 mixed-precision matrix multiplication with fused dequantization.

type W4A16Precision added in v0.3.0

type W4A16Precision struct {
	// WeightFormat describes which 4-bit quantization is used.
	WeightFormat string // "Q4_0", "GPTQ_4", "AWQ"
}

W4A16Precision represents the mixed-precision configuration where weights are stored in 4-bit quantized format and activations are in FP16.

func W4A16Info added in v0.3.0

func W4A16Info[T tensor.Numeric](a, b *tensor.TensorNumeric[T]) W4A16Precision

W4A16Info returns metadata about a W4A16 mixed-precision pair. Returns zero value if the inputs are not a W4A16 combination.

type WeightUploader

type WeightUploader interface {
	UploadWeights(tensors []*tensor.TensorNumeric[float32]) error
}

WeightUploader is an optional interface for engines that can pre-upload model weights to device memory at load time. This eliminates per-operation host-to-device copies during inference. Each tensor's storage is replaced in-place from CPUStorage to device-resident storage.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL