gpuapi

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 16, 2026 License: Apache-2.0 Imports: 18 Imported by: 0

Documentation

Overview

Package gpuapi defines internal interfaces for GPU runtime operations.

The GPU Runtime Abstraction Layer (GRAL) decouples GPUEngine and GPUStorage from vendor-specific APIs (CUDA, ROCm, OpenCL). Each vendor implements the Runtime, BLAS, DNN, and KernelRunner interfaces via adapter packages.

These interfaces are internal and not exported to users. The public Engine[T] interface in compute/ is unchanged.

Index

Constants

This section is empty.

Variables

View Source
var BLASFactory func() (BLAS, error)

BLASFactory creates a BLAS instance. Registered by cuda_blas_purego.go via init() when the cuBLAS library is available at runtime.

View Source
var DNNFactory func() (DNN, error)

DNNFactory creates a DNN instance. Registered by cuda_dnn.go via init().

Functions

func PrintCUBLASProfile

func PrintCUBLASProfile()

PrintCUBLASProfile prints the cuBLAS profiling summary if profiling is enabled. Safe to call even when profiling is disabled (no-op).

Types

type ActivationMode

type ActivationMode int

ActivationMode selects the activation function for DNN operations.

const (
	ActivationSigmoid ActivationMode = iota
	ActivationReLU
	ActivationTanh
	ActivationClippedReLU
	ActivationELU
)

type BLAS

type BLAS interface {
	// Sgemm performs single-precision general matrix multiplication:
	//   C = alpha * A * B + beta * C
	// where A is m x k, B is k x n, and C is m x n.
	// All matrices are contiguous row-major. The implementation handles
	// the row-major to column-major conversion internally.
	Sgemm(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error

	// BFloat16Gemm performs BFloat16 general matrix multiplication:
	//   C = alpha * A * B + beta * C
	// where A is m x k, B is k x n, and C is m x n.
	// All matrices are contiguous row-major BFloat16 elements.
	// Computation is performed in float32 for precision (CUBLAS_COMPUTE_32F).
	// Returns an error on backends that do not support BFloat16 GEMM.
	BFloat16Gemm(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error

	// Float16Gemm performs FP16 general matrix multiplication:
	//   C = alpha * A * B + beta * C
	// where A is m x k, B is k x n, and C is m x n.
	// All matrices are contiguous row-major FP16 elements.
	// Computation is performed in float32 for precision (CUBLAS_COMPUTE_32F).
	Float16Gemm(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error

	// MixedFP16Gemm performs mixed-precision GEMM with FP16 inputs and FP32 output:
	//   C_f32 = alpha * A_fp16 * B_fp16 + beta * C_f32
	MixedFP16Gemm(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error

	// MixedBF16Gemm performs mixed-precision GEMM with BF16 weights and FP32 output:
	//   C_f32 = alpha * A_bf16 * B_bf16 + beta * C_f32
	// where A is m x k, B is k x n (both BFloat16), and C is m x n (float32).
	// Computation uses CUBLAS_COMPUTE_32F for precision.
	// Returns an error on backends that do not support mixed-precision GEMM.
	MixedBF16Gemm(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error

	// SetStream associates the BLAS handle with an asynchronous stream.
	SetStream(stream Stream) error

	// Destroy releases the BLAS handle resources.
	Destroy() error
}

BLAS abstracts GPU-accelerated Basic Linear Algebra Subprograms. Each vendor (cuBLAS, rocBLAS, CLBlast) provides an implementation.

type BLASBatched

type BLASBatched interface {
	SgemmStridedBatched(m, n, k int, alpha float32,
		a unsafe.Pointer, strideA int64,
		b unsafe.Pointer, strideB int64,
		beta float32,
		c unsafe.Pointer, strideC int64,
		batch int,
	) error
}

BLASBatched is an optional extension that supports strided batched GEMM. All batch elements share the same m, n, k dimensions and alpha/beta scalars. Matrices are accessed at base + i*stride for batch element i.

type BLASBatchedTransposeB

type BLASBatchedTransposeB interface {
	SgemmNTStridedBatched(m, n, k int, alpha float32,
		a unsafe.Pointer, strideA int64,
		b unsafe.Pointer, strideB int64,
		beta float32,
		c unsafe.Pointer, strideC int64,
		batch int,
	) error
}

BLASBatchedTransposeB is an optional extension that supports strided batched C = A * B^T without explicitly transposing B. A is [m, k], B is [n, k] per batch.

type BLASTransposeB

type BLASTransposeB interface {
	SgemmNT(m, n, k int, alpha float32,
		a unsafe.Pointer, b unsafe.Pointer,
		beta float32, c unsafe.Pointer,
	) error
}

BLASTransposeB is an optional extension that supports computing C = alpha * A * B^T + beta * C without explicitly transposing B. A is m x k (row-major), B is n x k (row-major), C is m x n.

type BatchNormMode

type BatchNormMode int

BatchNormMode selects the batch normalization mode.

const (
	BatchNormPerActivation BatchNormMode = iota
	BatchNormSpatial
)

type CUDAArenaPool

type CUDAArenaPool struct {
	// contains filtered or unexported fields
}

CUDAArenaPool adapts cuda.ArenaPool to the gpuapi.MemPool interface. It also exposes Reset() for use between forward passes.

func NewCUDAArenaPool

func NewCUDAArenaPool(deviceID, capacityBytes int, fallback *cuda.MemPool) (*CUDAArenaPool, error)

NewCUDAArenaPool creates a new arena-backed pool on the given device. capacityBytes is the size of the pre-allocated arena region. fallback is the MemPool used when the arena is exhausted.

func (*CUDAArenaPool) Alloc

func (p *CUDAArenaPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)

func (*CUDAArenaPool) AllocManaged

func (p *CUDAArenaPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)

func (*CUDAArenaPool) Drain

func (p *CUDAArenaPool) Drain() error

func (*CUDAArenaPool) Free

func (p *CUDAArenaPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)

func (*CUDAArenaPool) FreeManaged

func (p *CUDAArenaPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)

func (*CUDAArenaPool) Inner

func (p *CUDAArenaPool) Inner() *cuda.ArenaPool

Inner returns the underlying cuda.ArenaPool.

func (*CUDAArenaPool) Reset

func (p *CUDAArenaPool) Reset()

Reset rewinds the arena, reclaiming all per-pass allocations.

func (*CUDAArenaPool) SetResetFloor

func (p *CUDAArenaPool) SetResetFloor(floor int)

SetResetFloor sets the minimum offset that Reset will rewind to.

func (*CUDAArenaPool) Stats

func (p *CUDAArenaPool) Stats() (int, int)

func (*CUDAArenaPool) UsedBytes

func (p *CUDAArenaPool) UsedBytes() int

UsedBytes returns the current arena offset (bytes in use).

type CUDABlas

type CUDABlas struct {
	// contains filtered or unexported fields
}

CUDABlas implements the BLAS interface using cuBLAS via purego.

func NewCUDABlas

func NewCUDABlas() (*CUDABlas, error)

NewCUDABlas creates a new cuBLAS adapter. The caller must call Destroy when done.

func NewCUDABlasFromHandle

func NewCUDABlasFromHandle(h *cublas.Handle) *CUDABlas

NewCUDABlasFromHandle wraps an existing cuBLAS handle. The caller retains ownership; Destroy on this adapter is a no-op.

func (*CUDABlas) BFloat16Gemm

func (b *CUDABlas) BFloat16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlas) Destroy

func (b *CUDABlas) Destroy() error

func (*CUDABlas) Float16Gemm

func (b *CUDABlas) Float16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlas) Handle

func (b *CUDABlas) Handle() *cublas.Handle

Handle returns the underlying cuBLAS handle for backward compatibility.

func (*CUDABlas) MixedBF16Gemm

func (b *CUDABlas) MixedBF16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlas) MixedFP16Gemm

func (b *CUDABlas) MixedFP16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlas) SetStream

func (b *CUDABlas) SetStream(stream Stream) error

func (*CUDABlas) Sgemm

func (b *CUDABlas) Sgemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlas) SgemmNT

func (b *CUDABlas) SgemmNT(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

SgemmNT performs C = alpha * A * B^T + beta * C where A is [m, k] and B is [n, k] (row-major). This avoids an explicit Transpose of B.

func (*CUDABlas) SgemmNTStridedBatched

func (b *CUDABlas) SgemmNTStridedBatched(m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	bPtr unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

SgemmNTStridedBatched performs batched C = A * B^T using strided batched GEMM.

func (*CUDABlas) SgemmStridedBatched

func (b *CUDABlas) SgemmStridedBatched(m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	bPtr unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

SgemmStridedBatched performs batched C = alpha * A * B + beta * C using cublasSgemmStridedBatched. All batch elements share the same m, n, k.

type CUDABlasProfiler

type CUDABlasProfiler struct {
	// contains filtered or unexported fields
}

CUDABlasProfiler wraps CUDABlas with optional per-call timing. When ZERFOO_PROFILE_CUBLAS=1, each Sgemm/SgemmNT/batched call is timed and recorded. Call PrintSummary to dump stats.

func WrapWithProfiler

func WrapWithProfiler(b *CUDABlas) *CUDABlasProfiler

WrapWithProfiler returns a profiling wrapper if ZERFOO_PROFILE_CUBLAS=1, otherwise returns the original CUDABlas unchanged.

func (*CUDABlasProfiler) BFloat16Gemm

func (p *CUDABlasProfiler) BFloat16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) Destroy

func (p *CUDABlasProfiler) Destroy() error

func (*CUDABlasProfiler) Float16Gemm

func (p *CUDABlasProfiler) Float16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) Handle

func (p *CUDABlasProfiler) Handle() *CUDABlas

func (*CUDABlasProfiler) IsEnabled

func (p *CUDABlasProfiler) IsEnabled() bool

IsEnabled returns whether profiling is active.

func (*CUDABlasProfiler) MixedBF16Gemm

func (p *CUDABlasProfiler) MixedBF16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) MixedFP16Gemm

func (p *CUDABlasProfiler) MixedFP16Gemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) PrintSummary

func (p *CUDABlasProfiler) PrintSummary()

PrintSummary prints cuBLAS profiling stats to stderr.

func (*CUDABlasProfiler) ResetProfile

func (p *CUDABlasProfiler) ResetProfile()

ResetProfile clears all recorded calls and increments the generation.

func (*CUDABlasProfiler) SetStream

func (p *CUDABlasProfiler) SetStream(stream Stream) error

func (*CUDABlasProfiler) Sgemm

func (p *CUDABlasProfiler) Sgemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) SgemmNT

func (p *CUDABlasProfiler) SgemmNT(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

func (*CUDABlasProfiler) SgemmNTStridedBatched

func (p *CUDABlasProfiler) SgemmNTStridedBatched(m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	bPtr unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

func (*CUDABlasProfiler) SgemmStridedBatched

func (p *CUDABlasProfiler) SgemmStridedBatched(m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	bPtr unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

func (*CUDABlasProfiler) Summary

func (p *CUDABlasProfiler) Summary() ProfileSummary

Summary returns aggregated profiling statistics.

type CUDADNN

type CUDADNN struct {
	// contains filtered or unexported fields
}

CUDADNN implements the DNN interface using cuDNN.

func NewCUDADNN

func NewCUDADNN() (*CUDADNN, error)

NewCUDADNN creates a new cuDNN adapter.

func NewCUDADNNFromHandle

func NewCUDADNNFromHandle(h *cudnn.Handle) *CUDADNN

NewCUDADNNFromHandle wraps an existing cuDNN handle.

func (*CUDADNN) ActivationBackward

func (d *CUDADNN) ActivationBackward(
	mode ActivationMode,
	y unsafe.Pointer, dy unsafe.Pointer,
	x unsafe.Pointer, dx unsafe.Pointer,
	shape [4]int,
	stream Stream,
) error

func (*CUDADNN) ActivationForward

func (d *CUDADNN) ActivationForward(
	mode ActivationMode,
	x unsafe.Pointer, shape [4]int,
	y unsafe.Pointer,
	stream Stream,
) error

func (*CUDADNN) AddTensor

func (d *CUDADNN) AddTensor(
	alpha float32,
	b unsafe.Pointer, bShape [4]int,
	beta float32,
	y unsafe.Pointer, yShape [4]int,
	stream Stream,
) error

func (*CUDADNN) BatchNormBackward

func (d *CUDADNN) BatchNormBackward(
	x unsafe.Pointer, xShape [4]int,
	dy unsafe.Pointer,
	scale unsafe.Pointer,
	channels int,
	saveMean, saveInvVariance unsafe.Pointer,
	dx, dScale, dBias unsafe.Pointer,
	stream Stream,
) error

func (*CUDADNN) BatchNormForwardInference

func (d *CUDADNN) BatchNormForwardInference(
	x unsafe.Pointer, xShape [4]int,
	scale, bias, mean, variance unsafe.Pointer,
	channels int,
	epsilon float64,
	y unsafe.Pointer,
	stream Stream,
) error

func (*CUDADNN) BatchNormForwardTraining

func (d *CUDADNN) BatchNormForwardTraining(
	x unsafe.Pointer, xShape [4]int,
	scale, bias unsafe.Pointer,
	channels int,
	epsilon, expAvgFactor float64,
	runningMean, runningVariance unsafe.Pointer,
	saveMean, saveInvVariance unsafe.Pointer,
	y unsafe.Pointer,
	stream Stream,
) error

func (*CUDADNN) ConvBackwardData

func (d *CUDADNN) ConvBackwardData(
	w unsafe.Pointer, wShape [4]int,
	dy unsafe.Pointer, dyShape [4]int,
	dx unsafe.Pointer, dxShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*CUDADNN) ConvBackwardFilter

func (d *CUDADNN) ConvBackwardFilter(
	x unsafe.Pointer, xShape [4]int,
	dy unsafe.Pointer, dyShape [4]int,
	dw unsafe.Pointer, dwShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*CUDADNN) ConvForward

func (d *CUDADNN) ConvForward(
	x unsafe.Pointer, xShape [4]int,
	w unsafe.Pointer, wShape [4]int,
	bias unsafe.Pointer,
	y unsafe.Pointer, yShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*CUDADNN) Destroy

func (d *CUDADNN) Destroy() error

func (*CUDADNN) Handle

func (d *CUDADNN) Handle() *cudnn.Handle

Handle returns the underlying cuDNN handle for backward compatibility.

func (*CUDADNN) PoolingBackward

func (d *CUDADNN) PoolingBackward(
	mode PoolingMode,
	y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int,
	x unsafe.Pointer, dx unsafe.Pointer, xShape [4]int,
	windowH, windowW, padH, padW, strideH, strideW int,
	stream Stream,
) error

func (*CUDADNN) PoolingForward

func (d *CUDADNN) PoolingForward(
	mode PoolingMode,
	x unsafe.Pointer, xShape [4]int,
	y unsafe.Pointer, yShape [4]int,
	windowH, windowW, padH, padW, strideH, strideW int,
	stream Stream,
) error

func (*CUDADNN) SetStream

func (d *CUDADNN) SetStream(stream Stream) error

func (*CUDADNN) SoftmaxForward

func (d *CUDADNN) SoftmaxForward(
	x unsafe.Pointer, shape [4]int,
	y unsafe.Pointer,
	stream Stream,
) error

type CUDAKernels

type CUDAKernels struct{}

CUDAKernels implements the KernelRunner interface using custom CUDA kernels.

func NewCUDAKernels

func NewCUDAKernels() *CUDAKernels

NewCUDAKernels returns a new CUDA kernel runner adapter.

func (*CUDAKernels) Add

func (k *CUDAKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) AddBroadcast

func (k *CUDAKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error

func (*CUDAKernels) AddBroadcast4D

func (k *CUDAKernels) AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s Stream) error

func (*CUDAKernels) AddFP16

func (k *CUDAKernels) AddFP16(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) AddScalar

func (k *CUDAKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Argmax

func (k *CUDAKernels) Argmax(input, result, scratch unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Cos

func (k *CUDAKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) DequantFP8E4M3ToFP16

func (k *CUDAKernels) DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s Stream) error

func (*CUDAKernels) DequantQ4KF32

func (k *CUDAKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, s Stream) error

func (*CUDAKernels) Div

func (k *CUDAKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) DivBroadcast

func (k *CUDAKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error

func (*CUDAKernels) DivBroadcast4D

func (k *CUDAKernels) DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s Stream) error

func (*CUDAKernels) DivFP16

func (k *CUDAKernels) DivFP16(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) DivScalar

func (k *CUDAKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Exp

func (k *CUDAKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) F32ToFP16

func (k *CUDAKernels) F32ToFP16(src, dst unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) FP16ToF32

func (k *CUDAKernels) FP16ToF32(src, dst unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Fill

func (k *CUDAKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error

func (*CUDAKernels) FusedAddRMSNormF32

func (k *CUDAKernels) FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s Stream) error

func (*CUDAKernels) FusedNormAddF32

func (k *CUDAKernels) FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, s Stream) error

func (*CUDAKernels) FusedQKNormRoPEF32

func (k *CUDAKernels) FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, s Stream) error

func (*CUDAKernels) FusedRoPEF32

func (k *CUDAKernels) FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, batch, seqLen, headDim, halfRotary, cosStride int, s Stream) error

func (*CUDAKernels) FusedSwiGLUF32

func (k *CUDAKernels) FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Gather

func (k *CUDAKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, s Stream) error

func (*CUDAKernels) GemmQ4F32

func (k *CUDAKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error

func (*CUDAKernels) GemmQ8F32

func (k *CUDAKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, s Stream) error

func (*CUDAKernels) GemvQ4KF32

func (k *CUDAKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, s Stream) error

func (*CUDAKernels) IncrementCounter

func (k *CUDAKernels) IncrementCounter(counter unsafe.Pointer, delta int, s Stream) error

func (*CUDAKernels) Log

func (k *CUDAKernels) Log(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Mul

func (k *CUDAKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) MulBroadcast

func (k *CUDAKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error

func (*CUDAKernels) MulBroadcast4D

func (k *CUDAKernels) MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s Stream) error

func (*CUDAKernels) MulFP16

func (k *CUDAKernels) MulFP16(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) MulScalar

func (k *CUDAKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) OffsetMemcpy

func (k *CUDAKernels) OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error

func (*CUDAKernels) OffsetMemcpyFP16

func (k *CUDAKernels) OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error

func (*CUDAKernels) Pow

func (k *CUDAKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) PowScalar

func (k *CUDAKernels) PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) RMSNorm

func (k *CUDAKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, s Stream) error

func (*CUDAKernels) RMSNormFP16

func (k *CUDAKernels) RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s Stream) error

func (*CUDAKernels) Repeat

func (k *CUDAKernels) Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s Stream) error

func (*CUDAKernels) ResetCounter

func (k *CUDAKernels) ResetCounter(counter unsafe.Pointer, value int, s Stream) error

func (*CUDAKernels) RoPESelect

func (k *CUDAKernels) RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer,
	halfRotary int, s Stream) error

func (*CUDAKernels) Rsqrt

func (k *CUDAKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) ScaledSoftmaxF32

func (k *CUDAKernels) ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, s Stream) error

func (*CUDAKernels) ScaledSoftmaxFP16

func (k *CUDAKernels) ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, s Stream) error

func (*CUDAKernels) SgemvM1

func (k *CUDAKernels) SgemvM1(y, A, x unsafe.Pointer, M, N int, s Stream) error

func (*CUDAKernels) Sin

func (k *CUDAKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Softmax

func (k *CUDAKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error

func (*CUDAKernels) Sqrt

func (k *CUDAKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Sub

func (k *CUDAKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) SubBroadcast

func (k *CUDAKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error

func (*CUDAKernels) SubBroadcast4D

func (k *CUDAKernels) SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s Stream) error

func (*CUDAKernels) SubFP16

func (k *CUDAKernels) SubFP16(a, b, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) SubScalar

func (k *CUDAKernels) SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) SumAxis

func (k *CUDAKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error

func (*CUDAKernels) Tanh

func (k *CUDAKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) TanhPrime

func (k *CUDAKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error

func (*CUDAKernels) Transpose2D

func (k *CUDAKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, s Stream) error

func (*CUDAKernels) TransposeND

func (k *CUDAKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, s Stream) error

type CUDAMemPool

type CUDAMemPool struct {
	// contains filtered or unexported fields
}

CUDAMemPool implements the MemPool interface by wrapping cuda.MemPool.

func NewCUDAMemPool

func NewCUDAMemPool() *CUDAMemPool

NewCUDAMemPool creates a new CUDA memory pool adapter.

func NewCUDAMemPoolFrom

func NewCUDAMemPoolFrom(pool *cuda.MemPool) *CUDAMemPool

NewCUDAMemPoolFrom wraps an existing cuda.MemPool.

func (*CUDAMemPool) Alloc

func (p *CUDAMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)

func (*CUDAMemPool) AllocManaged

func (p *CUDAMemPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)

func (*CUDAMemPool) Drain

func (p *CUDAMemPool) Drain() error

func (*CUDAMemPool) Free

func (p *CUDAMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)

func (*CUDAMemPool) FreeManaged

func (p *CUDAMemPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)

func (*CUDAMemPool) Inner

func (p *CUDAMemPool) Inner() *cuda.MemPool

Inner returns the underlying cuda.MemPool for backward compatibility.

func (*CUDAMemPool) Stats

func (p *CUDAMemPool) Stats() (int, int)

type CUDARuntime

type CUDARuntime struct{}

CUDARuntime implements the Runtime interface using the CUDA runtime API.

func NewCUDARuntime

func NewCUDARuntime() *CUDARuntime

NewCUDARuntime returns a new CUDA runtime adapter.

func (*CUDARuntime) CreateStream

func (r *CUDARuntime) CreateStream() (Stream, error)

func (*CUDARuntime) DeviceType

func (r *CUDARuntime) DeviceType() device.Type

func (*CUDARuntime) Free

func (r *CUDARuntime) Free(ptr unsafe.Pointer) error

func (*CUDARuntime) GetDeviceCount

func (r *CUDARuntime) GetDeviceCount() (int, error)

func (*CUDARuntime) Malloc

func (r *CUDARuntime) Malloc(byteSize int) (unsafe.Pointer, error)

func (*CUDARuntime) Memcpy

func (r *CUDARuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error

func (*CUDARuntime) MemcpyAsync

func (r *CUDARuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error

func (*CUDARuntime) MemcpyPeer

func (r *CUDARuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, count int) error

func (*CUDARuntime) SetDevice

func (r *CUDARuntime) SetDevice(deviceID int) error

type DNN

type DNN interface {
	// ConvForward performs 2D convolution.
	// x: [N,C_in,H,W], w: [C_out,C_in/groups,kH,kW], y: [N,C_out,outH,outW].
	// bias is optional (nil to skip).
	// pads: [padH, padW] (symmetric), strides: [sH, sW], dilations: [dH, dW].
	ConvForward(
		x unsafe.Pointer, xShape [4]int,
		w unsafe.Pointer, wShape [4]int,
		bias unsafe.Pointer,
		y unsafe.Pointer, yShape [4]int,
		pads [2]int, strides [2]int, dilations [2]int,
		groups int,
		stream Stream,
	) error

	// ConvBackwardData computes the gradient of the input for 2D convolution.
	// w: [C_out,C_in/groups,kH,kW], dy: [N,C_out,outH,outW], dx: [N,C_in,H,W].
	ConvBackwardData(
		w unsafe.Pointer, wShape [4]int,
		dy unsafe.Pointer, dyShape [4]int,
		dx unsafe.Pointer, dxShape [4]int,
		pads [2]int, strides [2]int, dilations [2]int,
		groups int,
		stream Stream,
	) error

	// ConvBackwardFilter computes the gradient of the filter for 2D convolution.
	// x: [N,C_in,H,W], dy: [N,C_out,outH,outW], dw: [C_out,C_in/groups,kH,kW].
	ConvBackwardFilter(
		x unsafe.Pointer, xShape [4]int,
		dy unsafe.Pointer, dyShape [4]int,
		dw unsafe.Pointer, dwShape [4]int,
		pads [2]int, strides [2]int, dilations [2]int,
		groups int,
		stream Stream,
	) error

	// BatchNormForwardInference performs batch normalization using running statistics.
	// x: [N,C,H,W], scale/bias/mean/variance: [C], y: [N,C,H,W].
	BatchNormForwardInference(
		x unsafe.Pointer, xShape [4]int,
		scale, bias, mean, variance unsafe.Pointer,
		channels int,
		epsilon float64,
		y unsafe.Pointer,
		stream Stream,
	) error

	// BatchNormForwardTraining performs batch normalization computing batch statistics.
	// x: [N,C,H,W], scale/bias: [C], y: [N,C,H,W].
	// saveMean and saveInvVariance are outputs for the backward pass, each [C].
	// runningMean and runningVariance are updated in-place with exponential averaging.
	BatchNormForwardTraining(
		x unsafe.Pointer, xShape [4]int,
		scale, bias unsafe.Pointer,
		channels int,
		epsilon, expAvgFactor float64,
		runningMean, runningVariance unsafe.Pointer,
		saveMean, saveInvVariance unsafe.Pointer,
		y unsafe.Pointer,
		stream Stream,
	) error

	// BatchNormBackward computes gradients for batch normalization.
	// x: [N,C,H,W], dy: [N,C,H,W], scale: [C].
	// saveMean, saveInvVariance: [C] (from BatchNormForwardTraining).
	// dx: [N,C,H,W], dScale, dBias: [C].
	BatchNormBackward(
		x unsafe.Pointer, xShape [4]int,
		dy unsafe.Pointer,
		scale unsafe.Pointer,
		channels int,
		saveMean, saveInvVariance unsafe.Pointer,
		dx, dScale, dBias unsafe.Pointer,
		stream Stream,
	) error

	// ActivationForward applies an activation function element-wise.
	// x and y have the same shape [N,C,H,W].
	ActivationForward(
		mode ActivationMode,
		x unsafe.Pointer, shape [4]int,
		y unsafe.Pointer,
		stream Stream,
	) error

	// ActivationBackward computes the gradient of an activation function.
	// x: original input, y: forward output, dy: upstream gradient, dx: output gradient.
	// All have shape [N,C,H,W].
	ActivationBackward(
		mode ActivationMode,
		y unsafe.Pointer, dy unsafe.Pointer,
		x unsafe.Pointer, dx unsafe.Pointer,
		shape [4]int,
		stream Stream,
	) error

	// PoolingForward performs 2D pooling.
	// x: [N,C,H,W], y: [N,C,outH,outW].
	PoolingForward(
		mode PoolingMode,
		x unsafe.Pointer, xShape [4]int,
		y unsafe.Pointer, yShape [4]int,
		windowH, windowW, padH, padW, strideH, strideW int,
		stream Stream,
	) error

	// PoolingBackward computes the gradient of 2D pooling.
	// y: forward output, dy: upstream gradient, x: forward input, dx: output gradient.
	PoolingBackward(
		mode PoolingMode,
		y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int,
		x unsafe.Pointer, dx unsafe.Pointer, xShape [4]int,
		windowH, windowW, padH, padW, strideH, strideW int,
		stream Stream,
	) error

	// SoftmaxForward computes softmax over the channel dimension.
	// x and y have the same shape [N,C,H,W].
	SoftmaxForward(
		x unsafe.Pointer, shape [4]int,
		y unsafe.Pointer,
		stream Stream,
	) error

	// AddTensor performs y = alpha*b + beta*y for bias addition.
	// b: [1,C,1,1], y: [N,C,H,W].
	AddTensor(
		alpha float32,
		b unsafe.Pointer, bShape [4]int,
		beta float32,
		y unsafe.Pointer, yShape [4]int,
		stream Stream,
	) error

	// SetStream associates the DNN handle with an asynchronous stream.
	SetStream(stream Stream) error

	// Destroy releases the DNN handle resources.
	Destroy() error
}

DNN abstracts GPU-accelerated deep neural network primitives. Each vendor (cuDNN, MIOpen) provides an implementation.

All tensor pointers are device memory. Shapes follow NCHW layout. The implementation handles descriptor creation internally.

type KernelRunner

type KernelRunner interface {
	// Binary elementwise operations: c[i] = op(a[i], b[i])
	Add(a, b, c unsafe.Pointer, n int, stream Stream) error
	Sub(a, b, c unsafe.Pointer, n int, stream Stream) error
	Mul(a, b, c unsafe.Pointer, n int, stream Stream) error
	Div(a, b, c unsafe.Pointer, n int, stream Stream) error
	Pow(base, exp, c unsafe.Pointer, n int, stream Stream) error

	// Unary elementwise operations: c[i] = op(a[i])
	Exp(a, c unsafe.Pointer, n int, stream Stream) error
	Log(a, c unsafe.Pointer, n int, stream Stream) error
	Sqrt(a, c unsafe.Pointer, n int, stream Stream) error
	Rsqrt(a, c unsafe.Pointer, n int, stream Stream) error
	Sin(a, c unsafe.Pointer, n int, stream Stream) error
	Cos(a, c unsafe.Pointer, n int, stream Stream) error
	Tanh(a, c unsafe.Pointer, n int, stream Stream) error

	// TanhPrime: c[i] = (1 - tanh(a[i])^2) * upstream[i]
	TanhPrime(a, upstream, c unsafe.Pointer, n int, stream Stream) error

	// Scalar operations: c[i] = op(a[i], scalar)
	AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
	MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
	DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error

	SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
	PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error

	// Fill sets all n elements to value.
	Fill(data unsafe.Pointer, value float32, n int, stream Stream) error

	// SumAxis reduces along one axis: output[outer][inner] = sum(input[outer][k][inner], k=0..axisSize-1).
	SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error

	// Softmax computes softmax along one axis.
	Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error

	// GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B.
	// A_q4 is in GPU separated layout (scales then data), B is [K,N] float32, C is [M,N] float32.
	// dataOffset is the byte offset from A_q4 to the packed data region.
	GemmQ4F32(aQ4, b, c unsafe.Pointer, m, k, n, dataOffset int, stream Stream) error

	// GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x.
	// W_q4k is raw Q4_K super-blocks for matrix [M, K]. x is [K] float32.
	// y is [M] float32. K must be a multiple of 256. Batch=1 only.
	GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, stream Stream) error

	// DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory.
	// src is raw Q4_K super-blocks for matrix [rows, K]. dst is [rows, K] float32.
	// K must be a multiple of 256. Used for non-GEMV cuBLAS path.
	DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, stream Stream) error

	// GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B.
	// A_q8 is packed Q8_0 blocks (36 bytes per 32 values), B is [K,N] float32, C is [M,N] float32.
	GemmQ8F32(aQ8, b, c unsafe.Pointer, m, k, n int, stream Stream) error

	// Broadcast binary ops: c[r,c] = op(a[r*saRow+c*saCol], b[r*sbRow+c*sbCol]).
	// Strides encode broadcasting: D for full row, 1 for full col, 0 for broadcast.
	AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
	SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
	MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
	DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error

	// 4D broadcast binary ops: c[i0,i1,i2,i3] = op(a[...], b[...]) with per-dim strides.
	// d0-d3 are output dims; sa0-sa3 and sb0-sb3 are per-dim strides (0 = broadcast).
	AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
	SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
	MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
	DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error

	// Transpose2D transposes a [rows, cols] matrix to [cols, rows] using tiled shared memory.
	Transpose2D(input, output unsafe.Pointer, rows, cols int, stream Stream) error

	// TransposeND permutes dimensions of an N-D tensor.
	// inStrides/outStrides/perm are int32 slices on host.
	TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, stream Stream) error

	// Gather performs embedding table lookup: output[i,:] = table[indices[i],:].
	// table: [V, D], indices: [N] int64 on device, output: [N, D].
	Gather(table, indices, output unsafe.Pointer, N, D, V int, stream Stream) error

	// RMSNorm computes fused RMSNorm: output = input * rsqrt(mean(input^2) + eps) * weight.
	// input: [rows, D], weight: [D], output: [rows, D], scales: [rows] (per-row rsqrt values for backward).
	RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, stream Stream) error

	// Repeat replicates elements along an axis.
	// outerSize = product of dims before axis, axisDim = size of axis,
	// innerSize = product of dims after axis, reps = number of repetitions.
	Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, stream Stream) error

	// Argmax finds the index of the maximum element in a float32 array on device.
	// input: [n] float32, result: single int32 on device, scratch: temp storage.
	// scratch must be at least 2*ceil(n/256)*4 bytes.
	Argmax(input, result, scratch unsafe.Pointer, n int, stream Stream) error

	// FusedRoPEF32 applies rotary positional embedding in one kernel launch.
	// input/output: [batch * seqLen * headDim], cos/sin: [seqLen * cosStride].
	FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, batch, seqLen, headDim, halfRotary, cosStride int, stream Stream) error

	// FusedSwiGLUF32 applies SwiGLU activation in one kernel launch.
	// output[i] = w1[i] * sigmoid(w1[i]) * w3[i]. All arrays have n elements.
	FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, stream Stream) error

	// FusedAddRMSNormF32 fuses residual addition and RMSNorm into one kernel launch.
	// sum_out = input + residual, normed_out = rmsnorm(sum_out, weight, eps).
	// input: [rows, D], residual: [rows, D], weight: [D],
	// normedOut: [rows, D], sumOut: [rows, D].
	FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, stream Stream) error

	// FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch.
	// output = rmsnorm(input, weight, eps) + residual.
	// input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].
	FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, stream Stream) error

	// FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads.
	// Replaces 4 kernel launches (Q_norm + K_norm + Q_RoPE + K_RoPE) with 1.
	// input: [totalHeads, headDim], weightQ/weightK: [headDim],
	// cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].
	FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, stream Stream) error

	// ScaledSoftmaxF32 computes softmax(input * scale) in one kernel launch,
	// replacing the MulScalar + Softmax chain (saves 1 kernel launch per call).
	ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream Stream) error

	// FP16 elementwise operations: inputs and outputs are __half (2 bytes each).
	AddFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
	SubFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
	MulFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
	DivFP16(a, b, c unsafe.Pointer, n int, stream Stream) error

	// RMSNormFP16 computes RMSNorm on FP16 data with FP32 accumulation.
	RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, stream Stream) error

	// ScaledSoftmaxFP16 computes softmax(input * scale) on FP16 data with FP32 accumulation.
	ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream Stream) error

	// F32ToFP16 converts n float32 elements to FP16 on device.
	F32ToFP16(src, dst unsafe.Pointer, n int, stream Stream) error

	// FP16ToF32 converts n FP16 elements to float32 on device.
	FP16ToF32(src, dst unsafe.Pointer, n int, stream Stream) error

	// DequantFP8E4M3ToFP16 dequantizes n FP8 E4M3 bytes to FP16 on device.
	// output[i] = fp8_to_fp16(input[i]) * scale.
	DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, stream Stream) error

	// GPU-resident counter operations for CUDA graph position tracking.
	IncrementCounter(counter unsafe.Pointer, delta int, stream Stream) error
	ResetCounter(counter unsafe.Pointer, value int, stream Stream) error

	// OffsetMemcpy copies dim floats from src to dst at offset counter*dim.
	// counter is a GPU-resident int32. Used for GPU-driven KV cache append.
	OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error

	// OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim.
	// counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.
	OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error

	// RoPESelect copies halfRotary cos/sin values from the precomputed table
	// at position counter[0]. Used for GPU-driven RoPE angle selection.
	RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, stream Stream) error

	// SgemvM1 computes y = A*x for M=1 decode (single-token GEMV).
	// y[M], A[M x N] row-major, x[N].
	SgemvM1(y, A, x unsafe.Pointer, M, N int, stream Stream) error
}

KernelRunner abstracts GPU compute kernels for elementwise, scalar, reduction, and utility operations. Each vendor provides an implementation using its own kernel compilation toolchain (CUDA .cu, HIP .hip, OpenCL .cl).

type MemPool

type MemPool interface {
	// Alloc returns a device pointer of at least byteSize bytes on the given device.
	// May return a cached pointer from a previous Free call.
	Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
	// Free returns a device pointer to the pool for reuse.
	Free(deviceID int, ptr unsafe.Pointer, byteSize int)
	// AllocManaged returns a unified memory pointer accessible from both host
	// and device. Returns an error on backends that do not support managed memory.
	AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
	// FreeManaged returns a managed memory pointer to the pool for reuse.
	FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
	// Drain frees all cached pointers back to the device.
	Drain() error
	// Stats returns the number of cached allocations and their total bytes.
	Stats() (allocations int, totalBytes int)
}

MemPool abstracts a GPU device memory pool with size-bucketed caching. Each vendor can reuse the same pool logic since the pool operates on opaque device pointers, but the underlying Malloc/Free come from the vendor's Runtime.

type MemcpyKind

type MemcpyKind int

MemcpyKind specifies the direction of a memory copy operation.

const (
	// MemcpyHostToDevice copies from host (CPU) memory to device (GPU) memory.
	MemcpyHostToDevice MemcpyKind = iota
	// MemcpyDeviceToHost copies from device (GPU) memory to host (CPU) memory.
	MemcpyDeviceToHost
	// MemcpyDeviceToDevice copies between device (GPU) memory regions.
	MemcpyDeviceToDevice
)

type OpSummary

type OpSummary struct {
	Op        string
	M, N, K   int
	Batch     int
	Calls     int
	TotalTime time.Duration
	AvgTime   time.Duration
}

OpSummary holds per-operation stats.

type OpenCLDNN

type OpenCLDNN struct{}

OpenCLDNN implements the DNN interface for OpenCL. OpenCL has no standard DNN library (like cuDNN or MIOpen), so all operations return ErrNotSupported. The compute engine falls back to CPU.

func NewOpenCLDNN

func NewOpenCLDNN() *OpenCLDNN

NewOpenCLDNN returns a new OpenCL DNN stub.

func (*OpenCLDNN) ActivationBackward

func (d *OpenCLDNN) ActivationBackward(
	_ ActivationMode,
	_ unsafe.Pointer, _ unsafe.Pointer,
	_ unsafe.Pointer, _ unsafe.Pointer,
	_ [4]int,
	_ Stream,
) error

func (*OpenCLDNN) ActivationForward

func (d *OpenCLDNN) ActivationForward(
	_ ActivationMode,
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer,
	_ Stream,
) error

func (*OpenCLDNN) AddTensor

func (d *OpenCLDNN) AddTensor(
	_ float32,
	_ unsafe.Pointer, _ [4]int,
	_ float32,
	_ unsafe.Pointer, _ [4]int,
	_ Stream,
) error

func (*OpenCLDNN) BatchNormBackward

func (d *OpenCLDNN) BatchNormBackward(
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer,
	_ unsafe.Pointer,
	_ int,
	_, _ unsafe.Pointer,
	_, _, _ unsafe.Pointer,
	_ Stream,
) error

func (*OpenCLDNN) BatchNormForwardInference

func (d *OpenCLDNN) BatchNormForwardInference(
	_ unsafe.Pointer, _ [4]int,
	_, _, _, _ unsafe.Pointer,
	_ int,
	_ float64,
	_ unsafe.Pointer,
	_ Stream,
) error

func (*OpenCLDNN) BatchNormForwardTraining

func (d *OpenCLDNN) BatchNormForwardTraining(
	_ unsafe.Pointer, _ [4]int,
	_, _ unsafe.Pointer,
	_ int,
	_, _ float64,
	_, _ unsafe.Pointer,
	_, _ unsafe.Pointer,
	_ unsafe.Pointer,
	_ Stream,
) error

func (*OpenCLDNN) ConvBackwardData

func (d *OpenCLDNN) ConvBackwardData(
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_, _, _ [2]int,
	_ int,
	_ Stream,
) error

func (*OpenCLDNN) ConvBackwardFilter

func (d *OpenCLDNN) ConvBackwardFilter(
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_, _, _ [2]int,
	_ int,
	_ Stream,
) error

func (*OpenCLDNN) ConvForward

func (d *OpenCLDNN) ConvForward(
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer,
	_ unsafe.Pointer, _ [4]int,
	_, _, _ [2]int,
	_ int,
	_ Stream,
) error

func (*OpenCLDNN) Destroy

func (d *OpenCLDNN) Destroy() error

func (*OpenCLDNN) PoolingBackward

func (d *OpenCLDNN) PoolingBackward(
	_ PoolingMode,
	_ unsafe.Pointer, _ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ unsafe.Pointer, _ [4]int,
	_, _, _, _, _, _ int,
	_ Stream,
) error

func (*OpenCLDNN) PoolingForward

func (d *OpenCLDNN) PoolingForward(
	_ PoolingMode,
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer, _ [4]int,
	_, _, _, _, _, _ int,
	_ Stream,
) error

func (*OpenCLDNN) SetStream

func (d *OpenCLDNN) SetStream(_ Stream) error

func (*OpenCLDNN) SoftmaxForward

func (d *OpenCLDNN) SoftmaxForward(
	_ unsafe.Pointer, _ [4]int,
	_ unsafe.Pointer,
	_ Stream,
) error

type OpenCLMemPool

type OpenCLMemPool struct {
	// contains filtered or unexported fields
}

OpenCLMemPool implements the MemPool interface for OpenCL. It uses a simple size-bucketed cache of cl_mem buffers.

func NewOpenCLMemPool

func NewOpenCLMemPool(rt *OpenCLRuntime) *OpenCLMemPool

NewOpenCLMemPool creates a new OpenCL memory pool.

func (*OpenCLMemPool) Alloc

func (p *OpenCLMemPool) Alloc(_ int, byteSize int) (unsafe.Pointer, error)

func (*OpenCLMemPool) AllocManaged

func (p *OpenCLMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)

func (*OpenCLMemPool) Drain

func (p *OpenCLMemPool) Drain() error

func (*OpenCLMemPool) Free

func (p *OpenCLMemPool) Free(_ int, ptr unsafe.Pointer, byteSize int)

func (*OpenCLMemPool) FreeManaged

func (p *OpenCLMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)

func (*OpenCLMemPool) Stats

func (p *OpenCLMemPool) Stats() (int, int)

type OpenCLRuntime

type OpenCLRuntime struct {
	// contains filtered or unexported fields
}

OpenCLRuntime implements the Runtime interface using OpenCL.

func NewOpenCLRuntime

func NewOpenCLRuntime() *OpenCLRuntime

NewOpenCLRuntime returns a new OpenCL runtime adapter. Returns nil if libOpenCL is not available on this system.

func (*OpenCLRuntime) CLContext

func (r *OpenCLRuntime) CLContext() unsafe.Pointer

CLContext returns the underlying cl_context pointer.

func (*OpenCLRuntime) CLDevice

func (r *OpenCLRuntime) CLDevice() unsafe.Pointer

CLDevice returns the underlying cl_device_id pointer.

func (*OpenCLRuntime) CLQueue

func (r *OpenCLRuntime) CLQueue() unsafe.Pointer

CLQueue returns the default command queue pointer.

func (*OpenCLRuntime) CreateStream

func (r *OpenCLRuntime) CreateStream() (Stream, error)

func (*OpenCLRuntime) DeviceType

func (r *OpenCLRuntime) DeviceType() device.Type

func (*OpenCLRuntime) Free

func (r *OpenCLRuntime) Free(ptr unsafe.Pointer) error

func (*OpenCLRuntime) GetDeviceCount

func (r *OpenCLRuntime) GetDeviceCount() (int, error)

func (*OpenCLRuntime) Malloc

func (r *OpenCLRuntime) Malloc(byteSize int) (unsafe.Pointer, error)

func (*OpenCLRuntime) Memcpy

func (r *OpenCLRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error

func (*OpenCLRuntime) MemcpyAsync

func (r *OpenCLRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, _ Stream) error

func (*OpenCLRuntime) MemcpyPeer

func (r *OpenCLRuntime) MemcpyPeer(dst unsafe.Pointer, _ int, src unsafe.Pointer, _ int, count int) error

func (*OpenCLRuntime) SetDevice

func (r *OpenCLRuntime) SetDevice(deviceID int) error

type PoolingMode

type PoolingMode int

PoolingMode selects the pooling strategy for DNN operations.

const (
	PoolingMax PoolingMode = iota
	PoolingAverageCountIncludePad
	PoolingAverageCountExcludePad
)

type ProfileSummary

type ProfileSummary struct {
	TotalCalls    int
	TotalDuration time.Duration
	ByOp          []OpSummary
}

ProfileSummary holds aggregated cuBLAS profiling stats.

type ROCmBlas

type ROCmBlas struct {
	// contains filtered or unexported fields
}

ROCmBlas implements the BLAS interface using rocBLAS.

func NewROCmBlas

func NewROCmBlas() (*ROCmBlas, error)

NewROCmBlas creates a new rocBLAS adapter. The caller must call Destroy when done.

func NewROCmBlasFromHandle

func NewROCmBlasFromHandle(h *rocblas.Handle) *ROCmBlas

NewROCmBlasFromHandle wraps an existing rocBLAS handle.

func (*ROCmBlas) BFloat16Gemm

func (b *ROCmBlas) BFloat16Gemm(_, _, _ int, _ float32,
	_, _ unsafe.Pointer, _ float32, _ unsafe.Pointer,
) error

func (*ROCmBlas) Destroy

func (b *ROCmBlas) Destroy() error

func (*ROCmBlas) Float16Gemm

func (b *ROCmBlas) Float16Gemm(_, _, _ int, _ float32,
	_, _ unsafe.Pointer, _ float32, _ unsafe.Pointer,
) error

func (*ROCmBlas) Handle

func (b *ROCmBlas) Handle() *rocblas.Handle

Handle returns the underlying rocBLAS handle.

func (*ROCmBlas) MixedBF16Gemm

func (b *ROCmBlas) MixedBF16Gemm(_, _, _ int, _ float32,
	_, _ unsafe.Pointer, _ float32, _ unsafe.Pointer,
) error

func (*ROCmBlas) MixedFP16Gemm

func (b *ROCmBlas) MixedFP16Gemm(_, _, _ int, _ float32,
	_, _ unsafe.Pointer, _ float32, _ unsafe.Pointer,
) error

func (*ROCmBlas) SetStream

func (b *ROCmBlas) SetStream(stream Stream) error

func (*ROCmBlas) Sgemm

func (b *ROCmBlas) Sgemm(m, n, k int, alpha float32,
	a unsafe.Pointer, bPtr unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

type ROCmDNN

type ROCmDNN struct {
	// contains filtered or unexported fields
}

ROCmDNN implements the DNN interface using MIOpen.

func NewROCmDNN

func NewROCmDNN() (*ROCmDNN, error)

NewROCmDNN creates a new MIOpen adapter.

func NewROCmDNNFromHandle

func NewROCmDNNFromHandle(h *miopen.Handle) *ROCmDNN

NewROCmDNNFromHandle wraps an existing MIOpen handle.

func (*ROCmDNN) ActivationBackward

func (d *ROCmDNN) ActivationBackward(
	mode ActivationMode,
	y unsafe.Pointer, dy unsafe.Pointer,
	x unsafe.Pointer, dx unsafe.Pointer,
	shape [4]int,
	stream Stream,
) error

func (*ROCmDNN) ActivationForward

func (d *ROCmDNN) ActivationForward(
	mode ActivationMode,
	x unsafe.Pointer, shape [4]int,
	y unsafe.Pointer,
	stream Stream,
) error

func (*ROCmDNN) AddTensor

func (d *ROCmDNN) AddTensor(
	alpha float32,
	b unsafe.Pointer, bShape [4]int,
	beta float32,
	y unsafe.Pointer, yShape [4]int,
	stream Stream,
) error

func (*ROCmDNN) BatchNormBackward

func (d *ROCmDNN) BatchNormBackward(
	x unsafe.Pointer, xShape [4]int,
	dy unsafe.Pointer,
	scale unsafe.Pointer,
	channels int,
	saveMean, saveInvVariance unsafe.Pointer,
	dx, dScale, dBias unsafe.Pointer,
	stream Stream,
) error

func (*ROCmDNN) BatchNormForwardInference

func (d *ROCmDNN) BatchNormForwardInference(
	x unsafe.Pointer, xShape [4]int,
	scale, bias, mean, variance unsafe.Pointer,
	channels int,
	epsilon float64,
	y unsafe.Pointer,
	stream Stream,
) error

func (*ROCmDNN) BatchNormForwardTraining

func (d *ROCmDNN) BatchNormForwardTraining(
	x unsafe.Pointer, xShape [4]int,
	scale, bias unsafe.Pointer,
	channels int,
	epsilon, expAvgFactor float64,
	runningMean, runningVariance unsafe.Pointer,
	saveMean, saveInvVariance unsafe.Pointer,
	y unsafe.Pointer,
	stream Stream,
) error

func (*ROCmDNN) ConvBackwardData

func (d *ROCmDNN) ConvBackwardData(
	w unsafe.Pointer, wShape [4]int,
	dy unsafe.Pointer, dyShape [4]int,
	dx unsafe.Pointer, dxShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*ROCmDNN) ConvBackwardFilter

func (d *ROCmDNN) ConvBackwardFilter(
	x unsafe.Pointer, xShape [4]int,
	dy unsafe.Pointer, dyShape [4]int,
	dw unsafe.Pointer, dwShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*ROCmDNN) ConvForward

func (d *ROCmDNN) ConvForward(
	x unsafe.Pointer, xShape [4]int,
	w unsafe.Pointer, wShape [4]int,
	bias unsafe.Pointer,
	y unsafe.Pointer, yShape [4]int,
	pads [2]int, strides [2]int, dilations [2]int,
	groups int,
	stream Stream,
) error

func (*ROCmDNN) Destroy

func (d *ROCmDNN) Destroy() error

func (*ROCmDNN) Handle

func (d *ROCmDNN) Handle() *miopen.Handle

Handle returns the underlying MIOpen handle.

func (*ROCmDNN) PoolingBackward

func (d *ROCmDNN) PoolingBackward(
	mode PoolingMode,
	y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int,
	x unsafe.Pointer, dx unsafe.Pointer, xShape [4]int,
	windowH, windowW, padH, padW, strideH, strideW int,
	stream Stream,
) error

func (*ROCmDNN) PoolingForward

func (d *ROCmDNN) PoolingForward(
	mode PoolingMode,
	x unsafe.Pointer, xShape [4]int,
	y unsafe.Pointer, yShape [4]int,
	windowH, windowW, padH, padW, strideH, strideW int,
	stream Stream,
) error

func (*ROCmDNN) SetStream

func (d *ROCmDNN) SetStream(stream Stream) error

func (*ROCmDNN) SoftmaxForward

func (d *ROCmDNN) SoftmaxForward(
	x unsafe.Pointer, shape [4]int,
	y unsafe.Pointer,
	stream Stream,
) error

type ROCmKernels

type ROCmKernels struct{}

ROCmKernels implements the KernelRunner interface using custom HIP kernels.

func NewROCmKernels

func NewROCmKernels() *ROCmKernels

NewROCmKernels returns a new ROCm kernel runner adapter.

func (*ROCmKernels) Add

func (k *ROCmKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) AddBroadcast

func (k *ROCmKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error

func (*ROCmKernels) AddBroadcast4D

func (k *ROCmKernels) AddBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) AddFP16

func (k *ROCmKernels) AddFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) AddScalar

func (k *ROCmKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Argmax

func (k *ROCmKernels) Argmax(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) Cos

func (k *ROCmKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) DequantFP8E4M3ToFP16

func (k *ROCmKernels) DequantFP8E4M3ToFP16(_, _ unsafe.Pointer, _ float32, _ int, _ Stream) error

func (*ROCmKernels) DequantQ4KF32

func (k *ROCmKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, _ Stream) error

func (*ROCmKernels) Div

func (k *ROCmKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) DivBroadcast

func (k *ROCmKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error

func (*ROCmKernels) DivBroadcast4D

func (k *ROCmKernels) DivBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) DivFP16

func (k *ROCmKernels) DivFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) DivScalar

func (k *ROCmKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Exp

func (k *ROCmKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) F32ToFP16

func (k *ROCmKernels) F32ToFP16(_, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) FP16ToF32

func (k *ROCmKernels) FP16ToF32(_, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) Fill

func (k *ROCmKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error

func (*ROCmKernels) FusedAddRMSNormF32

func (k *ROCmKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error

func (*ROCmKernels) FusedNormAddF32

func (k *ROCmKernels) FusedNormAddF32(_, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error

func (*ROCmKernels) FusedQKNormRoPEF32

func (k *ROCmKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ float32, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) FusedRoPEF32

func (k *ROCmKernels) FusedRoPEF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) FusedSwiGLUF32

func (k *ROCmKernels) FusedSwiGLUF32(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) Gather

func (k *ROCmKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, _ Stream) error

func (*ROCmKernels) GemmQ4F32

func (k *ROCmKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error

func (*ROCmKernels) GemmQ8F32

func (k *ROCmKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, _ Stream) error

func (*ROCmKernels) GemvQ4KF32

func (k *ROCmKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, _ Stream) error

func (*ROCmKernels) IncrementCounter

func (k *ROCmKernels) IncrementCounter(_ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) Log

func (k *ROCmKernels) Log(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Mul

func (k *ROCmKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) MulBroadcast

func (k *ROCmKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error

func (*ROCmKernels) MulBroadcast4D

func (k *ROCmKernels) MulBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) MulFP16

func (k *ROCmKernels) MulFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) MulScalar

func (k *ROCmKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) OffsetMemcpy

func (k *ROCmKernels) OffsetMemcpy(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error

func (*ROCmKernels) OffsetMemcpyFP16

func (k *ROCmKernels) OffsetMemcpyFP16(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error

func (*ROCmKernels) Pow

func (k *ROCmKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) PowScalar

func (k *ROCmKernels) PowScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) RMSNorm

func (k *ROCmKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, _ Stream) error

func (*ROCmKernels) RMSNormFP16

func (k *ROCmKernels) RMSNormFP16(_, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error

func (*ROCmKernels) Repeat

func (k *ROCmKernels) Repeat(_ unsafe.Pointer, _ unsafe.Pointer, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) ResetCounter

func (k *ROCmKernels) ResetCounter(_ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) RoPESelect

func (k *ROCmKernels) RoPESelect(_, _, _, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) Rsqrt

func (k *ROCmKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) ScaledSoftmaxF32

func (k *ROCmKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error

func (*ROCmKernels) ScaledSoftmaxFP16

func (k *ROCmKernels) ScaledSoftmaxFP16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error

func (*ROCmKernels) SgemvM1

func (k *ROCmKernels) SgemvM1(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error

func (*ROCmKernels) Sin

func (k *ROCmKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Softmax

func (k *ROCmKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error

func (*ROCmKernels) Sqrt

func (k *ROCmKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Sub

func (k *ROCmKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) SubBroadcast

func (k *ROCmKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error

func (*ROCmKernels) SubBroadcast4D

func (k *ROCmKernels) SubBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error

func (*ROCmKernels) SubFP16

func (k *ROCmKernels) SubFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) SubScalar

func (k *ROCmKernels) SubScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error

func (*ROCmKernels) SumAxis

func (k *ROCmKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error

func (*ROCmKernels) Tanh

func (k *ROCmKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) TanhPrime

func (k *ROCmKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error

func (*ROCmKernels) Transpose2D

func (k *ROCmKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, _ Stream) error

func (*ROCmKernels) TransposeND

func (k *ROCmKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, _ Stream) error

type ROCmMemPool

type ROCmMemPool struct {
	// contains filtered or unexported fields
}

ROCmMemPool implements the MemPool interface by wrapping hip.MemPool.

func NewROCmMemPool

func NewROCmMemPool() *ROCmMemPool

NewROCmMemPool creates a new ROCm memory pool adapter.

func NewROCmMemPoolFrom

func NewROCmMemPoolFrom(pool *hip.MemPool) *ROCmMemPool

NewROCmMemPoolFrom wraps an existing hip.MemPool.

func (*ROCmMemPool) Alloc

func (p *ROCmMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)

func (*ROCmMemPool) AllocManaged

func (p *ROCmMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)

func (*ROCmMemPool) Drain

func (p *ROCmMemPool) Drain() error

func (*ROCmMemPool) Free

func (p *ROCmMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)

func (*ROCmMemPool) FreeManaged

func (p *ROCmMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)

func (*ROCmMemPool) Inner

func (p *ROCmMemPool) Inner() *hip.MemPool

Inner returns the underlying hip.MemPool for backward compatibility.

func (*ROCmMemPool) Stats

func (p *ROCmMemPool) Stats() (int, int)

type ROCmRuntime

type ROCmRuntime struct{}

ROCmRuntime implements the Runtime interface using the AMD HIP runtime API.

func NewROCmRuntime

func NewROCmRuntime() *ROCmRuntime

NewROCmRuntime returns a new ROCm runtime adapter.

func (*ROCmRuntime) CreateStream

func (r *ROCmRuntime) CreateStream() (Stream, error)

func (*ROCmRuntime) DeviceType

func (r *ROCmRuntime) DeviceType() device.Type

func (*ROCmRuntime) Free

func (r *ROCmRuntime) Free(ptr unsafe.Pointer) error

func (*ROCmRuntime) GetDeviceCount

func (r *ROCmRuntime) GetDeviceCount() (int, error)

func (*ROCmRuntime) Malloc

func (r *ROCmRuntime) Malloc(byteSize int) (unsafe.Pointer, error)

func (*ROCmRuntime) Memcpy

func (r *ROCmRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error

func (*ROCmRuntime) MemcpyAsync

func (r *ROCmRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error

func (*ROCmRuntime) MemcpyPeer

func (r *ROCmRuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, count int) error

func (*ROCmRuntime) SetDevice

func (r *ROCmRuntime) SetDevice(deviceID int) error

type Runtime

type Runtime interface {
	// DeviceType returns the device type this runtime manages.
	DeviceType() device.Type

	// SetDevice sets the active GPU device for the calling goroutine.
	SetDevice(deviceID int) error
	// GetDeviceCount returns the number of available GPU devices.
	GetDeviceCount() (int, error)

	// Malloc allocates byteSize bytes of device memory.
	Malloc(byteSize int) (unsafe.Pointer, error)
	// Free releases device memory previously allocated by Malloc.
	Free(ptr unsafe.Pointer) error
	// Memcpy copies count bytes between host and device memory.
	Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
	// MemcpyAsync copies count bytes asynchronously on the given stream.
	MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
	// MemcpyPeer copies count bytes between devices (peer-to-peer).
	MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, count int) error

	// CreateStream creates a new asynchronous command stream.
	CreateStream() (Stream, error)
}

Runtime abstracts GPU device and memory management operations. Each vendor (CUDA, ROCm, OpenCL) provides an implementation.

type Stream

type Stream interface {
	// Synchronize blocks until all commands in the stream have completed.
	Synchronize() error
	// Destroy releases the stream resources.
	Destroy() error
	// Ptr returns the underlying vendor stream handle as an unsafe.Pointer.
	// For CUDA this is cudaStream_t, for ROCm this is hipStream_t.
	Ptr() unsafe.Pointer
}

Stream represents an asynchronous command queue on a GPU device.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL