autoopt

package
v1.38.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 30, 2026 License: Apache-2.0 Imports: 7 Imported by: 0

Documentation

Overview

Package autoopt provides automatic optimization recommendations based on hardware profiling. This file implements quantization format recommendation.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func EstimateWGMMAIterations

func EstimateWGMMAIterations(m, n, k int, cfg *WGMMAConfig) int

EstimateWGMMAIterations returns the number of warp-group MMA iterations needed to cover a GEMM of the given dimensions with the selected tile.

func IsTMACompatible

func IsTMACompatible(rows, cols, elementSizeBytes int) bool

IsTMACompatible checks whether a tensor with the given dimensions and element size can be efficiently loaded via TMA. Returns true if the layout satisfies alignment and size constraints.

Types

type AcceleratorInfo

type AcceleratorInfo struct {
	// DeviceID is a unique identifier for this device.
	DeviceID int

	// Type is the accelerator backend (cuda, rocm, metal, cpu).
	Type DeviceType

	// AvailableMemory is the amount of free memory in bytes.
	AvailableMemory int64

	// Utilization is the current device load from 0.0 (idle) to 1.0 (fully loaded).
	Utilization float64

	// QueueDepth is the number of workloads currently queued on this device.
	QueueDepth int
}

AcceleratorInfo describes a single accelerator device available for scheduling.

type BlackwellCapabilities

type BlackwellCapabilities struct {
	// TMA indicates Tensor Memory Accelerator support (inherited from Hopper).
	TMA bool

	// WGMMA indicates warp-group MMA support (inherited from Hopper).
	WGMMA bool

	// FP8Native indicates FP8 tensor core support (inherited from Hopper).
	FP8Native bool

	// FP4Aware indicates the runtime can exploit FP4 quantized weights
	// via native tensor core instructions.
	FP4Aware bool

	// ClusterPrimitives indicates support for cluster-level synchronization
	// and collective operations across thread block clusters.
	ClusterPrimitives bool

	// MaxClusterSize is the maximum cluster size (typically 16 or 32).
	MaxClusterSize int

	// SharedMemoryBytes is the maximum shared memory per SM in bytes.
	SharedMemoryBytes int64
}

BlackwellCapabilities describes SM100+-class GPU features.

func DetectBlackwellCapabilities

func DetectBlackwellCapabilities(gen GPUGeneration) *BlackwellCapabilities

DetectBlackwellCapabilities returns the Blackwell feature set for SM100+ GPUs. Returns nil if the generation is not Blackwell.

type CostModel

type CostModel struct {
	// TransferBandwidth is the device-to-device transfer rate in bytes/sec.
	// Defaults to 16 GB/s (PCIe Gen4 x16) if zero.
	TransferBandwidth float64
}

CostModel estimates the execution time of an Op on a given device.

func (*CostModel) EstimateTimeNs

func (cm *CostModel) EstimateTimeNs(op *Op, dev *DeviceCapability) float64

EstimateTimeNs returns the estimated execution time in nanoseconds for the given op on the given device.

func (*CostModel) TransferCostNs

func (cm *CostModel) TransferCostNs(memoryBytes int64, srcIsGPU, dstIsGPU bool) float64

TransferCostNs estimates the data transfer cost in nanoseconds for moving an op's data from one device to another.

type DeviceAssignment

type DeviceAssignment struct {
	// Op is the operation being assigned.
	Op Op

	// DeviceID is the ID of the device this op is assigned to.
	DeviceID string

	// EstimatedTimeNs is the estimated execution time in nanoseconds.
	EstimatedTimeNs float64

	// TransferCostNs is the estimated data transfer overhead in nanoseconds.
	TransferCostNs float64
}

DeviceAssignment maps an op to a device with its estimated execution cost.

type DeviceCapability

type DeviceCapability struct {
	// ID uniquely identifies this device (e.g. "cpu:0", "cuda:0", "cuda:1").
	ID string

	// Profile is the underlying hardware profile for this device.
	Profile *HardwareProfile

	// FLOPS is the estimated peak floating-point operations per second.
	FLOPS float64

	// MemoryBandwidth is the estimated memory bandwidth in bytes per second.
	MemoryBandwidth float64

	// AvailableMemory is the usable memory in bytes for tensor storage.
	AvailableMemory int64

	// IsGPU is true when this device is a GPU accelerator.
	IsGPU bool
}

DeviceCapability wraps a HardwareProfile with computed performance estimates.

func NewDeviceCapability

func NewDeviceCapability(id string, hw *HardwareProfile, isGPU bool) DeviceCapability

NewDeviceCapability computes performance estimates from a HardwareProfile. The isGPU flag indicates whether this represents a GPU device.

type DeviceType

type DeviceType string

DeviceType identifies the accelerator backend.

const (
	DeviceCUDA  DeviceType = "cuda"
	DeviceROCm  DeviceType = "rocm"
	DeviceMetal DeviceType = "metal"
	DeviceCPU   DeviceType = "cpu"
)

type ElementwiseTemplate

type ElementwiseTemplate struct {
	// NumElements is the total number of elements to process.
	NumElements int
}

ElementwiseTemplate generates element-wise kernel configurations. Maximizes occupancy with simple grid/block sizing.

func (*ElementwiseTemplate) Configure

func (t *ElementwiseTemplate) Configure(profile *HardwareProfile) *KernelConfig

Configure produces an element-wise kernel config tuned to the hardware profile.

type ExecutionPath

type ExecutionPath string

ExecutionPath identifies which kernel dispatch strategy to use for an operation.

const (
	// PathStandard uses the baseline CUDA kernel (Ampere or older).
	PathStandard ExecutionPath = "standard"

	// PathTMA uses Tensor Memory Accelerator for async bulk data movement.
	PathTMA ExecutionPath = "tma"

	// PathWGMMA uses warp-group MMA for matrix operations.
	PathWGMMA ExecutionPath = "wgmma"

	// PathTMAWGMMA combines TMA loads with warp-group MMA compute.
	PathTMAWGMMA ExecutionPath = "tma_wgmma"

	// PathFP4Cluster uses FP4-aware cluster-level execution on Blackwell.
	PathFP4Cluster ExecutionPath = "fp4_cluster"
)

type GEMMTemplate

type GEMMTemplate struct {
	// M, N, K are the matrix dimensions for the GEMM operation.
	M, N, K int
}

GEMMTemplate generates GEMM kernel configurations. Tile sizes and shared memory usage scale with GPU compute capability and available resources.

func (*GEMMTemplate) Configure

func (t *GEMMTemplate) Configure(profile *HardwareProfile) *KernelConfig

Configure produces a GEMM kernel config tuned to the hardware profile.

type GEMVTemplate

type GEMVTemplate struct {
	// Rows and Cols are the matrix dimensions.
	Rows, Cols int
}

GEMVTemplate generates matrix-vector multiply kernel configurations. Vectorization is chosen based on SIMD width.

func (*GEMVTemplate) Configure

func (t *GEMVTemplate) Configure(profile *HardwareProfile) *KernelConfig

Configure produces a GEMV kernel config tuned to the hardware profile.

type GPUGeneration

type GPUGeneration int

GPUGeneration identifies a class of NVIDIA GPU microarchitecture.

const (
	// GPUGenAmpere represents SM 8.x GPUs (A100, RTX 30xx/40xx Ada).
	GPUGenAmpere GPUGeneration = iota

	// GPUGenHopper represents SM 9.0 GPUs (H100, H200).
	GPUGenHopper

	// GPUGenBlackwell represents SM 10.0+ GPUs (B100, B200, GB200).
	GPUGenBlackwell

	// GPUGenUnknown is returned for GPUs older than Ampere or unrecognized SM versions.
	GPUGenUnknown GPUGeneration = -1
)

func DetectGPUGeneration

func DetectGPUGeneration(computeCap string) GPUGeneration

DetectGPUGeneration maps a CUDA compute capability string to a GPUGeneration.

func (GPUGeneration) String

func (g GPUGeneration) String() string

String returns the generation name.

type HardwareProfile

type HardwareProfile struct {
	// CPU
	CPUCores  int    // logical CPU count (GOMAXPROCS-visible)
	CPUModel  string // human-readable CPU model string
	HasNEON   bool   // ARM SIMD (Neon)
	HasAVX2   bool   // x86 SIMD (AVX2)
	HasAVX512 bool   // x86 advanced SIMD (AVX-512)
	CacheL1   int64  // L1 data cache size in bytes (0 if unknown)
	CacheL2   int64  // L2 cache size in bytes (0 if unknown)
	CacheL3   int64  // L3 cache size in bytes (0 if unknown)
	TotalRAM  int64  // total physical memory in bytes

	// GPU
	GPUAvailable  bool   // true if a usable GPU was detected
	GPUBackend    string // "cuda", "rocm", "metal", "opencl", or ""
	GPUName       string // human-readable GPU name
	GPUMemory     int64  // GPU memory in bytes (0 if unknown)
	GPUComputeCap string // e.g. "8.9" for CUDA compute capability
	MultiGPU      bool   // true if more than one GPU is available
	GPUCount      int    // number of GPUs (0 if none)
}

HardwareProfile describes the CPU and GPU capabilities of the current system. This mirrors github.com/zerfoo/ztensor/compute.HardwareProfile and will be replaced by a direct import once the ztensor dependency is updated.

func DefaultProfile

func DefaultProfile() *HardwareProfile

DefaultProfile returns a minimal hardware profile based on runtime detection.

type HopperCapabilities

type HopperCapabilities struct {
	// TMA indicates Tensor Memory Accelerator support for async bulk copies.
	TMA bool

	// WGMMA indicates warp-group matrix multiply-accumulate support.
	WGMMA bool

	// FP8Native indicates hardware-native FP8 (E4M3/E5M2) tensor core support.
	FP8Native bool

	// ClusterSize is the maximum thread block cluster size (typically 8 or 16).
	ClusterSize int

	// SharedMemoryBytes is the maximum shared memory per SM in bytes.
	SharedMemoryBytes int64
}

HopperCapabilities describes SM90-class GPU features.

func DetectHopperCapabilities

func DetectHopperCapabilities(gen GPUGeneration) *HopperCapabilities

DetectHopperCapabilities returns the Hopper feature set for SM90 GPUs. Returns nil if the generation is not Hopper or later.

type KernelClass

type KernelClass string

KernelClass identifies a category of computation for which multiple implementation strategies exist (e.g. GEMM, attention, normalization).

const (
	KernelGEMM        KernelClass = "gemm"        // general matrix multiply
	KernelGEMV        KernelClass = "gemv"        // matrix-vector multiply
	KernelAttention   KernelClass = "attention"   // scaled dot-product attention
	KernelRMSNorm     KernelClass = "rmsnorm"     // RMS normalization
	KernelSoftmax     KernelClass = "softmax"     // softmax
	KernelRoPE        KernelClass = "rope"        // rotary positional embedding
	KernelSiLU        KernelClass = "silu"        // SiLU/SwiGLU activation
	KernelElementwise KernelClass = "elementwise" // element-wise add/mul/etc.
	KernelQuantGEMM   KernelClass = "quant_gemm"  // quantized GEMM (Q4/Q8)
	KernelQuantDot    KernelClass = "quant_dot"   // quantized dot product
)

type KernelCodegen

type KernelCodegen struct {
	// contains filtered or unexported fields
}

KernelCodegen generates hardware-optimized kernel configurations using templates selected by kernel class and tuned to the hardware profile.

func NewKernelCodegen

func NewKernelCodegen(profile *HardwareProfile) *KernelCodegen

NewKernelCodegen creates a KernelCodegen bound to the given hardware profile.

func (*KernelCodegen) GenerateConfig

func (c *KernelCodegen) GenerateConfig(class KernelClass, dims ...int) *KernelConfig

GenerateConfig selects a template for the given kernel class, configures it for the bound hardware profile, and returns the resulting KernelConfig.

The dims parameter depends on the kernel class:

  • KernelGEMM: [M, N, K]
  • KernelGEMV: [rows, cols]
  • KernelElementwise: [numElements]
  • Other classes: [numElements] (uses elementwise template)

func (*KernelCodegen) GenerateLaunchParams

func (c *KernelCodegen) GenerateLaunchParams(class KernelClass, totalElements int) (gridDim, blockDim [3]int)

GenerateLaunchParams computes grid and block dimensions for launching a kernel of the given class over totalElements work items.

type KernelConfig

type KernelConfig struct {
	// Tile sizes for tiled algorithms (GEMM, GEMV).
	TileM int
	TileN int
	TileK int

	// UnrollFactor controls loop unrolling depth.
	UnrollFactor int

	// SharedMemBytes is the shared memory allocation per block (GPU).
	SharedMemBytes int

	// RegistersPerThread is the target register usage per thread (GPU).
	RegistersPerThread int

	// GridDim is the compute grid dimensions [x, y, z].
	GridDim [3]int

	// BlockDim is the block/threadgroup dimensions [x, y, z].
	BlockDim [3]int

	// VectorizationWidth is the SIMD vector width in number of float32 elements.
	VectorizationWidth int
}

KernelConfig holds hardware-optimized kernel launch parameters and tile sizes for a specific kernel class running on specific hardware.

func (*KernelConfig) String

func (cfg *KernelConfig) String() string

String returns a human-readable summary of the kernel config.

type KernelImpl

type KernelImpl string

KernelImpl identifies a specific implementation strategy for a kernel class.

const (
	// Backend implementations
	ImplGenericCPU KernelImpl = "generic_cpu" // scalar Go fallback
	ImplNEON       KernelImpl = "neon"        // ARM NEON SIMD assembly
	ImplAVX2       KernelImpl = "avx2"        // x86 AVX2 SIMD assembly
	ImplAVX512     KernelImpl = "avx512"      // x86 AVX-512 SIMD assembly
	ImplCUDA       KernelImpl = "cuda"        // CUDA GPU kernel
	ImplCUDAFused  KernelImpl = "cuda_fused"  // CUDA fused kernel (e.g. FlashAttention)
	ImplROCm       KernelImpl = "rocm"        // ROCm/HIP GPU kernel
	ImplROCmFused  KernelImpl = "rocm_fused"  // ROCm fused kernel
	ImplMetal      KernelImpl = "metal"       // Apple Metal GPU kernel
	ImplOpenCL     KernelImpl = "opencl"      // OpenCL GPU kernel
)

type KernelSelection

type KernelSelection struct {
	// Selections maps each kernel class to its chosen implementation.
	Selections map[KernelClass]KernelImpl

	// Backend is the recommended compute backend ("cuda", "rocm", "metal", "opencl", "cpu").
	Backend string

	// UseFusedOps is true when fused kernel variants should be preferred.
	UseFusedOps bool

	// UseFlashAttention is true when flash attention is available and recommended.
	UseFlashAttention bool

	// MatMulThreads is the recommended number of threads for CPU GEMM.
	// Zero means use all available cores.
	MatMulThreads int

	// Reason is a human-readable summary of why this selection was made.
	Reason string
}

KernelSelection maps each kernel class to the optimal implementation for the detected hardware.

func SelectKernels

func SelectKernels(hw *HardwareProfile) *KernelSelection

SelectKernels chooses the optimal kernel implementation for each kernel class based on the hardware profile. It returns a KernelSelection that maps every kernel class to a concrete implementation.

type KernelTemplate

type KernelTemplate interface {
	// Configure returns hardware-optimized kernel parameters.
	Configure(profile *HardwareProfile) *KernelConfig
}

KernelTemplate produces an optimal KernelConfig for a given hardware profile.

type LoadBalanced

type LoadBalanced struct{}

LoadBalanced assigns workloads to the device with the lowest combined utilization and queue depth, filtering out devices without enough memory.

func (*LoadBalanced) Select

func (lb *LoadBalanced) Select(accelerators []AcceleratorInfo, workload Workload) int

Select picks the device with the lowest load score that has enough memory.

type Migration

type Migration struct {
	// WorkloadID identifies the workload to migrate.
	WorkloadID string

	// FromDevice is the source device ID.
	FromDevice int

	// ToDevice is the destination device ID.
	ToDevice int

	// Reason explains why this migration was suggested.
	Reason string
}

Migration describes a suggested workload migration from an overloaded device to an underloaded one.

type ModelSpec

type ModelSpec struct {
	// ParameterCount is the total number of parameters in the model.
	ParameterCount int64

	// OriginalFormat is the model's original weight format (e.g. "FP16", "BF16").
	// Used to avoid recommending a format larger than the original.
	OriginalFormat QuantFormat
}

ModelSpec describes a model's resource requirements for quantization recommendation.

type NextGenOptimizer

type NextGenOptimizer struct {
	// contains filtered or unexported fields
}

NextGenOptimizer selects optimal execution paths based on GPU generation.

func NewNextGenOptimizer

func NewNextGenOptimizer(hw *HardwareProfile) *NextGenOptimizer

NewNextGenOptimizer creates an optimizer for the given hardware profile. Returns nil if the profile does not describe a CUDA GPU.

func (*NextGenOptimizer) Describe

func (o *NextGenOptimizer) Describe() string

Describe returns a human-readable summary of the optimizer's configuration.

func (*NextGenOptimizer) Generation

func (o *NextGenOptimizer) Generation() GPUGeneration

Generation returns the detected GPU generation.

func (*NextGenOptimizer) SelectOptimalPath

func (o *NextGenOptimizer) SelectOptimalPath(op *Op) ExecutionPath

SelectOptimalPath picks the best execution path for an operation and GPU generation.

type Op

type Op struct {
	// Name identifies this operation (for debugging/display).
	Name string

	// Class is the kernel class (GEMM, attention, elementwise, etc.).
	Class KernelClass

	// M, N, K are the dimensions for matrix operations.
	// For elementwise ops, M*N gives the number of elements and K is unused.
	M, N, K int

	// MemoryBytes is the total memory footprint of inputs + outputs.
	MemoryBytes int64

	// OutputDeviceHint, if non-empty, suggests a preferred device for the output
	// (to reduce data transfer for downstream consumers).
	OutputDeviceHint string
}

Op represents a computation to be scheduled across devices.

func (*Op) FLOPs

func (op *Op) FLOPs() float64

FLOPs returns the estimated floating-point operations for this op.

type Preference

type Preference int

Preference expresses the user's priority between inference speed and output quality.

const (
	// PreferQuality favors higher-bit formats that preserve model accuracy.
	PreferQuality Preference = iota

	// PreferBalanced balances throughput and quality (default).
	PreferBalanced

	// PreferSpeed favors lower-bit formats that maximize tokens/second.
	PreferSpeed
)

type Priority

type Priority struct {
	// Order lists device types from most preferred to least preferred.
	Order []DeviceType
}

Priority prefers devices of specific types in a defined order. Within the same device type, it falls back to load-balanced selection.

func (*Priority) Select

func (p *Priority) Select(accelerators []AcceleratorInfo, workload Workload) int

Select picks the first device matching the highest-priority type that has enough memory and lowest load.

type QuantFormat

type QuantFormat string

QuantFormat identifies a quantization format for model weights.

const (
	QuantNVFP4 QuantFormat = "NVFP4"  // 4-bit NVIDIA FP4 (E2M1)
	QuantQ4K   QuantFormat = "Q4_K_M" // 4-bit K-quant (mixed precision)
	QuantQ5K   QuantFormat = "Q5_K_M" // 5-bit K-quant (mixed precision)
	QuantQ6K   QuantFormat = "Q6_K"   // 6-bit K-quant
	QuantQ8_0  QuantFormat = "Q8_0"   // 8-bit quantization
	QuantFP8   QuantFormat = "FP8"    // 8-bit floating point (E4M3FN)
	QuantBF16  QuantFormat = "BF16"   // Brain floating point 16
	QuantFP16  QuantFormat = "FP16"   // IEEE 754 half precision
)

Supported quantization formats ordered roughly from lowest to highest quality.

func (QuantFormat) BitsPerWeight

func (q QuantFormat) BitsPerWeight() float64

BitsPerWeight returns the approximate bits per weight for this format.

type Recommendation

type Recommendation struct {
	// Format is the recommended quantization format.
	Format QuantFormat

	// EstimatedVRAM is the estimated GPU memory usage in bytes.
	// Zero when the model fits in system RAM only.
	EstimatedVRAM int64

	// FitsInVRAM is true when the model fits entirely in GPU memory.
	FitsInVRAM bool

	// Reason is a human-readable explanation for the recommendation.
	Reason string
}

Recommendation is the result of RecommendQuant.

func RecommendQuant

func RecommendQuant(hw *HardwareProfile, model ModelSpec, pref Preference) Recommendation

RecommendQuant recommends the optimal quantization format for a model given the hardware profile and user preference.

type RoundRobin

type RoundRobin struct {
	// contains filtered or unexported fields
}

RoundRobin cycles through devices in order, skipping devices that lack sufficient memory for the workload.

func (*RoundRobin) Select

func (rr *RoundRobin) Select(accelerators []AcceleratorInfo, workload Workload) int

Select picks the next device in round-robin order that has enough memory.

type Scheduler

type Scheduler struct {
	// contains filtered or unexported fields
}

Scheduler manages workload scheduling across multiple accelerators.

func NewScheduler

func NewScheduler(accelerators []AcceleratorInfo, strategy SchedulingStrategy) *Scheduler

NewScheduler creates a scheduler with the given accelerators and strategy.

func (*Scheduler) AutoMigrate

func (s *Scheduler) AutoMigrate(threshold float64) []Migration

AutoMigrate suggests migrations from devices above the threshold to the least loaded device. The threshold is a utilization value (0.0-1.0).

func (*Scheduler) Schedule

func (s *Scheduler) Schedule(workload Workload) (int, error)

Schedule picks the best device for the given workload and records the assignment.

func (*Scheduler) UpdateUtilization

func (s *Scheduler) UpdateUtilization(deviceID int, utilization float64)

UpdateUtilization sets the utilization for the device with the given ID.

type SchedulingStrategy

type SchedulingStrategy interface {
	// Select picks the best device index (into the accelerators slice) for the workload.
	// Returns -1 if no suitable device is found.
	Select(accelerators []AcceleratorInfo, workload Workload) int
}

SchedulingStrategy selects a device for a given workload from a set of accelerators.

type SplitPlan

type SplitPlan struct {
	// Assignments maps each op (by index in the original slice) to a device.
	Assignments []DeviceAssignment

	// TotalEstimatedTimeNs is the estimated wall-clock time assuming
	// operations on different devices run in parallel.
	TotalEstimatedTimeNs float64

	// DeviceMemoryUsed tracks allocated memory per device.
	DeviceMemoryUsed map[string]int64
}

SplitPlan is the result of workload splitting: a set of device assignments.

func (*SplitPlan) String

func (sp *SplitPlan) String() string

String returns a human-readable summary of the split plan.

type SwizzlePattern

type SwizzlePattern int

SwizzlePattern defines the memory swizzle mode for TMA loads.

const (
	// SwizzleNone disables swizzling (linear access).
	SwizzleNone SwizzlePattern = iota

	// Swizzle32B applies 32-byte interleaving to reduce bank conflicts.
	Swizzle32B

	// Swizzle64B applies 64-byte interleaving.
	Swizzle64B

	// Swizzle128B applies 128-byte interleaving (best for large tiles).
	Swizzle128B
)

func RecommendSwizzle

func RecommendSwizzle(elementSizeBytes, tileWidth int) SwizzlePattern

RecommendSwizzle selects the optimal swizzle pattern based on element size and tile width. Larger swizzle patterns reduce bank conflicts for wider tiles.

func (SwizzlePattern) String

func (s SwizzlePattern) String() string

String returns the swizzle pattern name.

type TMAConfig

type TMAConfig struct {
	// Dim is the dimensionality of the TMA operation (2D or 3D).
	Dim TMADim

	// BoxDimX is the width of the TMA box in elements (innermost dimension).
	BoxDimX int

	// BoxDimY is the height of the TMA box in elements.
	BoxDimY int

	// BoxDimZ is the depth of the TMA box in elements (only used for TMA3D).
	BoxDimZ int

	// ElementSizeBytes is the size of each element in bytes (e.g. 2 for FP16, 4 for FP32).
	ElementSizeBytes int

	// Swizzle is the shared memory swizzle pattern.
	Swizzle SwizzlePattern

	// GlobalStride is the stride in bytes between rows in global memory.
	// Zero means the rows are contiguous (stride = BoxDimX * ElementSizeBytes).
	GlobalStride int64
}

TMAConfig describes a Tensor Memory Accelerator descriptor configuration for asynchronous bulk data movement between global and shared memory.

func (*TMAConfig) BoxBytes

func (c *TMAConfig) BoxBytes() int64

BoxBytes returns the total size of one TMA box in bytes.

func (*TMAConfig) Validate

func (c *TMAConfig) Validate() error

Validate checks that the TMA configuration is valid for SM90 hardware.

type TMADim

type TMADim int

TMADim specifies the dimensionality of a TMA descriptor.

const (
	TMA2D TMADim = 2
	TMA3D TMADim = 3
)

type WGMMAConfig

type WGMMAConfig struct {
	// M is the tile height (output rows per warp group).
	M int

	// N is the tile width (output columns per warp group).
	N int

	// K is the tile depth (reduction dimension per MMA step).
	K int

	// InputType is the data type for the A and B input matrices.
	InputType WGMMADataType

	// AccumulatorFP32 uses FP32 accumulators when true, FP16 when false.
	AccumulatorFP32 bool

	// TransposeB transposes the B matrix (column-major layout).
	TransposeB bool
}

WGMMAConfig describes the configuration for a warp-group matrix multiply-accumulate operation on SM90+ hardware. WGMMA allows a warp group (4 warps = 128 threads) to collectively execute large MMA operations using tensor cores.

func SelectWGMMATile

func SelectWGMMATile(m, n, k int, dt WGMMADataType) *WGMMAConfig

SelectWGMMATile chooses optimal WGMMA tile dimensions for a given GEMM problem size. Returns a WGMMAConfig with the best tile shape for the given M, N, K dimensions and data type.

func (*WGMMAConfig) OutputElements

func (c *WGMMAConfig) OutputElements() int

OutputElements returns the number of output elements per warp-group MMA step.

func (*WGMMAConfig) Validate

func (c *WGMMAConfig) Validate() error

Validate checks that the WGMMA configuration is valid for SM90 hardware.

type WGMMADataType

type WGMMADataType int

WGMMADataType identifies the input data type for warp-group MMA.

const (
	WGMMAFP16 WGMMADataType = iota
	WGMMABF16
	WGMMAFP8E4M3
	WGMMAFP8E5M2
	WGMMAINT8
)

func (WGMMADataType) String

func (dt WGMMADataType) String() string

String returns the data type name.

type Workload

type Workload struct {
	// ID uniquely identifies this workload.
	ID string

	// EstimatedFLOPS is the estimated floating-point operations required.
	EstimatedFLOPS float64

	// MemoryRequired is the memory needed in bytes.
	MemoryRequired int64

	// PreferredDevice is the preferred device type (optional, empty means no preference).
	PreferredDevice DeviceType
}

Workload describes a unit of work to be scheduled on an accelerator.

type WorkloadSplitter

type WorkloadSplitter struct {
	// contains filtered or unexported fields
}

WorkloadSplitter partitions operations across available devices.

func NewWorkloadSplitter

func NewWorkloadSplitter(devices []DeviceCapability) *WorkloadSplitter

NewWorkloadSplitter creates a splitter for the given set of devices.

func NewWorkloadSplitterWithCost

func NewWorkloadSplitterWithCost(devices []DeviceCapability, cost CostModel) *WorkloadSplitter

NewWorkloadSplitterWithCost creates a splitter with a custom cost model.

func (*WorkloadSplitter) Split

func (ws *WorkloadSplitter) Split(ops []Op) *SplitPlan

Split assigns each op to the device that minimizes estimated execution time, respecting device memory constraints and accounting for data transfer costs.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL