Documentation
¶
Overview ¶
Package gpuapi defines internal interfaces for GPU runtime operations.
The GPU Runtime Abstraction Layer (GRAL) decouples GPUEngine and GPUStorage from vendor-specific APIs (CUDA, ROCm, OpenCL). Each vendor implements the Runtime, BLAS, DNN, and KernelRunner interfaces via adapter packages.
These interfaces are internal and not exported to users. The public Engine[T] interface in compute/ is unchanged.
Index ¶
- Variables
- func PrintCUBLASProfile()
- type ActivationMode
- type BLAS
- type BLASBatched
- type BLASBatchedTransposeB
- type BLASTransposeB
- type BatchNormMode
- type CUDAArenaPool
- func (p *CUDAArenaPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
- func (p *CUDAArenaPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
- func (p *CUDAArenaPool) Drain() error
- func (p *CUDAArenaPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
- func (p *CUDAArenaPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
- func (p *CUDAArenaPool) Inner() *cuda.ArenaPool
- func (p *CUDAArenaPool) Reset()
- func (p *CUDAArenaPool) SetResetFloor(floor int)
- func (p *CUDAArenaPool) Stats() (int, int)
- func (p *CUDAArenaPool) UsedBytes() int
- type CUDABlas
- func (b *CUDABlas) BFloat16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) Destroy() error
- func (b *CUDABlas) Float16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) Handle() *cublas.Handle
- func (b *CUDABlas) MixedBF16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) MixedFP16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) SetStream(stream Stream) error
- func (b *CUDABlas) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) SgemmNT(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (b *CUDABlas) SgemmNTStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error
- func (b *CUDABlas) SgemmStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error
- type CUDABlasProfiler
- func (p *CUDABlasProfiler) BFloat16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) Destroy() error
- func (p *CUDABlasProfiler) Float16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) Handle() *CUDABlas
- func (p *CUDABlasProfiler) IsEnabled() bool
- func (p *CUDABlasProfiler) MixedBF16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) MixedFP16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) PrintSummary()
- func (p *CUDABlasProfiler) ResetProfile()
- func (p *CUDABlasProfiler) SetStream(stream Stream) error
- func (p *CUDABlasProfiler) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) SgemmNT(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- func (p *CUDABlasProfiler) SgemmNTStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error
- func (p *CUDABlasProfiler) SgemmStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error
- func (p *CUDABlasProfiler) Summary() ProfileSummary
- type CUDADNN
- func (d *CUDADNN) ActivationBackward(mode ActivationMode, y unsafe.Pointer, dy unsafe.Pointer, x unsafe.Pointer, ...) error
- func (d *CUDADNN) ActivationForward(mode ActivationMode, x unsafe.Pointer, shape [4]int, y unsafe.Pointer, ...) error
- func (d *CUDADNN) AddTensor(alpha float32, b unsafe.Pointer, bShape [4]int, beta float32, y unsafe.Pointer, ...) error
- func (d *CUDADNN) BatchNormBackward(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, scale unsafe.Pointer, ...) error
- func (d *CUDADNN) BatchNormForwardInference(x unsafe.Pointer, xShape [4]int, scale, bias, mean, variance unsafe.Pointer, ...) error
- func (d *CUDADNN) BatchNormForwardTraining(x unsafe.Pointer, xShape [4]int, scale, bias unsafe.Pointer, channels int, ...) error
- func (d *CUDADNN) ConvBackwardData(w unsafe.Pointer, wShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error
- func (d *CUDADNN) ConvBackwardFilter(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error
- func (d *CUDADNN) ConvForward(x unsafe.Pointer, xShape [4]int, w unsafe.Pointer, wShape [4]int, ...) error
- func (d *CUDADNN) Destroy() error
- func (d *CUDADNN) Handle() *cudnn.Handle
- func (d *CUDADNN) PoolingBackward(mode PoolingMode, y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int, ...) error
- func (d *CUDADNN) PoolingForward(mode PoolingMode, x unsafe.Pointer, xShape [4]int, y unsafe.Pointer, ...) error
- func (d *CUDADNN) SetStream(stream Stream) error
- func (d *CUDADNN) SoftmaxForward(x unsafe.Pointer, shape [4]int, y unsafe.Pointer, stream Stream) error
- type CUDAKernels
- func (k *CUDAKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error
- func (k *CUDAKernels) AddBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func (k *CUDAKernels) AddFP16(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Argmax(input, result, scratch unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s Stream) error
- func (k *CUDAKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, s Stream) error
- func (k *CUDAKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error
- func (k *CUDAKernels) DivBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func (k *CUDAKernels) DivFP16(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) F32ToFP16(src, dst unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) FP16ToF32(src, dst unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error
- func (k *CUDAKernels) FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, ...) error
- func (k *CUDAKernels) FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, ...) error
- func (k *CUDAKernels) FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func (k *CUDAKernels) FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func (k *CUDAKernels) FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, s Stream) error
- func (k *CUDAKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error
- func (k *CUDAKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, s Stream) error
- func (k *CUDAKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, s Stream) error
- func (k *CUDAKernels) IncrementCounter(counter unsafe.Pointer, delta int, s Stream) error
- func (k *CUDAKernels) Log(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error
- func (k *CUDAKernels) MulBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func (k *CUDAKernels) MulFP16(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error
- func (k *CUDAKernels) OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error
- func (k *CUDAKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error
- func (k *CUDAKernels) RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s Stream) error
- func (k *CUDAKernels) Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s Stream) error
- func (k *CUDAKernels) ResetCounter(counter unsafe.Pointer, value int, s Stream) error
- func (k *CUDAKernels) RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, ...) error
- func (k *CUDAKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func (k *CUDAKernels) ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func (k *CUDAKernels) SgemvM1(y, A, x unsafe.Pointer, M, N int, s Stream) error
- func (k *CUDAKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error
- func (k *CUDAKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error
- func (k *CUDAKernels) SubBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func (k *CUDAKernels) SubFP16(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error
- func (k *CUDAKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error
- func (k *CUDAKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, s Stream) error
- func (k *CUDAKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error
- type CUDAMemPool
- func (p *CUDAMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
- func (p *CUDAMemPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
- func (p *CUDAMemPool) Drain() error
- func (p *CUDAMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
- func (p *CUDAMemPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
- func (p *CUDAMemPool) Inner() *cuda.MemPool
- func (p *CUDAMemPool) Stats() (int, int)
- type CUDARuntime
- func (r *CUDARuntime) CreateStream() (Stream, error)
- func (r *CUDARuntime) DeviceType() device.Type
- func (r *CUDARuntime) Free(ptr unsafe.Pointer) error
- func (r *CUDARuntime) GetDeviceCount() (int, error)
- func (r *CUDARuntime) Malloc(byteSize int) (unsafe.Pointer, error)
- func (r *CUDARuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
- func (r *CUDARuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
- func (r *CUDARuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, ...) error
- func (r *CUDARuntime) SetDevice(deviceID int) error
- type DNN
- type KernelRunner
- type MemPool
- type MemcpyKind
- type OpSummary
- type OpenCLDNN
- func (d *OpenCLDNN) ActivationBackward(_ ActivationMode, _ unsafe.Pointer, _ unsafe.Pointer, _ unsafe.Pointer, ...) error
- func (d *OpenCLDNN) ActivationForward(_ ActivationMode, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ Stream) error
- func (d *OpenCLDNN) AddTensor(_ float32, _ unsafe.Pointer, _ [4]int, _ float32, _ unsafe.Pointer, _ [4]int, ...) error
- func (d *OpenCLDNN) BatchNormBackward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ unsafe.Pointer, _ int, ...) error
- func (d *OpenCLDNN) BatchNormForwardInference(_ unsafe.Pointer, _ [4]int, _, _, _, _ unsafe.Pointer, _ int, _ float64, ...) error
- func (d *OpenCLDNN) BatchNormForwardTraining(_ unsafe.Pointer, _ [4]int, _, _ unsafe.Pointer, _ int, _, _ float64, ...) error
- func (d *OpenCLDNN) ConvBackwardData(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error
- func (d *OpenCLDNN) ConvBackwardFilter(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error
- func (d *OpenCLDNN) ConvForward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error
- func (d *OpenCLDNN) Destroy() error
- func (d *OpenCLDNN) PoolingBackward(_ PoolingMode, _ unsafe.Pointer, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error
- func (d *OpenCLDNN) PoolingForward(_ PoolingMode, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, ...) error
- func (d *OpenCLDNN) SetStream(_ Stream) error
- func (d *OpenCLDNN) SoftmaxForward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ Stream) error
- type OpenCLMemPool
- func (p *OpenCLMemPool) Alloc(_ int, byteSize int) (unsafe.Pointer, error)
- func (p *OpenCLMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)
- func (p *OpenCLMemPool) Drain() error
- func (p *OpenCLMemPool) Free(_ int, ptr unsafe.Pointer, byteSize int)
- func (p *OpenCLMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)
- func (p *OpenCLMemPool) Stats() (int, int)
- type OpenCLRuntime
- func (r *OpenCLRuntime) CLContext() unsafe.Pointer
- func (r *OpenCLRuntime) CLDevice() unsafe.Pointer
- func (r *OpenCLRuntime) CLQueue() unsafe.Pointer
- func (r *OpenCLRuntime) CreateStream() (Stream, error)
- func (r *OpenCLRuntime) DeviceType() device.Type
- func (r *OpenCLRuntime) Free(ptr unsafe.Pointer) error
- func (r *OpenCLRuntime) GetDeviceCount() (int, error)
- func (r *OpenCLRuntime) Malloc(byteSize int) (unsafe.Pointer, error)
- func (r *OpenCLRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
- func (r *OpenCLRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, _ Stream) error
- func (r *OpenCLRuntime) MemcpyPeer(dst unsafe.Pointer, _ int, src unsafe.Pointer, _ int, count int) error
- func (r *OpenCLRuntime) SetDevice(deviceID int) error
- type PoolingMode
- type ProfileSummary
- type ROCmBlas
- func (b *ROCmBlas) BFloat16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error
- func (b *ROCmBlas) Destroy() error
- func (b *ROCmBlas) Float16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error
- func (b *ROCmBlas) Handle() *rocblas.Handle
- func (b *ROCmBlas) MixedBF16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error
- func (b *ROCmBlas) MixedFP16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error
- func (b *ROCmBlas) SetStream(stream Stream) error
- func (b *ROCmBlas) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error
- type ROCmDNN
- func (d *ROCmDNN) ActivationBackward(mode ActivationMode, y unsafe.Pointer, dy unsafe.Pointer, x unsafe.Pointer, ...) error
- func (d *ROCmDNN) ActivationForward(mode ActivationMode, x unsafe.Pointer, shape [4]int, y unsafe.Pointer, ...) error
- func (d *ROCmDNN) AddTensor(alpha float32, b unsafe.Pointer, bShape [4]int, beta float32, y unsafe.Pointer, ...) error
- func (d *ROCmDNN) BatchNormBackward(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, scale unsafe.Pointer, ...) error
- func (d *ROCmDNN) BatchNormForwardInference(x unsafe.Pointer, xShape [4]int, scale, bias, mean, variance unsafe.Pointer, ...) error
- func (d *ROCmDNN) BatchNormForwardTraining(x unsafe.Pointer, xShape [4]int, scale, bias unsafe.Pointer, channels int, ...) error
- func (d *ROCmDNN) ConvBackwardData(w unsafe.Pointer, wShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error
- func (d *ROCmDNN) ConvBackwardFilter(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error
- func (d *ROCmDNN) ConvForward(x unsafe.Pointer, xShape [4]int, w unsafe.Pointer, wShape [4]int, ...) error
- func (d *ROCmDNN) Destroy() error
- func (d *ROCmDNN) Handle() *miopen.Handle
- func (d *ROCmDNN) PoolingBackward(mode PoolingMode, y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int, ...) error
- func (d *ROCmDNN) PoolingForward(mode PoolingMode, x unsafe.Pointer, xShape [4]int, y unsafe.Pointer, ...) error
- func (d *ROCmDNN) SetStream(stream Stream) error
- func (d *ROCmDNN) SoftmaxForward(x unsafe.Pointer, shape [4]int, y unsafe.Pointer, stream Stream) error
- type ROCmKernels
- func (k *ROCmKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error
- func (k *ROCmKernels) AddBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) AddFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Argmax(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) DequantFP8E4M3ToFP16(_, _ unsafe.Pointer, _ float32, _ int, _ Stream) error
- func (k *ROCmKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, _ Stream) error
- func (k *ROCmKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error
- func (k *ROCmKernels) DivBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) DivFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) F32ToFP16(_, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) FP16ToF32(_, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error
- func (k *ROCmKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error
- func (k *ROCmKernels) FusedNormAddF32(_, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error
- func (k *ROCmKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ float32, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) FusedRoPEF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) FusedSwiGLUF32(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, _ Stream) error
- func (k *ROCmKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error
- func (k *ROCmKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, _ Stream) error
- func (k *ROCmKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, _ Stream) error
- func (k *ROCmKernels) IncrementCounter(_ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) Log(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error
- func (k *ROCmKernels) MulBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) MulFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) OffsetMemcpy(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error
- func (k *ROCmKernels) OffsetMemcpyFP16(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error
- func (k *ROCmKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) PowScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error
- func (k *ROCmKernels) RMSNormFP16(_, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error
- func (k *ROCmKernels) Repeat(_ unsafe.Pointer, _ unsafe.Pointer, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) ResetCounter(_ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) RoPESelect(_, _, _, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error
- func (k *ROCmKernels) ScaledSoftmaxFP16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error
- func (k *ROCmKernels) SgemvM1(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error
- func (k *ROCmKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error
- func (k *ROCmKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error
- func (k *ROCmKernels) SubBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error
- func (k *ROCmKernels) SubFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) SubScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error
- func (k *ROCmKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error
- func (k *ROCmKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error
- func (k *ROCmKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, _ Stream) error
- func (k *ROCmKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error
- type ROCmMemPool
- func (p *ROCmMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
- func (p *ROCmMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)
- func (p *ROCmMemPool) Drain() error
- func (p *ROCmMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
- func (p *ROCmMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)
- func (p *ROCmMemPool) Inner() *hip.MemPool
- func (p *ROCmMemPool) Stats() (int, int)
- type ROCmRuntime
- func (r *ROCmRuntime) CreateStream() (Stream, error)
- func (r *ROCmRuntime) DeviceType() device.Type
- func (r *ROCmRuntime) Free(ptr unsafe.Pointer) error
- func (r *ROCmRuntime) GetDeviceCount() (int, error)
- func (r *ROCmRuntime) Malloc(byteSize int) (unsafe.Pointer, error)
- func (r *ROCmRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
- func (r *ROCmRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
- func (r *ROCmRuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, ...) error
- func (r *ROCmRuntime) SetDevice(deviceID int) error
- type Runtime
- type Stream
Constants ¶
This section is empty.
Variables ¶
var BLASFactory func() (BLAS, error)
BLASFactory creates a BLAS instance. Registered by cuda_blas_purego.go via init() when the cuBLAS library is available at runtime.
var DNNFactory func() (DNN, error)
DNNFactory creates a DNN instance. Registered by cuda_dnn.go via init().
Functions ¶
func PrintCUBLASProfile ¶
func PrintCUBLASProfile()
PrintCUBLASProfile prints the cuBLAS profiling summary if profiling is enabled. Safe to call even when profiling is disabled (no-op).
Types ¶
type ActivationMode ¶
type ActivationMode int
ActivationMode selects the activation function for DNN operations.
const ( ActivationSigmoid ActivationMode = iota ActivationReLU ActivationTanh ActivationClippedReLU ActivationELU )
type BLAS ¶
type BLAS interface {
// Sgemm performs single-precision general matrix multiplication:
// C = alpha * A * B + beta * C
// where A is m x k, B is k x n, and C is m x n.
// All matrices are contiguous row-major. The implementation handles
// the row-major to column-major conversion internally.
Sgemm(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
// BFloat16Gemm performs BFloat16 general matrix multiplication:
// C = alpha * A * B + beta * C
// where A is m x k, B is k x n, and C is m x n.
// All matrices are contiguous row-major BFloat16 elements.
// Computation is performed in float32 for precision (CUBLAS_COMPUTE_32F).
// Returns an error on backends that do not support BFloat16 GEMM.
BFloat16Gemm(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
// Float16Gemm performs FP16 general matrix multiplication:
// C = alpha * A * B + beta * C
// where A is m x k, B is k x n, and C is m x n.
// All matrices are contiguous row-major FP16 elements.
// Computation is performed in float32 for precision (CUBLAS_COMPUTE_32F).
Float16Gemm(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
// MixedFP16Gemm performs mixed-precision GEMM with FP16 inputs and FP32 output:
// C_f32 = alpha * A_fp16 * B_fp16 + beta * C_f32
MixedFP16Gemm(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
// MixedBF16Gemm performs mixed-precision GEMM with BF16 weights and FP32 output:
// C_f32 = alpha * A_bf16 * B_bf16 + beta * C_f32
// where A is m x k, B is k x n (both BFloat16), and C is m x n (float32).
// Computation uses CUBLAS_COMPUTE_32F for precision.
// Returns an error on backends that do not support mixed-precision GEMM.
MixedBF16Gemm(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
// SetStream associates the BLAS handle with an asynchronous stream.
SetStream(stream Stream) error
// Destroy releases the BLAS handle resources.
Destroy() error
}
BLAS abstracts GPU-accelerated Basic Linear Algebra Subprograms. Each vendor (cuBLAS, rocBLAS, CLBlast) provides an implementation.
type BLASBatched ¶
type BLASBatched interface {
SgemmStridedBatched(m, n, k int, alpha float32,
a unsafe.Pointer, strideA int64,
b unsafe.Pointer, strideB int64,
beta float32,
c unsafe.Pointer, strideC int64,
batch int,
) error
}
BLASBatched is an optional extension that supports strided batched GEMM. All batch elements share the same m, n, k dimensions and alpha/beta scalars. Matrices are accessed at base + i*stride for batch element i.
type BLASBatchedTransposeB ¶
type BLASBatchedTransposeB interface {
SgemmNTStridedBatched(m, n, k int, alpha float32,
a unsafe.Pointer, strideA int64,
b unsafe.Pointer, strideB int64,
beta float32,
c unsafe.Pointer, strideC int64,
batch int,
) error
}
BLASBatchedTransposeB is an optional extension that supports strided batched C = A * B^T without explicitly transposing B. A is [m, k], B is [n, k] per batch.
type BLASTransposeB ¶
type BLASTransposeB interface {
SgemmNT(m, n, k int, alpha float32,
a unsafe.Pointer, b unsafe.Pointer,
beta float32, c unsafe.Pointer,
) error
}
BLASTransposeB is an optional extension that supports computing C = alpha * A * B^T + beta * C without explicitly transposing B. A is m x k (row-major), B is n x k (row-major), C is m x n.
type BatchNormMode ¶
type BatchNormMode int
BatchNormMode selects the batch normalization mode.
const ( BatchNormPerActivation BatchNormMode = iota BatchNormSpatial )
type CUDAArenaPool ¶
type CUDAArenaPool struct {
// contains filtered or unexported fields
}
CUDAArenaPool adapts cuda.ArenaPool to the gpuapi.MemPool interface. It also exposes Reset() for use between forward passes.
func NewCUDAArenaPool ¶
func NewCUDAArenaPool(deviceID, capacityBytes int, fallback *cuda.MemPool) (*CUDAArenaPool, error)
NewCUDAArenaPool creates a new arena-backed pool on the given device. capacityBytes is the size of the pre-allocated arena region. fallback is the MemPool used when the arena is exhausted.
func (*CUDAArenaPool) Alloc ¶
func (p *CUDAArenaPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
func (*CUDAArenaPool) AllocManaged ¶
func (p *CUDAArenaPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
func (*CUDAArenaPool) Drain ¶
func (p *CUDAArenaPool) Drain() error
func (*CUDAArenaPool) Free ¶
func (p *CUDAArenaPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
func (*CUDAArenaPool) FreeManaged ¶
func (p *CUDAArenaPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
func (*CUDAArenaPool) Inner ¶
func (p *CUDAArenaPool) Inner() *cuda.ArenaPool
Inner returns the underlying cuda.ArenaPool.
func (*CUDAArenaPool) Reset ¶
func (p *CUDAArenaPool) Reset()
Reset rewinds the arena, reclaiming all per-pass allocations.
func (*CUDAArenaPool) SetResetFloor ¶
func (p *CUDAArenaPool) SetResetFloor(floor int)
SetResetFloor sets the minimum offset that Reset will rewind to.
func (*CUDAArenaPool) Stats ¶
func (p *CUDAArenaPool) Stats() (int, int)
func (*CUDAArenaPool) UsedBytes ¶
func (p *CUDAArenaPool) UsedBytes() int
UsedBytes returns the current arena offset (bytes in use).
type CUDABlas ¶
type CUDABlas struct {
// contains filtered or unexported fields
}
CUDABlas implements the BLAS interface using cuBLAS via purego.
func NewCUDABlas ¶
NewCUDABlas creates a new cuBLAS adapter. The caller must call Destroy when done.
func NewCUDABlasFromHandle ¶
NewCUDABlasFromHandle wraps an existing cuBLAS handle. The caller retains ownership; Destroy on this adapter is a no-op.
func (*CUDABlas) BFloat16Gemm ¶
func (*CUDABlas) Float16Gemm ¶
func (*CUDABlas) MixedBF16Gemm ¶
func (*CUDABlas) MixedFP16Gemm ¶
func (*CUDABlas) SgemmNT ¶
func (b *CUDABlas) SgemmNT(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, beta float32, c unsafe.Pointer, ) error
SgemmNT performs C = alpha * A * B^T + beta * C where A is [m, k] and B is [n, k] (row-major). This avoids an explicit Transpose of B.
func (*CUDABlas) SgemmNTStridedBatched ¶
func (b *CUDABlas) SgemmNTStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, bPtr unsafe.Pointer, strideB int64, beta float32, c unsafe.Pointer, strideC int64, batch int, ) error
SgemmNTStridedBatched performs batched C = A * B^T using strided batched GEMM.
func (*CUDABlas) SgemmStridedBatched ¶
func (b *CUDABlas) SgemmStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, bPtr unsafe.Pointer, strideB int64, beta float32, c unsafe.Pointer, strideC int64, batch int, ) error
SgemmStridedBatched performs batched C = alpha * A * B + beta * C using cublasSgemmStridedBatched. All batch elements share the same m, n, k.
type CUDABlasProfiler ¶
type CUDABlasProfiler struct {
// contains filtered or unexported fields
}
CUDABlasProfiler wraps CUDABlas with optional per-call timing. When ZERFOO_PROFILE_CUBLAS=1, each Sgemm/SgemmNT/batched call is timed and recorded. Call PrintSummary to dump stats.
func WrapWithProfiler ¶
func WrapWithProfiler(b *CUDABlas) *CUDABlasProfiler
WrapWithProfiler returns a profiling wrapper if ZERFOO_PROFILE_CUBLAS=1, otherwise returns the original CUDABlas unchanged.
func (*CUDABlasProfiler) BFloat16Gemm ¶
func (*CUDABlasProfiler) Destroy ¶
func (p *CUDABlasProfiler) Destroy() error
func (*CUDABlasProfiler) Float16Gemm ¶
func (*CUDABlasProfiler) Handle ¶
func (p *CUDABlasProfiler) Handle() *CUDABlas
func (*CUDABlasProfiler) IsEnabled ¶
func (p *CUDABlasProfiler) IsEnabled() bool
IsEnabled returns whether profiling is active.
func (*CUDABlasProfiler) MixedBF16Gemm ¶
func (*CUDABlasProfiler) MixedFP16Gemm ¶
func (*CUDABlasProfiler) PrintSummary ¶
func (p *CUDABlasProfiler) PrintSummary()
PrintSummary prints cuBLAS profiling stats to stderr.
func (*CUDABlasProfiler) ResetProfile ¶
func (p *CUDABlasProfiler) ResetProfile()
ResetProfile clears all recorded calls and increments the generation.
func (*CUDABlasProfiler) SetStream ¶
func (p *CUDABlasProfiler) SetStream(stream Stream) error
func (*CUDABlasProfiler) SgemmNTStridedBatched ¶
func (*CUDABlasProfiler) SgemmStridedBatched ¶
func (*CUDABlasProfiler) Summary ¶
func (p *CUDABlasProfiler) Summary() ProfileSummary
Summary returns aggregated profiling statistics.
type CUDADNN ¶
type CUDADNN struct {
// contains filtered or unexported fields
}
CUDADNN implements the DNN interface using cuDNN.
func NewCUDADNNFromHandle ¶
NewCUDADNNFromHandle wraps an existing cuDNN handle.
func (*CUDADNN) ActivationBackward ¶
func (*CUDADNN) ActivationForward ¶
func (*CUDADNN) BatchNormBackward ¶
func (*CUDADNN) BatchNormForwardInference ¶
func (*CUDADNN) BatchNormForwardTraining ¶
func (*CUDADNN) ConvBackwardData ¶
func (*CUDADNN) ConvBackwardFilter ¶
func (*CUDADNN) ConvForward ¶
func (*CUDADNN) PoolingBackward ¶
func (*CUDADNN) PoolingForward ¶
type CUDAKernels ¶
type CUDAKernels struct{}
CUDAKernels implements the KernelRunner interface using custom CUDA kernels.
func NewCUDAKernels ¶
func NewCUDAKernels() *CUDAKernels
NewCUDAKernels returns a new CUDA kernel runner adapter.
func (*CUDAKernels) AddBroadcast ¶
func (*CUDAKernels) AddBroadcast4D ¶
func (*CUDAKernels) DequantFP8E4M3ToFP16 ¶
func (*CUDAKernels) DequantQ4KF32 ¶
func (*CUDAKernels) DivBroadcast ¶
func (*CUDAKernels) DivBroadcast4D ¶
func (*CUDAKernels) FusedAddRMSNormF32 ¶
func (*CUDAKernels) FusedNormAddF32 ¶
func (*CUDAKernels) FusedQKNormRoPEF32 ¶
func (*CUDAKernels) FusedRoPEF32 ¶
func (*CUDAKernels) FusedSwiGLUF32 ¶
func (*CUDAKernels) GemvQ4KF32 ¶
func (*CUDAKernels) IncrementCounter ¶
func (*CUDAKernels) MulBroadcast ¶
func (*CUDAKernels) MulBroadcast4D ¶
func (*CUDAKernels) OffsetMemcpy ¶
func (*CUDAKernels) OffsetMemcpyFP16 ¶
func (*CUDAKernels) RMSNormFP16 ¶
func (*CUDAKernels) ResetCounter ¶
func (*CUDAKernels) RoPESelect ¶
func (*CUDAKernels) ScaledSoftmaxF32 ¶
func (*CUDAKernels) ScaledSoftmaxFP16 ¶
func (*CUDAKernels) SubBroadcast ¶
func (*CUDAKernels) SubBroadcast4D ¶
func (*CUDAKernels) Transpose2D ¶
func (*CUDAKernels) TransposeND ¶
type CUDAMemPool ¶
type CUDAMemPool struct {
// contains filtered or unexported fields
}
CUDAMemPool implements the MemPool interface by wrapping cuda.MemPool.
func NewCUDAMemPool ¶
func NewCUDAMemPool() *CUDAMemPool
NewCUDAMemPool creates a new CUDA memory pool adapter.
func NewCUDAMemPoolFrom ¶
func NewCUDAMemPoolFrom(pool *cuda.MemPool) *CUDAMemPool
NewCUDAMemPoolFrom wraps an existing cuda.MemPool.
func (*CUDAMemPool) Alloc ¶
func (p *CUDAMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
func (*CUDAMemPool) AllocManaged ¶
func (p *CUDAMemPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
func (*CUDAMemPool) Drain ¶
func (p *CUDAMemPool) Drain() error
func (*CUDAMemPool) Free ¶
func (p *CUDAMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
func (*CUDAMemPool) FreeManaged ¶
func (p *CUDAMemPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
func (*CUDAMemPool) Inner ¶
func (p *CUDAMemPool) Inner() *cuda.MemPool
Inner returns the underlying cuda.MemPool for backward compatibility.
func (*CUDAMemPool) Stats ¶
func (p *CUDAMemPool) Stats() (int, int)
type CUDARuntime ¶
type CUDARuntime struct{}
CUDARuntime implements the Runtime interface using the CUDA runtime API.
func NewCUDARuntime ¶
func NewCUDARuntime() *CUDARuntime
NewCUDARuntime returns a new CUDA runtime adapter.
func (*CUDARuntime) CreateStream ¶
func (r *CUDARuntime) CreateStream() (Stream, error)
func (*CUDARuntime) DeviceType ¶
func (r *CUDARuntime) DeviceType() device.Type
func (*CUDARuntime) GetDeviceCount ¶
func (r *CUDARuntime) GetDeviceCount() (int, error)
func (*CUDARuntime) Memcpy ¶
func (r *CUDARuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
func (*CUDARuntime) MemcpyAsync ¶
func (r *CUDARuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
func (*CUDARuntime) MemcpyPeer ¶
func (*CUDARuntime) SetDevice ¶
func (r *CUDARuntime) SetDevice(deviceID int) error
type DNN ¶
type DNN interface {
// ConvForward performs 2D convolution.
// x: [N,C_in,H,W], w: [C_out,C_in/groups,kH,kW], y: [N,C_out,outH,outW].
// bias is optional (nil to skip).
// pads: [padH, padW] (symmetric), strides: [sH, sW], dilations: [dH, dW].
ConvForward(
x unsafe.Pointer, xShape [4]int,
w unsafe.Pointer, wShape [4]int,
bias unsafe.Pointer,
y unsafe.Pointer, yShape [4]int,
pads [2]int, strides [2]int, dilations [2]int,
groups int,
stream Stream,
) error
// ConvBackwardData computes the gradient of the input for 2D convolution.
// w: [C_out,C_in/groups,kH,kW], dy: [N,C_out,outH,outW], dx: [N,C_in,H,W].
ConvBackwardData(
w unsafe.Pointer, wShape [4]int,
dy unsafe.Pointer, dyShape [4]int,
dx unsafe.Pointer, dxShape [4]int,
pads [2]int, strides [2]int, dilations [2]int,
groups int,
stream Stream,
) error
// ConvBackwardFilter computes the gradient of the filter for 2D convolution.
// x: [N,C_in,H,W], dy: [N,C_out,outH,outW], dw: [C_out,C_in/groups,kH,kW].
ConvBackwardFilter(
x unsafe.Pointer, xShape [4]int,
dy unsafe.Pointer, dyShape [4]int,
dw unsafe.Pointer, dwShape [4]int,
pads [2]int, strides [2]int, dilations [2]int,
groups int,
stream Stream,
) error
// BatchNormForwardInference performs batch normalization using running statistics.
// x: [N,C,H,W], scale/bias/mean/variance: [C], y: [N,C,H,W].
BatchNormForwardInference(
x unsafe.Pointer, xShape [4]int,
scale, bias, mean, variance unsafe.Pointer,
channels int,
epsilon float64,
y unsafe.Pointer,
stream Stream,
) error
// BatchNormForwardTraining performs batch normalization computing batch statistics.
// x: [N,C,H,W], scale/bias: [C], y: [N,C,H,W].
// saveMean and saveInvVariance are outputs for the backward pass, each [C].
// runningMean and runningVariance are updated in-place with exponential averaging.
BatchNormForwardTraining(
x unsafe.Pointer, xShape [4]int,
scale, bias unsafe.Pointer,
channels int,
epsilon, expAvgFactor float64,
runningMean, runningVariance unsafe.Pointer,
saveMean, saveInvVariance unsafe.Pointer,
y unsafe.Pointer,
stream Stream,
) error
// BatchNormBackward computes gradients for batch normalization.
// x: [N,C,H,W], dy: [N,C,H,W], scale: [C].
// saveMean, saveInvVariance: [C] (from BatchNormForwardTraining).
// dx: [N,C,H,W], dScale, dBias: [C].
BatchNormBackward(
x unsafe.Pointer, xShape [4]int,
dy unsafe.Pointer,
scale unsafe.Pointer,
channels int,
saveMean, saveInvVariance unsafe.Pointer,
dx, dScale, dBias unsafe.Pointer,
stream Stream,
) error
// ActivationForward applies an activation function element-wise.
// x and y have the same shape [N,C,H,W].
ActivationForward(
mode ActivationMode,
x unsafe.Pointer, shape [4]int,
y unsafe.Pointer,
stream Stream,
) error
// ActivationBackward computes the gradient of an activation function.
// x: original input, y: forward output, dy: upstream gradient, dx: output gradient.
// All have shape [N,C,H,W].
ActivationBackward(
mode ActivationMode,
y unsafe.Pointer, dy unsafe.Pointer,
x unsafe.Pointer, dx unsafe.Pointer,
shape [4]int,
stream Stream,
) error
// PoolingForward performs 2D pooling.
// x: [N,C,H,W], y: [N,C,outH,outW].
PoolingForward(
mode PoolingMode,
x unsafe.Pointer, xShape [4]int,
y unsafe.Pointer, yShape [4]int,
windowH, windowW, padH, padW, strideH, strideW int,
stream Stream,
) error
// PoolingBackward computes the gradient of 2D pooling.
// y: forward output, dy: upstream gradient, x: forward input, dx: output gradient.
PoolingBackward(
mode PoolingMode,
y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int,
x unsafe.Pointer, dx unsafe.Pointer, xShape [4]int,
windowH, windowW, padH, padW, strideH, strideW int,
stream Stream,
) error
// SoftmaxForward computes softmax over the channel dimension.
// x and y have the same shape [N,C,H,W].
SoftmaxForward(
x unsafe.Pointer, shape [4]int,
y unsafe.Pointer,
stream Stream,
) error
// AddTensor performs y = alpha*b + beta*y for bias addition.
// b: [1,C,1,1], y: [N,C,H,W].
AddTensor(
alpha float32,
b unsafe.Pointer, bShape [4]int,
beta float32,
y unsafe.Pointer, yShape [4]int,
stream Stream,
) error
// SetStream associates the DNN handle with an asynchronous stream.
SetStream(stream Stream) error
// Destroy releases the DNN handle resources.
Destroy() error
}
DNN abstracts GPU-accelerated deep neural network primitives. Each vendor (cuDNN, MIOpen) provides an implementation.
All tensor pointers are device memory. Shapes follow NCHW layout. The implementation handles descriptor creation internally.
type KernelRunner ¶
type KernelRunner interface {
// Binary elementwise operations: c[i] = op(a[i], b[i])
Add(a, b, c unsafe.Pointer, n int, stream Stream) error
Sub(a, b, c unsafe.Pointer, n int, stream Stream) error
Mul(a, b, c unsafe.Pointer, n int, stream Stream) error
Div(a, b, c unsafe.Pointer, n int, stream Stream) error
Pow(base, exp, c unsafe.Pointer, n int, stream Stream) error
// Unary elementwise operations: c[i] = op(a[i])
Exp(a, c unsafe.Pointer, n int, stream Stream) error
Log(a, c unsafe.Pointer, n int, stream Stream) error
Sqrt(a, c unsafe.Pointer, n int, stream Stream) error
Rsqrt(a, c unsafe.Pointer, n int, stream Stream) error
Sin(a, c unsafe.Pointer, n int, stream Stream) error
Cos(a, c unsafe.Pointer, n int, stream Stream) error
Tanh(a, c unsafe.Pointer, n int, stream Stream) error
// TanhPrime: c[i] = (1 - tanh(a[i])^2) * upstream[i]
TanhPrime(a, upstream, c unsafe.Pointer, n int, stream Stream) error
// Scalar operations: c[i] = op(a[i], scalar)
AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error
// Fill sets all n elements to value.
Fill(data unsafe.Pointer, value float32, n int, stream Stream) error
// SumAxis reduces along one axis: output[outer][inner] = sum(input[outer][k][inner], k=0..axisSize-1).
SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error
// Softmax computes softmax along one axis.
Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error
// GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B.
// A_q4 is in GPU separated layout (scales then data), B is [K,N] float32, C is [M,N] float32.
// dataOffset is the byte offset from A_q4 to the packed data region.
GemmQ4F32(aQ4, b, c unsafe.Pointer, m, k, n, dataOffset int, stream Stream) error
// GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x.
// W_q4k is raw Q4_K super-blocks for matrix [M, K]. x is [K] float32.
// y is [M] float32. K must be a multiple of 256. Batch=1 only.
GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, stream Stream) error
// DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory.
// src is raw Q4_K super-blocks for matrix [rows, K]. dst is [rows, K] float32.
// K must be a multiple of 256. Used for non-GEMV cuBLAS path.
DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, stream Stream) error
// GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B.
// A_q8 is packed Q8_0 blocks (36 bytes per 32 values), B is [K,N] float32, C is [M,N] float32.
GemmQ8F32(aQ8, b, c unsafe.Pointer, m, k, n int, stream Stream) error
// Broadcast binary ops: c[r,c] = op(a[r*saRow+c*saCol], b[r*sbRow+c*sbCol]).
// Strides encode broadcasting: D for full row, 1 for full col, 0 for broadcast.
AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error
// 4D broadcast binary ops: c[i0,i1,i2,i3] = op(a[...], b[...]) with per-dim strides.
// d0-d3 are output dims; sa0-sa3 and sb0-sb3 are per-dim strides (0 = broadcast).
AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, stream Stream) error
// Transpose2D transposes a [rows, cols] matrix to [cols, rows] using tiled shared memory.
Transpose2D(input, output unsafe.Pointer, rows, cols int, stream Stream) error
// TransposeND permutes dimensions of an N-D tensor.
// inStrides/outStrides/perm are int32 slices on host.
TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, stream Stream) error
// Gather performs embedding table lookup: output[i,:] = table[indices[i],:].
// table: [V, D], indices: [N] int64 on device, output: [N, D].
Gather(table, indices, output unsafe.Pointer, N, D, V int, stream Stream) error
// RMSNorm computes fused RMSNorm: output = input * rsqrt(mean(input^2) + eps) * weight.
// input: [rows, D], weight: [D], output: [rows, D], scales: [rows] (per-row rsqrt values for backward).
RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, stream Stream) error
// Repeat replicates elements along an axis.
// outerSize = product of dims before axis, axisDim = size of axis,
// innerSize = product of dims after axis, reps = number of repetitions.
Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, stream Stream) error
// Argmax finds the index of the maximum element in a float32 array on device.
// input: [n] float32, result: single int32 on device, scratch: temp storage.
// scratch must be at least 2*ceil(n/256)*4 bytes.
Argmax(input, result, scratch unsafe.Pointer, n int, stream Stream) error
// FusedRoPEF32 applies rotary positional embedding in one kernel launch.
// input/output: [batch * seqLen * headDim], cos/sin: [seqLen * cosStride].
FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, batch, seqLen, headDim, halfRotary, cosStride int, stream Stream) error
// FusedSwiGLUF32 applies SwiGLU activation in one kernel launch.
// output[i] = w1[i] * sigmoid(w1[i]) * w3[i]. All arrays have n elements.
FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, stream Stream) error
// FusedAddRMSNormF32 fuses residual addition and RMSNorm into one kernel launch.
// sum_out = input + residual, normed_out = rmsnorm(sum_out, weight, eps).
// input: [rows, D], residual: [rows, D], weight: [D],
// normedOut: [rows, D], sumOut: [rows, D].
FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, stream Stream) error
// FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch.
// output = rmsnorm(input, weight, eps) + residual.
// input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].
FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, stream Stream) error
// FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads.
// Replaces 4 kernel launches (Q_norm + K_norm + Q_RoPE + K_RoPE) with 1.
// input: [totalHeads, headDim], weightQ/weightK: [headDim],
// cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].
FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, stream Stream) error
// ScaledSoftmaxF32 computes softmax(input * scale) in one kernel launch,
// replacing the MulScalar + Softmax chain (saves 1 kernel launch per call).
ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream Stream) error
// FP16 elementwise operations: inputs and outputs are __half (2 bytes each).
AddFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
SubFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
MulFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
DivFP16(a, b, c unsafe.Pointer, n int, stream Stream) error
// RMSNormFP16 computes RMSNorm on FP16 data with FP32 accumulation.
RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, stream Stream) error
// ScaledSoftmaxFP16 computes softmax(input * scale) on FP16 data with FP32 accumulation.
ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream Stream) error
// F32ToFP16 converts n float32 elements to FP16 on device.
F32ToFP16(src, dst unsafe.Pointer, n int, stream Stream) error
// FP16ToF32 converts n FP16 elements to float32 on device.
FP16ToF32(src, dst unsafe.Pointer, n int, stream Stream) error
// DequantFP8E4M3ToFP16 dequantizes n FP8 E4M3 bytes to FP16 on device.
// output[i] = fp8_to_fp16(input[i]) * scale.
DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, stream Stream) error
// GPU-resident counter operations for CUDA graph position tracking.
IncrementCounter(counter unsafe.Pointer, delta int, stream Stream) error
ResetCounter(counter unsafe.Pointer, value int, stream Stream) error
// OffsetMemcpy copies dim floats from src to dst at offset counter*dim.
// counter is a GPU-resident int32. Used for GPU-driven KV cache append.
OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error
// OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim.
// counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.
OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error
// RoPESelect copies halfRotary cos/sin values from the precomputed table
// at position counter[0]. Used for GPU-driven RoPE angle selection.
RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, stream Stream) error
// SgemvM1 computes y = A*x for M=1 decode (single-token GEMV).
// y[M], A[M x N] row-major, x[N].
SgemvM1(y, A, x unsafe.Pointer, M, N int, stream Stream) error
}
KernelRunner abstracts GPU compute kernels for elementwise, scalar, reduction, and utility operations. Each vendor provides an implementation using its own kernel compilation toolchain (CUDA .cu, HIP .hip, OpenCL .cl).
type MemPool ¶
type MemPool interface {
// Alloc returns a device pointer of at least byteSize bytes on the given device.
// May return a cached pointer from a previous Free call.
Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
// Free returns a device pointer to the pool for reuse.
Free(deviceID int, ptr unsafe.Pointer, byteSize int)
// AllocManaged returns a unified memory pointer accessible from both host
// and device. Returns an error on backends that do not support managed memory.
AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error)
// FreeManaged returns a managed memory pointer to the pool for reuse.
FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int)
// Drain frees all cached pointers back to the device.
Drain() error
// Stats returns the number of cached allocations and their total bytes.
Stats() (allocations int, totalBytes int)
}
MemPool abstracts a GPU device memory pool with size-bucketed caching. Each vendor can reuse the same pool logic since the pool operates on opaque device pointers, but the underlying Malloc/Free come from the vendor's Runtime.
type MemcpyKind ¶
type MemcpyKind int
MemcpyKind specifies the direction of a memory copy operation.
const ( // MemcpyHostToDevice copies from host (CPU) memory to device (GPU) memory. MemcpyHostToDevice MemcpyKind = iota // MemcpyDeviceToHost copies from device (GPU) memory to host (CPU) memory. MemcpyDeviceToHost // MemcpyDeviceToDevice copies between device (GPU) memory regions. MemcpyDeviceToDevice )
type OpSummary ¶
type OpSummary struct {
Op string
M, N, K int
Batch int
Calls int
TotalTime time.Duration
AvgTime time.Duration
}
OpSummary holds per-operation stats.
type OpenCLDNN ¶
type OpenCLDNN struct{}
OpenCLDNN implements the DNN interface for OpenCL. OpenCL has no standard DNN library (like cuDNN or MIOpen), so all operations return ErrNotSupported. The compute engine falls back to CPU.
func (*OpenCLDNN) ActivationBackward ¶
func (*OpenCLDNN) ActivationForward ¶
func (*OpenCLDNN) BatchNormBackward ¶
func (*OpenCLDNN) BatchNormForwardInference ¶
func (*OpenCLDNN) BatchNormForwardTraining ¶
func (*OpenCLDNN) ConvBackwardData ¶
func (*OpenCLDNN) ConvBackwardFilter ¶
func (*OpenCLDNN) ConvForward ¶
func (*OpenCLDNN) PoolingBackward ¶
func (*OpenCLDNN) PoolingForward ¶
type OpenCLMemPool ¶
type OpenCLMemPool struct {
// contains filtered or unexported fields
}
OpenCLMemPool implements the MemPool interface for OpenCL. It uses a simple size-bucketed cache of cl_mem buffers.
func NewOpenCLMemPool ¶
func NewOpenCLMemPool(rt *OpenCLRuntime) *OpenCLMemPool
NewOpenCLMemPool creates a new OpenCL memory pool.
func (*OpenCLMemPool) AllocManaged ¶
func (p *OpenCLMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)
func (*OpenCLMemPool) Drain ¶
func (p *OpenCLMemPool) Drain() error
func (*OpenCLMemPool) FreeManaged ¶
func (p *OpenCLMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)
func (*OpenCLMemPool) Stats ¶
func (p *OpenCLMemPool) Stats() (int, int)
type OpenCLRuntime ¶
type OpenCLRuntime struct {
// contains filtered or unexported fields
}
OpenCLRuntime implements the Runtime interface using OpenCL.
func NewOpenCLRuntime ¶
func NewOpenCLRuntime() *OpenCLRuntime
NewOpenCLRuntime returns a new OpenCL runtime adapter. Returns nil if libOpenCL is not available on this system.
func (*OpenCLRuntime) CLContext ¶
func (r *OpenCLRuntime) CLContext() unsafe.Pointer
CLContext returns the underlying cl_context pointer.
func (*OpenCLRuntime) CLDevice ¶
func (r *OpenCLRuntime) CLDevice() unsafe.Pointer
CLDevice returns the underlying cl_device_id pointer.
func (*OpenCLRuntime) CLQueue ¶
func (r *OpenCLRuntime) CLQueue() unsafe.Pointer
CLQueue returns the default command queue pointer.
func (*OpenCLRuntime) CreateStream ¶
func (r *OpenCLRuntime) CreateStream() (Stream, error)
func (*OpenCLRuntime) DeviceType ¶
func (r *OpenCLRuntime) DeviceType() device.Type
func (*OpenCLRuntime) GetDeviceCount ¶
func (r *OpenCLRuntime) GetDeviceCount() (int, error)
func (*OpenCLRuntime) Memcpy ¶
func (r *OpenCLRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
func (*OpenCLRuntime) MemcpyAsync ¶
func (r *OpenCLRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, _ Stream) error
func (*OpenCLRuntime) MemcpyPeer ¶
func (*OpenCLRuntime) SetDevice ¶
func (r *OpenCLRuntime) SetDevice(deviceID int) error
type PoolingMode ¶
type PoolingMode int
PoolingMode selects the pooling strategy for DNN operations.
const ( PoolingMax PoolingMode = iota PoolingAverageCountIncludePad PoolingAverageCountExcludePad )
type ProfileSummary ¶
ProfileSummary holds aggregated cuBLAS profiling stats.
type ROCmBlas ¶
type ROCmBlas struct {
// contains filtered or unexported fields
}
ROCmBlas implements the BLAS interface using rocBLAS.
func NewROCmBlas ¶
NewROCmBlas creates a new rocBLAS adapter. The caller must call Destroy when done.
func NewROCmBlasFromHandle ¶
NewROCmBlasFromHandle wraps an existing rocBLAS handle.
func (*ROCmBlas) BFloat16Gemm ¶
func (*ROCmBlas) Float16Gemm ¶
func (*ROCmBlas) MixedBF16Gemm ¶
func (*ROCmBlas) MixedFP16Gemm ¶
type ROCmDNN ¶
type ROCmDNN struct {
// contains filtered or unexported fields
}
ROCmDNN implements the DNN interface using MIOpen.
func NewROCmDNNFromHandle ¶
NewROCmDNNFromHandle wraps an existing MIOpen handle.
func (*ROCmDNN) ActivationBackward ¶
func (*ROCmDNN) ActivationForward ¶
func (*ROCmDNN) BatchNormBackward ¶
func (*ROCmDNN) BatchNormForwardInference ¶
func (*ROCmDNN) BatchNormForwardTraining ¶
func (*ROCmDNN) ConvBackwardData ¶
func (*ROCmDNN) ConvBackwardFilter ¶
func (*ROCmDNN) ConvForward ¶
func (*ROCmDNN) PoolingBackward ¶
func (*ROCmDNN) PoolingForward ¶
type ROCmKernels ¶
type ROCmKernels struct{}
ROCmKernels implements the KernelRunner interface using custom HIP kernels.
func NewROCmKernels ¶
func NewROCmKernels() *ROCmKernels
NewROCmKernels returns a new ROCm kernel runner adapter.
func (*ROCmKernels) AddBroadcast ¶
func (*ROCmKernels) AddBroadcast4D ¶
func (*ROCmKernels) DequantFP8E4M3ToFP16 ¶
func (*ROCmKernels) DequantQ4KF32 ¶
func (*ROCmKernels) DivBroadcast ¶
func (*ROCmKernels) DivBroadcast4D ¶
func (*ROCmKernels) FusedAddRMSNormF32 ¶
func (*ROCmKernels) FusedNormAddF32 ¶
func (*ROCmKernels) FusedQKNormRoPEF32 ¶
func (*ROCmKernels) FusedRoPEF32 ¶
func (*ROCmKernels) FusedSwiGLUF32 ¶
func (*ROCmKernels) GemvQ4KF32 ¶
func (*ROCmKernels) IncrementCounter ¶
func (*ROCmKernels) MulBroadcast ¶
func (*ROCmKernels) MulBroadcast4D ¶
func (*ROCmKernels) OffsetMemcpy ¶
func (*ROCmKernels) OffsetMemcpyFP16 ¶
func (*ROCmKernels) RMSNormFP16 ¶
func (*ROCmKernels) ResetCounter ¶
func (*ROCmKernels) RoPESelect ¶
func (*ROCmKernels) ScaledSoftmaxF32 ¶
func (*ROCmKernels) ScaledSoftmaxFP16 ¶
func (*ROCmKernels) SubBroadcast ¶
func (*ROCmKernels) SubBroadcast4D ¶
func (*ROCmKernels) Transpose2D ¶
func (*ROCmKernels) TransposeND ¶
type ROCmMemPool ¶
type ROCmMemPool struct {
// contains filtered or unexported fields
}
ROCmMemPool implements the MemPool interface by wrapping hip.MemPool.
func NewROCmMemPool ¶
func NewROCmMemPool() *ROCmMemPool
NewROCmMemPool creates a new ROCm memory pool adapter.
func NewROCmMemPoolFrom ¶
func NewROCmMemPoolFrom(pool *hip.MemPool) *ROCmMemPool
NewROCmMemPoolFrom wraps an existing hip.MemPool.
func (*ROCmMemPool) Alloc ¶
func (p *ROCmMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error)
func (*ROCmMemPool) AllocManaged ¶
func (p *ROCmMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error)
func (*ROCmMemPool) Drain ¶
func (p *ROCmMemPool) Drain() error
func (*ROCmMemPool) Free ¶
func (p *ROCmMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int)
func (*ROCmMemPool) FreeManaged ¶
func (p *ROCmMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int)
func (*ROCmMemPool) Inner ¶
func (p *ROCmMemPool) Inner() *hip.MemPool
Inner returns the underlying hip.MemPool for backward compatibility.
func (*ROCmMemPool) Stats ¶
func (p *ROCmMemPool) Stats() (int, int)
type ROCmRuntime ¶
type ROCmRuntime struct{}
ROCmRuntime implements the Runtime interface using the AMD HIP runtime API.
func NewROCmRuntime ¶
func NewROCmRuntime() *ROCmRuntime
NewROCmRuntime returns a new ROCm runtime adapter.
func (*ROCmRuntime) CreateStream ¶
func (r *ROCmRuntime) CreateStream() (Stream, error)
func (*ROCmRuntime) DeviceType ¶
func (r *ROCmRuntime) DeviceType() device.Type
func (*ROCmRuntime) GetDeviceCount ¶
func (r *ROCmRuntime) GetDeviceCount() (int, error)
func (*ROCmRuntime) Memcpy ¶
func (r *ROCmRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
func (*ROCmRuntime) MemcpyAsync ¶
func (r *ROCmRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
func (*ROCmRuntime) MemcpyPeer ¶
func (*ROCmRuntime) SetDevice ¶
func (r *ROCmRuntime) SetDevice(deviceID int) error
type Runtime ¶
type Runtime interface {
// DeviceType returns the device type this runtime manages.
DeviceType() device.Type
// SetDevice sets the active GPU device for the calling goroutine.
SetDevice(deviceID int) error
// GetDeviceCount returns the number of available GPU devices.
GetDeviceCount() (int, error)
// Malloc allocates byteSize bytes of device memory.
Malloc(byteSize int) (unsafe.Pointer, error)
// Free releases device memory previously allocated by Malloc.
Free(ptr unsafe.Pointer) error
// Memcpy copies count bytes between host and device memory.
Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error
// MemcpyAsync copies count bytes asynchronously on the given stream.
MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error
// MemcpyPeer copies count bytes between devices (peer-to-peer).
MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, count int) error
// CreateStream creates a new asynchronous command stream.
CreateStream() (Stream, error)
}
Runtime abstracts GPU device and memory management operations. Each vendor (CUDA, ROCm, OpenCL) provides an implementation.
type Stream ¶
type Stream interface {
// Synchronize blocks until all commands in the stream have completed.
Synchronize() error
// Destroy releases the stream resources.
Destroy() error
// Ptr returns the underlying vendor stream handle as an unsafe.Pointer.
// For CUDA this is cudaStream_t, for ROCm this is hipStream_t.
Ptr() unsafe.Pointer
}
Stream represents an asynchronous command queue on a GPU device.