batched

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 21, 2026 License: Apache-2.0 Imports: 6 Imported by: 0

Documentation

Overview

Package batched provides batched multi-model inference, enabling 1000+ per-source models that share the same architecture to run in a single batched GEMM call rather than N sequential matrix multiplications.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type ActivationType

type ActivationType int

ActivationType identifies a supported activation function.

const (
	// ActivationNone applies no activation (identity).
	ActivationNone ActivationType = iota
	// ActivationReLU applies max(0, x).
	ActivationReLU
	// ActivationTanh applies the hyperbolic tangent function.
	ActivationTanh
)

type Architecture

type Architecture struct {
	Layers []LayerSpec
}

Architecture describes the shared model architecture. Every model in a batch must conform to the same architecture (layer sizes and activations).

func (Architecture) Validate

func (a Architecture) Validate() error

Validate checks that the architecture is well-formed.

type BatchedInference

type BatchedInference struct {
	// contains filtered or unexported fields
}

BatchedInference runs forward passes for all models in parallel using batched GEMM operations through the compute.Engine[float32] interface.

func NewBatchedInference

func NewBatchedInference(numModels int, arch Architecture, engine compute.Engine[float32]) (*BatchedInference, error)

NewBatchedInference creates a new BatchedInference for numModels models sharing the given architecture. All tensor arithmetic flows through engine.

func (*BatchedInference) Architecture

func (bi *BatchedInference) Architecture() Architecture

Architecture returns the shared architecture.

func (*BatchedInference) Forward

func (bi *BatchedInference) Forward(inputs [][]float32) ([][]float32, error)

Forward runs all models in parallel via batched GEMM. inputs[i] is the input vector for model i. All inputs must have length equal to the first layer's InputSize. Returns one output vector per model.

func (*BatchedInference) LoadWeights

func (bi *BatchedInference) LoadWeights(modelIdx int, weights map[string][]float32) error

LoadWeights loads weight parameters for a single model identified by modelIdx. The weights map is keyed by "layer_<N>" and each value must be a row-major float32 slice of size InputSize*OutputSize.

func (*BatchedInference) NumModels

func (bi *BatchedInference) NumModels() int

NumModels returns the number of models in the batch.

type BatchedWeights

type BatchedWeights struct {
	// contains filtered or unexported fields
}

BatchedWeights holds weight tensors for N models that share the same architecture but have different learned parameters. Weights are stored contiguously with a batch dimension so that a single batched GEMM can process all models at once.

func NewBatchedWeights

func NewBatchedWeights(numModels int, arch Architecture) (*BatchedWeights, error)

NewBatchedWeights allocates contiguous weight storage for numModels models.

type LayerSpec

type LayerSpec struct {
	InputSize  int
	OutputSize int
	Activation ActivationType
}

LayerSpec describes one fully-connected layer in the shared architecture.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL