Documentation
¶
Overview ¶
Package autoopt provides automatic optimization recommendations based on hardware profiling. This file implements quantization format recommendation.
Index ¶
- func EstimateWGMMAIterations(m, n, k int, cfg *WGMMAConfig) int
- func IsTMACompatible(rows, cols, elementSizeBytes int) bool
- type AcceleratorInfo
- type BlackwellCapabilities
- type CostModel
- type DeviceAssignment
- type DeviceCapability
- type DeviceType
- type ElementwiseTemplate
- type ExecutionPath
- type GEMMTemplate
- type GEMVTemplate
- type GPUGeneration
- type HardwareProfile
- type HopperCapabilities
- type KernelClass
- type KernelCodegen
- type KernelConfig
- type KernelImpl
- type KernelSelection
- type KernelTemplate
- type LoadBalanced
- type Migration
- type ModelSpec
- type NextGenOptimizer
- type Op
- type Preference
- type Priority
- type QuantFormat
- type Recommendation
- type RoundRobin
- type Scheduler
- type SchedulingStrategy
- type SplitPlan
- type SwizzlePattern
- type TMAConfig
- type TMADim
- type WGMMAConfig
- type WGMMADataType
- type Workload
- type WorkloadSplitter
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func EstimateWGMMAIterations ¶
func EstimateWGMMAIterations(m, n, k int, cfg *WGMMAConfig) int
EstimateWGMMAIterations returns the number of warp-group MMA iterations needed to cover a GEMM of the given dimensions with the selected tile.
func IsTMACompatible ¶
IsTMACompatible checks whether a tensor with the given dimensions and element size can be efficiently loaded via TMA. Returns true if the layout satisfies alignment and size constraints.
Types ¶
type AcceleratorInfo ¶
type AcceleratorInfo struct {
// DeviceID is a unique identifier for this device.
DeviceID int
// Type is the accelerator backend (cuda, rocm, metal, cpu).
Type DeviceType
// AvailableMemory is the amount of free memory in bytes.
AvailableMemory int64
// Utilization is the current device load from 0.0 (idle) to 1.0 (fully loaded).
Utilization float64
// QueueDepth is the number of workloads currently queued on this device.
QueueDepth int
}
AcceleratorInfo describes a single accelerator device available for scheduling.
type BlackwellCapabilities ¶
type BlackwellCapabilities struct {
// TMA indicates Tensor Memory Accelerator support (inherited from Hopper).
TMA bool
// WGMMA indicates warp-group MMA support (inherited from Hopper).
WGMMA bool
// FP8Native indicates FP8 tensor core support (inherited from Hopper).
FP8Native bool
// FP4Aware indicates the runtime can exploit FP4 quantized weights
// via native tensor core instructions.
FP4Aware bool
// ClusterPrimitives indicates support for cluster-level synchronization
// and collective operations across thread block clusters.
ClusterPrimitives bool
// MaxClusterSize is the maximum cluster size (typically 16 or 32).
MaxClusterSize int
SharedMemoryBytes int64
}
BlackwellCapabilities describes SM100+-class GPU features.
func DetectBlackwellCapabilities ¶
func DetectBlackwellCapabilities(gen GPUGeneration) *BlackwellCapabilities
DetectBlackwellCapabilities returns the Blackwell feature set for SM100+ GPUs. Returns nil if the generation is not Blackwell.
type CostModel ¶
type CostModel struct {
// TransferBandwidth is the device-to-device transfer rate in bytes/sec.
// Defaults to 16 GB/s (PCIe Gen4 x16) if zero.
TransferBandwidth float64
}
CostModel estimates the execution time of an Op on a given device.
func (*CostModel) EstimateTimeNs ¶
func (cm *CostModel) EstimateTimeNs(op *Op, dev *DeviceCapability) float64
EstimateTimeNs returns the estimated execution time in nanoseconds for the given op on the given device.
type DeviceAssignment ¶
type DeviceAssignment struct {
// Op is the operation being assigned.
Op Op
// DeviceID is the ID of the device this op is assigned to.
DeviceID string
// EstimatedTimeNs is the estimated execution time in nanoseconds.
EstimatedTimeNs float64
// TransferCostNs is the estimated data transfer overhead in nanoseconds.
TransferCostNs float64
}
DeviceAssignment maps an op to a device with its estimated execution cost.
type DeviceCapability ¶
type DeviceCapability struct {
// ID uniquely identifies this device (e.g. "cpu:0", "cuda:0", "cuda:1").
ID string
// Profile is the underlying hardware profile for this device.
Profile *HardwareProfile
// FLOPS is the estimated peak floating-point operations per second.
FLOPS float64
// MemoryBandwidth is the estimated memory bandwidth in bytes per second.
MemoryBandwidth float64
// AvailableMemory is the usable memory in bytes for tensor storage.
AvailableMemory int64
// IsGPU is true when this device is a GPU accelerator.
IsGPU bool
}
DeviceCapability wraps a HardwareProfile with computed performance estimates.
func NewDeviceCapability ¶
func NewDeviceCapability(id string, hw *HardwareProfile, isGPU bool) DeviceCapability
NewDeviceCapability computes performance estimates from a HardwareProfile. The isGPU flag indicates whether this represents a GPU device.
type DeviceType ¶
type DeviceType string
DeviceType identifies the accelerator backend.
const ( DeviceCUDA DeviceType = "cuda" DeviceROCm DeviceType = "rocm" DeviceMetal DeviceType = "metal" DeviceCPU DeviceType = "cpu" )
type ElementwiseTemplate ¶
type ElementwiseTemplate struct {
// NumElements is the total number of elements to process.
NumElements int
}
ElementwiseTemplate generates element-wise kernel configurations. Maximizes occupancy with simple grid/block sizing.
func (*ElementwiseTemplate) Configure ¶
func (t *ElementwiseTemplate) Configure(profile *HardwareProfile) *KernelConfig
Configure produces an element-wise kernel config tuned to the hardware profile.
type ExecutionPath ¶
type ExecutionPath string
ExecutionPath identifies which kernel dispatch strategy to use for an operation.
const ( // PathStandard uses the baseline CUDA kernel (Ampere or older). PathStandard ExecutionPath = "standard" // PathTMA uses Tensor Memory Accelerator for async bulk data movement. PathTMA ExecutionPath = "tma" // PathWGMMA uses warp-group MMA for matrix operations. PathWGMMA ExecutionPath = "wgmma" // PathTMAWGMMA combines TMA loads with warp-group MMA compute. PathTMAWGMMA ExecutionPath = "tma_wgmma" // PathFP4Cluster uses FP4-aware cluster-level execution on Blackwell. PathFP4Cluster ExecutionPath = "fp4_cluster" )
type GEMMTemplate ¶
type GEMMTemplate struct {
// M, N, K are the matrix dimensions for the GEMM operation.
M, N, K int
}
GEMMTemplate generates GEMM kernel configurations. Tile sizes and shared memory usage scale with GPU compute capability and available resources.
func (*GEMMTemplate) Configure ¶
func (t *GEMMTemplate) Configure(profile *HardwareProfile) *KernelConfig
Configure produces a GEMM kernel config tuned to the hardware profile.
type GEMVTemplate ¶
type GEMVTemplate struct {
// Rows and Cols are the matrix dimensions.
Rows, Cols int
}
GEMVTemplate generates matrix-vector multiply kernel configurations. Vectorization is chosen based on SIMD width.
func (*GEMVTemplate) Configure ¶
func (t *GEMVTemplate) Configure(profile *HardwareProfile) *KernelConfig
Configure produces a GEMV kernel config tuned to the hardware profile.
type GPUGeneration ¶
type GPUGeneration int
GPUGeneration identifies a class of NVIDIA GPU microarchitecture.
const ( // GPUGenAmpere represents SM 8.x GPUs (A100, RTX 30xx/40xx Ada). GPUGenAmpere GPUGeneration = iota // GPUGenHopper represents SM 9.0 GPUs (H100, H200). GPUGenHopper // GPUGenBlackwell represents SM 10.0+ GPUs (B100, B200, GB200). GPUGenBlackwell // GPUGenUnknown is returned for GPUs older than Ampere or unrecognized SM versions. GPUGenUnknown GPUGeneration = -1 )
func DetectGPUGeneration ¶
func DetectGPUGeneration(computeCap string) GPUGeneration
DetectGPUGeneration maps a CUDA compute capability string to a GPUGeneration.
func (GPUGeneration) String ¶
func (g GPUGeneration) String() string
String returns the generation name.
type HardwareProfile ¶
type HardwareProfile struct {
// CPU
CPUCores int // logical CPU count (GOMAXPROCS-visible)
CPUModel string // human-readable CPU model string
HasNEON bool // ARM SIMD (Neon)
HasAVX2 bool // x86 SIMD (AVX2)
HasAVX512 bool // x86 advanced SIMD (AVX-512)
CacheL1 int64 // L1 data cache size in bytes (0 if unknown)
CacheL2 int64 // L2 cache size in bytes (0 if unknown)
CacheL3 int64 // L3 cache size in bytes (0 if unknown)
TotalRAM int64 // total physical memory in bytes
// GPU
GPUAvailable bool // true if a usable GPU was detected
GPUBackend string // "cuda", "rocm", "metal", "opencl", or ""
GPUName string // human-readable GPU name
GPUMemory int64 // GPU memory in bytes (0 if unknown)
GPUComputeCap string // e.g. "8.9" for CUDA compute capability
MultiGPU bool // true if more than one GPU is available
GPUCount int // number of GPUs (0 if none)
}
HardwareProfile describes the CPU and GPU capabilities of the current system. This mirrors github.com/zerfoo/ztensor/compute.HardwareProfile and will be replaced by a direct import once the ztensor dependency is updated.
func DefaultProfile ¶
func DefaultProfile() *HardwareProfile
DefaultProfile returns a minimal hardware profile based on runtime detection.
type HopperCapabilities ¶
type HopperCapabilities struct {
// TMA indicates Tensor Memory Accelerator support for async bulk copies.
TMA bool
// WGMMA indicates warp-group matrix multiply-accumulate support.
WGMMA bool
// FP8Native indicates hardware-native FP8 (E4M3/E5M2) tensor core support.
FP8Native bool
// ClusterSize is the maximum thread block cluster size (typically 8 or 16).
ClusterSize int
SharedMemoryBytes int64
}
HopperCapabilities describes SM90-class GPU features.
func DetectHopperCapabilities ¶
func DetectHopperCapabilities(gen GPUGeneration) *HopperCapabilities
DetectHopperCapabilities returns the Hopper feature set for SM90 GPUs. Returns nil if the generation is not Hopper or later.
type KernelClass ¶
type KernelClass string
KernelClass identifies a category of computation for which multiple implementation strategies exist (e.g. GEMM, attention, normalization).
const ( KernelGEMM KernelClass = "gemm" // general matrix multiply KernelGEMV KernelClass = "gemv" // matrix-vector multiply KernelAttention KernelClass = "attention" // scaled dot-product attention KernelRMSNorm KernelClass = "rmsnorm" // RMS normalization KernelSoftmax KernelClass = "softmax" // softmax KernelRoPE KernelClass = "rope" // rotary positional embedding KernelSiLU KernelClass = "silu" // SiLU/SwiGLU activation KernelElementwise KernelClass = "elementwise" // element-wise add/mul/etc. KernelQuantGEMM KernelClass = "quant_gemm" // quantized GEMM (Q4/Q8) KernelQuantDot KernelClass = "quant_dot" // quantized dot product )
type KernelCodegen ¶
type KernelCodegen struct {
// contains filtered or unexported fields
}
KernelCodegen generates hardware-optimized kernel configurations using templates selected by kernel class and tuned to the hardware profile.
func NewKernelCodegen ¶
func NewKernelCodegen(profile *HardwareProfile) *KernelCodegen
NewKernelCodegen creates a KernelCodegen bound to the given hardware profile.
func (*KernelCodegen) GenerateConfig ¶
func (c *KernelCodegen) GenerateConfig(class KernelClass, dims ...int) *KernelConfig
GenerateConfig selects a template for the given kernel class, configures it for the bound hardware profile, and returns the resulting KernelConfig.
The dims parameter depends on the kernel class:
- KernelGEMM: [M, N, K]
- KernelGEMV: [rows, cols]
- KernelElementwise: [numElements]
- Other classes: [numElements] (uses elementwise template)
func (*KernelCodegen) GenerateLaunchParams ¶
func (c *KernelCodegen) GenerateLaunchParams(class KernelClass, totalElements int) (gridDim, blockDim [3]int)
GenerateLaunchParams computes grid and block dimensions for launching a kernel of the given class over totalElements work items.
type KernelConfig ¶
type KernelConfig struct {
// Tile sizes for tiled algorithms (GEMM, GEMV).
TileM int
TileN int
TileK int
// UnrollFactor controls loop unrolling depth.
UnrollFactor int
SharedMemBytes int
// RegistersPerThread is the target register usage per thread (GPU).
RegistersPerThread int
// GridDim is the compute grid dimensions [x, y, z].
GridDim [3]int
// BlockDim is the block/threadgroup dimensions [x, y, z].
BlockDim [3]int
// VectorizationWidth is the SIMD vector width in number of float32 elements.
VectorizationWidth int
}
KernelConfig holds hardware-optimized kernel launch parameters and tile sizes for a specific kernel class running on specific hardware.
func (*KernelConfig) String ¶
func (cfg *KernelConfig) String() string
String returns a human-readable summary of the kernel config.
type KernelImpl ¶
type KernelImpl string
KernelImpl identifies a specific implementation strategy for a kernel class.
const ( // Backend implementations ImplGenericCPU KernelImpl = "generic_cpu" // scalar Go fallback ImplNEON KernelImpl = "neon" // ARM NEON SIMD assembly ImplAVX2 KernelImpl = "avx2" // x86 AVX2 SIMD assembly ImplAVX512 KernelImpl = "avx512" // x86 AVX-512 SIMD assembly ImplCUDA KernelImpl = "cuda" // CUDA GPU kernel ImplCUDAFused KernelImpl = "cuda_fused" // CUDA fused kernel (e.g. FlashAttention) ImplROCm KernelImpl = "rocm" // ROCm/HIP GPU kernel ImplROCmFused KernelImpl = "rocm_fused" // ROCm fused kernel ImplMetal KernelImpl = "metal" // Apple Metal GPU kernel ImplOpenCL KernelImpl = "opencl" // OpenCL GPU kernel )
type KernelSelection ¶
type KernelSelection struct {
// Selections maps each kernel class to its chosen implementation.
Selections map[KernelClass]KernelImpl
// Backend is the recommended compute backend ("cuda", "rocm", "metal", "opencl", "cpu").
Backend string
// UseFusedOps is true when fused kernel variants should be preferred.
UseFusedOps bool
// UseFlashAttention is true when flash attention is available and recommended.
UseFlashAttention bool
// MatMulThreads is the recommended number of threads for CPU GEMM.
// Zero means use all available cores.
MatMulThreads int
// Reason is a human-readable summary of why this selection was made.
Reason string
}
KernelSelection maps each kernel class to the optimal implementation for the detected hardware.
func SelectKernels ¶
func SelectKernels(hw *HardwareProfile) *KernelSelection
SelectKernels chooses the optimal kernel implementation for each kernel class based on the hardware profile. It returns a KernelSelection that maps every kernel class to a concrete implementation.
type KernelTemplate ¶
type KernelTemplate interface {
// Configure returns hardware-optimized kernel parameters.
Configure(profile *HardwareProfile) *KernelConfig
}
KernelTemplate produces an optimal KernelConfig for a given hardware profile.
type LoadBalanced ¶
type LoadBalanced struct{}
LoadBalanced assigns workloads to the device with the lowest combined utilization and queue depth, filtering out devices without enough memory.
func (*LoadBalanced) Select ¶
func (lb *LoadBalanced) Select(accelerators []AcceleratorInfo, workload Workload) int
Select picks the device with the lowest load score that has enough memory.
type Migration ¶
type Migration struct {
// WorkloadID identifies the workload to migrate.
WorkloadID string
// FromDevice is the source device ID.
FromDevice int
// ToDevice is the destination device ID.
ToDevice int
// Reason explains why this migration was suggested.
Reason string
}
Migration describes a suggested workload migration from an overloaded device to an underloaded one.
type ModelSpec ¶
type ModelSpec struct {
// ParameterCount is the total number of parameters in the model.
ParameterCount int64
// OriginalFormat is the model's original weight format (e.g. "FP16", "BF16").
// Used to avoid recommending a format larger than the original.
OriginalFormat QuantFormat
}
ModelSpec describes a model's resource requirements for quantization recommendation.
type NextGenOptimizer ¶
type NextGenOptimizer struct {
// contains filtered or unexported fields
}
NextGenOptimizer selects optimal execution paths based on GPU generation.
func NewNextGenOptimizer ¶
func NewNextGenOptimizer(hw *HardwareProfile) *NextGenOptimizer
NewNextGenOptimizer creates an optimizer for the given hardware profile. Returns nil if the profile does not describe a CUDA GPU.
func (*NextGenOptimizer) Describe ¶
func (o *NextGenOptimizer) Describe() string
Describe returns a human-readable summary of the optimizer's configuration.
func (*NextGenOptimizer) Generation ¶
func (o *NextGenOptimizer) Generation() GPUGeneration
Generation returns the detected GPU generation.
func (*NextGenOptimizer) SelectOptimalPath ¶
func (o *NextGenOptimizer) SelectOptimalPath(op *Op) ExecutionPath
SelectOptimalPath picks the best execution path for an operation and GPU generation.
type Op ¶
type Op struct {
// Name identifies this operation (for debugging/display).
Name string
// Class is the kernel class (GEMM, attention, elementwise, etc.).
Class KernelClass
// M, N, K are the dimensions for matrix operations.
// For elementwise ops, M*N gives the number of elements and K is unused.
M, N, K int
// MemoryBytes is the total memory footprint of inputs + outputs.
MemoryBytes int64
// OutputDeviceHint, if non-empty, suggests a preferred device for the output
// (to reduce data transfer for downstream consumers).
OutputDeviceHint string
}
Op represents a computation to be scheduled across devices.
type Preference ¶
type Preference int
Preference expresses the user's priority between inference speed and output quality.
const ( // PreferQuality favors higher-bit formats that preserve model accuracy. PreferQuality Preference = iota // PreferBalanced balances throughput and quality (default). PreferBalanced // PreferSpeed favors lower-bit formats that maximize tokens/second. PreferSpeed )
type Priority ¶
type Priority struct {
// Order lists device types from most preferred to least preferred.
Order []DeviceType
}
Priority prefers devices of specific types in a defined order. Within the same device type, it falls back to load-balanced selection.
type QuantFormat ¶
type QuantFormat string
QuantFormat identifies a quantization format for model weights.
const ( QuantNVFP4 QuantFormat = "NVFP4" // 4-bit NVIDIA FP4 (E2M1) QuantQ4K QuantFormat = "Q4_K_M" // 4-bit K-quant (mixed precision) QuantQ5K QuantFormat = "Q5_K_M" // 5-bit K-quant (mixed precision) QuantQ6K QuantFormat = "Q6_K" // 6-bit K-quant QuantQ8_0 QuantFormat = "Q8_0" // 8-bit quantization QuantFP8 QuantFormat = "FP8" // 8-bit floating point (E4M3FN) QuantBF16 QuantFormat = "BF16" // Brain floating point 16 QuantFP16 QuantFormat = "FP16" // IEEE 754 half precision )
Supported quantization formats ordered roughly from lowest to highest quality.
func (QuantFormat) BitsPerWeight ¶
func (q QuantFormat) BitsPerWeight() float64
BitsPerWeight returns the approximate bits per weight for this format.
type Recommendation ¶
type Recommendation struct {
// Format is the recommended quantization format.
Format QuantFormat
// EstimatedVRAM is the estimated GPU memory usage in bytes.
// Zero when the model fits in system RAM only.
EstimatedVRAM int64
// FitsInVRAM is true when the model fits entirely in GPU memory.
FitsInVRAM bool
// Reason is a human-readable explanation for the recommendation.
Reason string
}
Recommendation is the result of RecommendQuant.
func RecommendQuant ¶
func RecommendQuant(hw *HardwareProfile, model ModelSpec, pref Preference) Recommendation
RecommendQuant recommends the optimal quantization format for a model given the hardware profile and user preference.
type RoundRobin ¶
type RoundRobin struct {
// contains filtered or unexported fields
}
RoundRobin cycles through devices in order, skipping devices that lack sufficient memory for the workload.
func (*RoundRobin) Select ¶
func (rr *RoundRobin) Select(accelerators []AcceleratorInfo, workload Workload) int
Select picks the next device in round-robin order that has enough memory.
type Scheduler ¶
type Scheduler struct {
// contains filtered or unexported fields
}
Scheduler manages workload scheduling across multiple accelerators.
func NewScheduler ¶
func NewScheduler(accelerators []AcceleratorInfo, strategy SchedulingStrategy) *Scheduler
NewScheduler creates a scheduler with the given accelerators and strategy.
func (*Scheduler) AutoMigrate ¶
AutoMigrate suggests migrations from devices above the threshold to the least loaded device. The threshold is a utilization value (0.0-1.0).
func (*Scheduler) Schedule ¶
Schedule picks the best device for the given workload and records the assignment.
func (*Scheduler) UpdateUtilization ¶
UpdateUtilization sets the utilization for the device with the given ID.
type SchedulingStrategy ¶
type SchedulingStrategy interface {
// Select picks the best device index (into the accelerators slice) for the workload.
// Returns -1 if no suitable device is found.
Select(accelerators []AcceleratorInfo, workload Workload) int
}
SchedulingStrategy selects a device for a given workload from a set of accelerators.
type SplitPlan ¶
type SplitPlan struct {
// Assignments maps each op (by index in the original slice) to a device.
Assignments []DeviceAssignment
// TotalEstimatedTimeNs is the estimated wall-clock time assuming
// operations on different devices run in parallel.
TotalEstimatedTimeNs float64
// DeviceMemoryUsed tracks allocated memory per device.
DeviceMemoryUsed map[string]int64
}
SplitPlan is the result of workload splitting: a set of device assignments.
type SwizzlePattern ¶
type SwizzlePattern int
SwizzlePattern defines the memory swizzle mode for TMA loads.
const ( // SwizzleNone disables swizzling (linear access). SwizzleNone SwizzlePattern = iota // Swizzle32B applies 32-byte interleaving to reduce bank conflicts. Swizzle32B // Swizzle64B applies 64-byte interleaving. Swizzle64B // Swizzle128B applies 128-byte interleaving (best for large tiles). Swizzle128B )
func RecommendSwizzle ¶
func RecommendSwizzle(elementSizeBytes, tileWidth int) SwizzlePattern
RecommendSwizzle selects the optimal swizzle pattern based on element size and tile width. Larger swizzle patterns reduce bank conflicts for wider tiles.
func (SwizzlePattern) String ¶
func (s SwizzlePattern) String() string
String returns the swizzle pattern name.
type TMAConfig ¶
type TMAConfig struct {
// Dim is the dimensionality of the TMA operation (2D or 3D).
Dim TMADim
// BoxDimX is the width of the TMA box in elements (innermost dimension).
BoxDimX int
// BoxDimY is the height of the TMA box in elements.
BoxDimY int
// BoxDimZ is the depth of the TMA box in elements (only used for TMA3D).
BoxDimZ int
// ElementSizeBytes is the size of each element in bytes (e.g. 2 for FP16, 4 for FP32).
ElementSizeBytes int
// Swizzle is the shared memory swizzle pattern.
Swizzle SwizzlePattern
// GlobalStride is the stride in bytes between rows in global memory.
// Zero means the rows are contiguous (stride = BoxDimX * ElementSizeBytes).
GlobalStride int64
}
TMAConfig describes a Tensor Memory Accelerator descriptor configuration for asynchronous bulk data movement between global and shared memory.
type WGMMAConfig ¶
type WGMMAConfig struct {
// M is the tile height (output rows per warp group).
M int
// N is the tile width (output columns per warp group).
N int
// K is the tile depth (reduction dimension per MMA step).
K int
// InputType is the data type for the A and B input matrices.
InputType WGMMADataType
// AccumulatorFP32 uses FP32 accumulators when true, FP16 when false.
AccumulatorFP32 bool
// TransposeB transposes the B matrix (column-major layout).
TransposeB bool
}
WGMMAConfig describes the configuration for a warp-group matrix multiply-accumulate operation on SM90+ hardware. WGMMA allows a warp group (4 warps = 128 threads) to collectively execute large MMA operations using tensor cores.
func SelectWGMMATile ¶
func SelectWGMMATile(m, n, k int, dt WGMMADataType) *WGMMAConfig
SelectWGMMATile chooses optimal WGMMA tile dimensions for a given GEMM problem size. Returns a WGMMAConfig with the best tile shape for the given M, N, K dimensions and data type.
func (*WGMMAConfig) OutputElements ¶
func (c *WGMMAConfig) OutputElements() int
OutputElements returns the number of output elements per warp-group MMA step.
func (*WGMMAConfig) Validate ¶
func (c *WGMMAConfig) Validate() error
Validate checks that the WGMMA configuration is valid for SM90 hardware.
type WGMMADataType ¶
type WGMMADataType int
WGMMADataType identifies the input data type for warp-group MMA.
const ( WGMMAFP16 WGMMADataType = iota WGMMABF16 WGMMAFP8E4M3 WGMMAFP8E5M2 WGMMAINT8 )
func (WGMMADataType) String ¶
func (dt WGMMADataType) String() string
String returns the data type name.
type Workload ¶
type Workload struct {
// ID uniquely identifies this workload.
ID string
// EstimatedFLOPS is the estimated floating-point operations required.
EstimatedFLOPS float64
// MemoryRequired is the memory needed in bytes.
MemoryRequired int64
// PreferredDevice is the preferred device type (optional, empty means no preference).
PreferredDevice DeviceType
}
Workload describes a unit of work to be scheduled on an accelerator.
type WorkloadSplitter ¶
type WorkloadSplitter struct {
// contains filtered or unexported fields
}
WorkloadSplitter partitions operations across available devices.
func NewWorkloadSplitter ¶
func NewWorkloadSplitter(devices []DeviceCapability) *WorkloadSplitter
NewWorkloadSplitter creates a splitter for the given set of devices.
func NewWorkloadSplitterWithCost ¶
func NewWorkloadSplitterWithCost(devices []DeviceCapability, cost CostModel) *WorkloadSplitter
NewWorkloadSplitterWithCost creates a splitter with a custom cost model.
func (*WorkloadSplitter) Split ¶
func (ws *WorkloadSplitter) Split(ops []Op) *SplitPlan
Split assigns each op to the device that minimizes estimated execution time, respecting device memory constraints and accounting for data transfer costs.