Versions in this module Expand all Collapse all v0 v0.2.0 Mar 16, 2026 v0.1.0 Mar 16, 2026 Changes in this version + var BLASFactory func() (BLAS, error) + var DNNFactory func() (DNN, error) + func PrintCUBLASProfile() + type ActivationMode int + const ActivationClippedReLU + const ActivationELU + const ActivationReLU + const ActivationSigmoid + const ActivationTanh + type BLAS interface + BFloat16Gemm func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + Destroy func() error + Float16Gemm func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + MixedBF16Gemm func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + MixedFP16Gemm func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + SetStream func(stream Stream) error + Sgemm func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + type BLASBatched interface + SgemmStridedBatched func(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, b unsafe.Pointer, ...) error + type BLASBatchedTransposeB interface + SgemmNTStridedBatched func(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, b unsafe.Pointer, ...) error + type BLASTransposeB interface + SgemmNT func(m, n, k int, alpha float32, a unsafe.Pointer, b unsafe.Pointer, beta float32, ...) error + type BatchNormMode int + const BatchNormPerActivation + const BatchNormSpatial + type CUDAArenaPool struct + func NewCUDAArenaPool(deviceID, capacityBytes int, fallback *cuda.MemPool) (*CUDAArenaPool, error) + func (p *CUDAArenaPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error) + func (p *CUDAArenaPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error) + func (p *CUDAArenaPool) Drain() error + func (p *CUDAArenaPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int) + func (p *CUDAArenaPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int) + func (p *CUDAArenaPool) Inner() *cuda.ArenaPool + func (p *CUDAArenaPool) Reset() + func (p *CUDAArenaPool) SetResetFloor(floor int) + func (p *CUDAArenaPool) Stats() (int, int) + func (p *CUDAArenaPool) UsedBytes() int + type CUDABlas struct + func NewCUDABlas() (*CUDABlas, error) + func NewCUDABlasFromHandle(h *cublas.Handle) *CUDABlas + func (b *CUDABlas) BFloat16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) Destroy() error + func (b *CUDABlas) Float16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) Handle() *cublas.Handle + func (b *CUDABlas) MixedBF16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) MixedFP16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) SetStream(stream Stream) error + func (b *CUDABlas) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) SgemmNT(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (b *CUDABlas) SgemmNTStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error + func (b *CUDABlas) SgemmStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error + type CUDABlasProfiler struct + func WrapWithProfiler(b *CUDABlas) *CUDABlasProfiler + func (p *CUDABlasProfiler) BFloat16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) Destroy() error + func (p *CUDABlasProfiler) Float16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) Handle() *CUDABlas + func (p *CUDABlasProfiler) IsEnabled() bool + func (p *CUDABlasProfiler) MixedBF16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) MixedFP16Gemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) PrintSummary() + func (p *CUDABlasProfiler) ResetProfile() + func (p *CUDABlasProfiler) SetStream(stream Stream) error + func (p *CUDABlasProfiler) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) SgemmNT(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + func (p *CUDABlasProfiler) SgemmNTStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error + func (p *CUDABlasProfiler) SgemmStridedBatched(m, n, k int, alpha float32, a unsafe.Pointer, strideA int64, ...) error + func (p *CUDABlasProfiler) Summary() ProfileSummary + type CUDADNN struct + func NewCUDADNN() (*CUDADNN, error) + func NewCUDADNNFromHandle(h *cudnn.Handle) *CUDADNN + func (d *CUDADNN) ActivationBackward(mode ActivationMode, y unsafe.Pointer, dy unsafe.Pointer, x unsafe.Pointer, ...) error + func (d *CUDADNN) ActivationForward(mode ActivationMode, x unsafe.Pointer, shape [4]int, y unsafe.Pointer, ...) error + func (d *CUDADNN) AddTensor(alpha float32, b unsafe.Pointer, bShape [4]int, beta float32, y unsafe.Pointer, ...) error + func (d *CUDADNN) BatchNormBackward(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, scale unsafe.Pointer, ...) error + func (d *CUDADNN) BatchNormForwardInference(x unsafe.Pointer, xShape [4]int, scale, bias, mean, variance unsafe.Pointer, ...) error + func (d *CUDADNN) BatchNormForwardTraining(x unsafe.Pointer, xShape [4]int, scale, bias unsafe.Pointer, channels int, ...) error + func (d *CUDADNN) ConvBackwardData(w unsafe.Pointer, wShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + func (d *CUDADNN) ConvBackwardFilter(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + func (d *CUDADNN) ConvForward(x unsafe.Pointer, xShape [4]int, w unsafe.Pointer, wShape [4]int, ...) error + func (d *CUDADNN) Destroy() error + func (d *CUDADNN) Handle() *cudnn.Handle + func (d *CUDADNN) PoolingBackward(mode PoolingMode, y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int, ...) error + func (d *CUDADNN) PoolingForward(mode PoolingMode, x unsafe.Pointer, xShape [4]int, y unsafe.Pointer, ...) error + func (d *CUDADNN) SetStream(stream Stream) error + func (d *CUDADNN) SoftmaxForward(x unsafe.Pointer, shape [4]int, y unsafe.Pointer, stream Stream) error + type CUDAKernels struct + func NewCUDAKernels() *CUDAKernels + func (k *CUDAKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error + func (k *CUDAKernels) AddBroadcast4D(a, b, c unsafe.Pointer, ...) error + func (k *CUDAKernels) AddFP16(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Argmax(input, result, scratch unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s Stream) error + func (k *CUDAKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, s Stream) error + func (k *CUDAKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error + func (k *CUDAKernels) DivBroadcast4D(a, b, c unsafe.Pointer, ...) error + func (k *CUDAKernels) DivFP16(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) F32ToFP16(src, dst unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) FP16ToF32(src, dst unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error + func (k *CUDAKernels) FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, ...) error + func (k *CUDAKernels) FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, ...) error + func (k *CUDAKernels) FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, ...) error + func (k *CUDAKernels) FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, ...) error + func (k *CUDAKernels) FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, s Stream) error + func (k *CUDAKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error + func (k *CUDAKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, s Stream) error + func (k *CUDAKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, s Stream) error + func (k *CUDAKernels) IncrementCounter(counter unsafe.Pointer, delta int, s Stream) error + func (k *CUDAKernels) Log(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error + func (k *CUDAKernels) MulBroadcast4D(a, b, c unsafe.Pointer, ...) error + func (k *CUDAKernels) MulFP16(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error + func (k *CUDAKernels) OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s Stream) error + func (k *CUDAKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error + func (k *CUDAKernels) RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s Stream) error + func (k *CUDAKernels) Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s Stream) error + func (k *CUDAKernels) ResetCounter(counter unsafe.Pointer, value int, s Stream) error + func (k *CUDAKernels) RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, ...) error + func (k *CUDAKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error + func (k *CUDAKernels) ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error + func (k *CUDAKernels) SgemvM1(y, A, x unsafe.Pointer, M, N int, s Stream) error + func (k *CUDAKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error + func (k *CUDAKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s Stream) error + func (k *CUDAKernels) SubBroadcast4D(a, b, c unsafe.Pointer, ...) error + func (k *CUDAKernels) SubFP16(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error + func (k *CUDAKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error + func (k *CUDAKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, s Stream) error + func (k *CUDAKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error + type CUDAMemPool struct + func NewCUDAMemPool() *CUDAMemPool + func NewCUDAMemPoolFrom(pool *cuda.MemPool) *CUDAMemPool + func (p *CUDAMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error) + func (p *CUDAMemPool) AllocManaged(deviceID, byteSize int) (unsafe.Pointer, error) + func (p *CUDAMemPool) Drain() error + func (p *CUDAMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int) + func (p *CUDAMemPool) FreeManaged(deviceID int, ptr unsafe.Pointer, byteSize int) + func (p *CUDAMemPool) Inner() *cuda.MemPool + func (p *CUDAMemPool) Stats() (int, int) + type CUDARuntime struct + func NewCUDARuntime() *CUDARuntime + func (r *CUDARuntime) CreateStream() (Stream, error) + func (r *CUDARuntime) DeviceType() device.Type + func (r *CUDARuntime) Free(ptr unsafe.Pointer) error + func (r *CUDARuntime) GetDeviceCount() (int, error) + func (r *CUDARuntime) Malloc(byteSize int) (unsafe.Pointer, error) + func (r *CUDARuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error + func (r *CUDARuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error + func (r *CUDARuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, ...) error + func (r *CUDARuntime) SetDevice(deviceID int) error + type DNN interface + ActivationBackward func(mode ActivationMode, y unsafe.Pointer, dy unsafe.Pointer, x unsafe.Pointer, ...) error + ActivationForward func(mode ActivationMode, x unsafe.Pointer, shape [4]int, y unsafe.Pointer, ...) error + AddTensor func(alpha float32, b unsafe.Pointer, bShape [4]int, beta float32, y unsafe.Pointer, ...) error + BatchNormBackward func(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, scale unsafe.Pointer, ...) error + BatchNormForwardInference func(x unsafe.Pointer, xShape [4]int, scale, bias, mean, variance unsafe.Pointer, ...) error + BatchNormForwardTraining func(x unsafe.Pointer, xShape [4]int, scale, bias unsafe.Pointer, channels int, ...) error + ConvBackwardData func(w unsafe.Pointer, wShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + ConvBackwardFilter func(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + ConvForward func(x unsafe.Pointer, xShape [4]int, w unsafe.Pointer, wShape [4]int, ...) error + Destroy func() error + PoolingBackward func(mode PoolingMode, y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int, ...) error + PoolingForward func(mode PoolingMode, x unsafe.Pointer, xShape [4]int, y unsafe.Pointer, ...) error + SetStream func(stream Stream) error + SoftmaxForward func(x unsafe.Pointer, shape [4]int, y unsafe.Pointer, stream Stream) error + type KernelRunner interface + Add func(a, b, c unsafe.Pointer, n int, stream Stream) error + AddBroadcast func(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error + AddBroadcast4D func(a, b, c unsafe.Pointer, ...) error + AddFP16 func(a, b, c unsafe.Pointer, n int, stream Stream) error + AddScalar func(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error + Argmax func(input, result, scratch unsafe.Pointer, n int, stream Stream) error + Cos func(a, c unsafe.Pointer, n int, stream Stream) error + DequantFP8E4M3ToFP16 func(input, output unsafe.Pointer, scale float32, n int, stream Stream) error + DequantQ4KF32 func(src, dst unsafe.Pointer, rows, K int, stream Stream) error + Div func(a, b, c unsafe.Pointer, n int, stream Stream) error + DivBroadcast func(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error + DivBroadcast4D func(a, b, c unsafe.Pointer, ...) error + DivFP16 func(a, b, c unsafe.Pointer, n int, stream Stream) error + DivScalar func(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error + Exp func(a, c unsafe.Pointer, n int, stream Stream) error + F32ToFP16 func(src, dst unsafe.Pointer, n int, stream Stream) error + FP16ToF32 func(src, dst unsafe.Pointer, n int, stream Stream) error + Fill func(data unsafe.Pointer, value float32, n int, stream Stream) error + FusedAddRMSNormF32 func(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, ...) error + FusedNormAddF32 func(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, ...) error + FusedQKNormRoPEF32 func(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, ...) error + FusedRoPEF32 func(input, cosAngles, sinAngles, output unsafe.Pointer, ...) error + FusedSwiGLUF32 func(w1, w3, output unsafe.Pointer, n int, stream Stream) error + Gather func(table, indices, output unsafe.Pointer, N, D, V int, stream Stream) error + GemmQ4F32 func(aQ4, b, c unsafe.Pointer, m, k, n, dataOffset int, stream Stream) error + GemmQ8F32 func(aQ8, b, c unsafe.Pointer, m, k, n int, stream Stream) error + GemvQ4KF32 func(wQ4K, x, y unsafe.Pointer, M, K int, stream Stream) error + IncrementCounter func(counter unsafe.Pointer, delta int, stream Stream) error + Log func(a, c unsafe.Pointer, n int, stream Stream) error + Mul func(a, b, c unsafe.Pointer, n int, stream Stream) error + MulBroadcast func(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error + MulBroadcast4D func(a, b, c unsafe.Pointer, ...) error + MulFP16 func(a, b, c unsafe.Pointer, n int, stream Stream) error + MulScalar func(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error + OffsetMemcpy func(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error + OffsetMemcpyFP16 func(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, stream Stream) error + Pow func(base, exp, c unsafe.Pointer, n int, stream Stream) error + PowScalar func(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error + RMSNorm func(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error + RMSNormFP16 func(input, weight, output unsafe.Pointer, eps float32, rows, D int, stream Stream) error + Repeat func(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, ...) error + ResetCounter func(counter unsafe.Pointer, value int, stream Stream) error + RoPESelect func(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, ...) error + Rsqrt func(a, c unsafe.Pointer, n int, stream Stream) error + ScaledSoftmaxF32 func(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error + ScaledSoftmaxFP16 func(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error + SgemvM1 func(y, A, x unsafe.Pointer, M, N int, stream Stream) error + Sin func(a, c unsafe.Pointer, n int, stream Stream) error + Softmax func(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error + Sqrt func(a, c unsafe.Pointer, n int, stream Stream) error + Sub func(a, b, c unsafe.Pointer, n int, stream Stream) error + SubBroadcast func(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream Stream) error + SubBroadcast4D func(a, b, c unsafe.Pointer, ...) error + SubFP16 func(a, b, c unsafe.Pointer, n int, stream Stream) error + SubScalar func(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, stream Stream) error + SumAxis func(input, output unsafe.Pointer, outer, inner, axisSize int, stream Stream) error + Tanh func(a, c unsafe.Pointer, n int, stream Stream) error + TanhPrime func(a, upstream, c unsafe.Pointer, n int, stream Stream) error + Transpose2D func(input, output unsafe.Pointer, rows, cols int, stream Stream) error + TransposeND func(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error + type MemPool interface + Alloc func(deviceID, byteSize int) (unsafe.Pointer, error) + AllocManaged func(deviceID, byteSize int) (unsafe.Pointer, error) + Drain func() error + Free func(deviceID int, ptr unsafe.Pointer, byteSize int) + FreeManaged func(deviceID int, ptr unsafe.Pointer, byteSize int) + Stats func() (allocations int, totalBytes int) + type MemcpyKind int + const MemcpyDeviceToDevice + const MemcpyDeviceToHost + const MemcpyHostToDevice + type OpSummary struct + AvgTime time.Duration + Batch int + Calls int + K int + M int + N int + Op string + TotalTime time.Duration + type OpenCLDNN struct + func NewOpenCLDNN() *OpenCLDNN + func (d *OpenCLDNN) ActivationBackward(_ ActivationMode, _ unsafe.Pointer, _ unsafe.Pointer, _ unsafe.Pointer, ...) error + func (d *OpenCLDNN) ActivationForward(_ ActivationMode, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ Stream) error + func (d *OpenCLDNN) AddTensor(_ float32, _ unsafe.Pointer, _ [4]int, _ float32, _ unsafe.Pointer, _ [4]int, ...) error + func (d *OpenCLDNN) BatchNormBackward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ unsafe.Pointer, _ int, ...) error + func (d *OpenCLDNN) BatchNormForwardInference(_ unsafe.Pointer, _ [4]int, _, _, _, _ unsafe.Pointer, _ int, _ float64, ...) error + func (d *OpenCLDNN) BatchNormForwardTraining(_ unsafe.Pointer, _ [4]int, _, _ unsafe.Pointer, _ int, _, _ float64, ...) error + func (d *OpenCLDNN) ConvBackwardData(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error + func (d *OpenCLDNN) ConvBackwardFilter(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error + func (d *OpenCLDNN) ConvForward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error + func (d *OpenCLDNN) Destroy() error + func (d *OpenCLDNN) PoolingBackward(_ PoolingMode, _ unsafe.Pointer, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, ...) error + func (d *OpenCLDNN) PoolingForward(_ PoolingMode, _ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ [4]int, ...) error + func (d *OpenCLDNN) SetStream(_ Stream) error + func (d *OpenCLDNN) SoftmaxForward(_ unsafe.Pointer, _ [4]int, _ unsafe.Pointer, _ Stream) error + type OpenCLMemPool struct + func NewOpenCLMemPool(rt *OpenCLRuntime) *OpenCLMemPool + func (p *OpenCLMemPool) Alloc(_ int, byteSize int) (unsafe.Pointer, error) + func (p *OpenCLMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error) + func (p *OpenCLMemPool) Drain() error + func (p *OpenCLMemPool) Free(_ int, ptr unsafe.Pointer, byteSize int) + func (p *OpenCLMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int) + func (p *OpenCLMemPool) Stats() (int, int) + type OpenCLRuntime struct + func NewOpenCLRuntime() *OpenCLRuntime + func (r *OpenCLRuntime) CLContext() unsafe.Pointer + func (r *OpenCLRuntime) CLDevice() unsafe.Pointer + func (r *OpenCLRuntime) CLQueue() unsafe.Pointer + func (r *OpenCLRuntime) CreateStream() (Stream, error) + func (r *OpenCLRuntime) DeviceType() device.Type + func (r *OpenCLRuntime) Free(ptr unsafe.Pointer) error + func (r *OpenCLRuntime) GetDeviceCount() (int, error) + func (r *OpenCLRuntime) Malloc(byteSize int) (unsafe.Pointer, error) + func (r *OpenCLRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error + func (r *OpenCLRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, _ Stream) error + func (r *OpenCLRuntime) MemcpyPeer(dst unsafe.Pointer, _ int, src unsafe.Pointer, _ int, count int) error + func (r *OpenCLRuntime) SetDevice(deviceID int) error + type PoolingMode int + const PoolingAverageCountExcludePad + const PoolingAverageCountIncludePad + const PoolingMax + type ProfileSummary struct + ByOp []OpSummary + TotalCalls int + TotalDuration time.Duration + type ROCmBlas struct + func NewROCmBlas() (*ROCmBlas, error) + func NewROCmBlasFromHandle(h *rocblas.Handle) *ROCmBlas + func (b *ROCmBlas) BFloat16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error + func (b *ROCmBlas) Destroy() error + func (b *ROCmBlas) Float16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error + func (b *ROCmBlas) Handle() *rocblas.Handle + func (b *ROCmBlas) MixedBF16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error + func (b *ROCmBlas) MixedFP16Gemm(_, _, _ int, _ float32, _, _ unsafe.Pointer, _ float32, _ unsafe.Pointer) error + func (b *ROCmBlas) SetStream(stream Stream) error + func (b *ROCmBlas) Sgemm(m, n, k int, alpha float32, a unsafe.Pointer, bPtr unsafe.Pointer, ...) error + type ROCmDNN struct + func NewROCmDNN() (*ROCmDNN, error) + func NewROCmDNNFromHandle(h *miopen.Handle) *ROCmDNN + func (d *ROCmDNN) ActivationBackward(mode ActivationMode, y unsafe.Pointer, dy unsafe.Pointer, x unsafe.Pointer, ...) error + func (d *ROCmDNN) ActivationForward(mode ActivationMode, x unsafe.Pointer, shape [4]int, y unsafe.Pointer, ...) error + func (d *ROCmDNN) AddTensor(alpha float32, b unsafe.Pointer, bShape [4]int, beta float32, y unsafe.Pointer, ...) error + func (d *ROCmDNN) BatchNormBackward(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, scale unsafe.Pointer, ...) error + func (d *ROCmDNN) BatchNormForwardInference(x unsafe.Pointer, xShape [4]int, scale, bias, mean, variance unsafe.Pointer, ...) error + func (d *ROCmDNN) BatchNormForwardTraining(x unsafe.Pointer, xShape [4]int, scale, bias unsafe.Pointer, channels int, ...) error + func (d *ROCmDNN) ConvBackwardData(w unsafe.Pointer, wShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + func (d *ROCmDNN) ConvBackwardFilter(x unsafe.Pointer, xShape [4]int, dy unsafe.Pointer, dyShape [4]int, ...) error + func (d *ROCmDNN) ConvForward(x unsafe.Pointer, xShape [4]int, w unsafe.Pointer, wShape [4]int, ...) error + func (d *ROCmDNN) Destroy() error + func (d *ROCmDNN) Handle() *miopen.Handle + func (d *ROCmDNN) PoolingBackward(mode PoolingMode, y unsafe.Pointer, dy unsafe.Pointer, yShape [4]int, ...) error + func (d *ROCmDNN) PoolingForward(mode PoolingMode, x unsafe.Pointer, xShape [4]int, y unsafe.Pointer, ...) error + func (d *ROCmDNN) SetStream(stream Stream) error + func (d *ROCmDNN) SoftmaxForward(x unsafe.Pointer, shape [4]int, y unsafe.Pointer, stream Stream) error + type ROCmKernels struct + func NewROCmKernels() *ROCmKernels + func (k *ROCmKernels) Add(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error + func (k *ROCmKernels) AddBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) AddFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Argmax(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) Cos(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) DequantFP8E4M3ToFP16(_, _ unsafe.Pointer, _ float32, _ int, _ Stream) error + func (k *ROCmKernels) DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, _ Stream) error + func (k *ROCmKernels) Div(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error + func (k *ROCmKernels) DivBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) DivFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Exp(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) F32ToFP16(_, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) FP16ToF32(_, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) Fill(data unsafe.Pointer, value float32, n int, s Stream) error + func (k *ROCmKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error + func (k *ROCmKernels) FusedNormAddF32(_, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error + func (k *ROCmKernels) FusedQKNormRoPEF32(_, _, _, _, _, _ unsafe.Pointer, _ float32, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) FusedRoPEF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) FusedSwiGLUF32(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, _ Stream) error + func (k *ROCmKernels) GemmQ4F32(aQ4, b, c unsafe.Pointer, m, kk, n, dataOffset int, s Stream) error + func (k *ROCmKernels) GemmQ8F32(aQ8, b, c unsafe.Pointer, m, kk, n int, _ Stream) error + func (k *ROCmKernels) GemvQ4KF32(wQ4K, x, y unsafe.Pointer, M, K int, _ Stream) error + func (k *ROCmKernels) IncrementCounter(_ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) Log(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Mul(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error + func (k *ROCmKernels) MulBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) MulFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) OffsetMemcpy(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error + func (k *ROCmKernels) OffsetMemcpyFP16(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error + func (k *ROCmKernels) Pow(base, exp, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) PowScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error + func (k *ROCmKernels) RMSNormFP16(_, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error + func (k *ROCmKernels) Repeat(_ unsafe.Pointer, _ unsafe.Pointer, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) ResetCounter(_ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) RoPESelect(_, _, _, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) Rsqrt(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) ScaledSoftmaxF32(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error + func (k *ROCmKernels) ScaledSoftmaxFP16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error + func (k *ROCmKernels) SgemvM1(_, _, _ unsafe.Pointer, _, _ int, _ Stream) error + func (k *ROCmKernels) Sin(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error + func (k *ROCmKernels) Sqrt(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Sub(a, b, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, _ Stream) error + func (k *ROCmKernels) SubBroadcast4D(_, _, _ unsafe.Pointer, _, _, _, _, _, _, _, _, _, _, _, _ int, _ Stream) error + func (k *ROCmKernels) SubFP16(_, _, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) SubScalar(_ unsafe.Pointer, _ float32, _ unsafe.Pointer, _ int, _ Stream) error + func (k *ROCmKernels) SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s Stream) error + func (k *ROCmKernels) Tanh(a, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) TanhPrime(a, upstream, c unsafe.Pointer, n int, s Stream) error + func (k *ROCmKernels) Transpose2D(input, output unsafe.Pointer, rows, cols int, _ Stream) error + func (k *ROCmKernels) TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error + type ROCmMemPool struct + func NewROCmMemPool() *ROCmMemPool + func NewROCmMemPoolFrom(pool *hip.MemPool) *ROCmMemPool + func (p *ROCmMemPool) Alloc(deviceID, byteSize int) (unsafe.Pointer, error) + func (p *ROCmMemPool) AllocManaged(_, _ int) (unsafe.Pointer, error) + func (p *ROCmMemPool) Drain() error + func (p *ROCmMemPool) Free(deviceID int, ptr unsafe.Pointer, byteSize int) + func (p *ROCmMemPool) FreeManaged(_ int, _ unsafe.Pointer, _ int) + func (p *ROCmMemPool) Inner() *hip.MemPool + func (p *ROCmMemPool) Stats() (int, int) + type ROCmRuntime struct + func NewROCmRuntime() *ROCmRuntime + func (r *ROCmRuntime) CreateStream() (Stream, error) + func (r *ROCmRuntime) DeviceType() device.Type + func (r *ROCmRuntime) Free(ptr unsafe.Pointer) error + func (r *ROCmRuntime) GetDeviceCount() (int, error) + func (r *ROCmRuntime) Malloc(byteSize int) (unsafe.Pointer, error) + func (r *ROCmRuntime) Memcpy(dst, src unsafe.Pointer, count int, kind MemcpyKind) error + func (r *ROCmRuntime) MemcpyAsync(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error + func (r *ROCmRuntime) MemcpyPeer(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, ...) error + func (r *ROCmRuntime) SetDevice(deviceID int) error + type Runtime interface + CreateStream func() (Stream, error) + DeviceType func() device.Type + Free func(ptr unsafe.Pointer) error + GetDeviceCount func() (int, error) + Malloc func(byteSize int) (unsafe.Pointer, error) + Memcpy func(dst, src unsafe.Pointer, count int, kind MemcpyKind) error + MemcpyAsync func(dst, src unsafe.Pointer, count int, kind MemcpyKind, stream Stream) error + MemcpyPeer func(dst unsafe.Pointer, dstDevice int, src unsafe.Pointer, srcDevice int, ...) error + SetDevice func(deviceID int) error + type Stream interface + Destroy func() error + Ptr func() unsafe.Pointer + Synchronize func() error