kernels

package
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 21, 2026 License: Apache-2.0 Imports: 5 Imported by: 0

Documentation

Overview

Package kernels provides Go wrappers for custom CUDA kernels. Build libkernels.a first: cd internal/cuda/kernels && make All functional code requires the "cuda" build tag.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func Add

func Add(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Add launches the elementwise add kernel: c = a + b.

func AddBroadcast

func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

AddBroadcast launches the broadcast add kernel.

func AddBroadcast4D

func AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

AddBroadcast4D launches the 4D broadcast add kernel.

func AddFP16

func AddFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

AddFP16 launches the FP16 elementwise add kernel: c = a + b.

func AddScalar

func AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

AddScalar launches the scalar add kernel: c = a + scalar.

func Argmax

func Argmax(input unsafe.Pointer, result unsafe.Pointer,
	scratch unsafe.Pointer, n int, s unsafe.Pointer) error

Argmax launches the GPU argmax kernel. input: [n] float32 on device, result: single int32 on device, scratch: device temp storage of at least 2*ceil(n/256)*4 bytes.

func Cos

func Cos(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Cos launches the elementwise cos kernel: c = cos(a).

func DequantFP8E4M3ToFP16

func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error

DequantFP8E4M3ToFP16 launches the FP8 E4M3 -> FP16 dequantization kernel. input: n bytes of FP8 E4M3 data, output: n FP16 values, scale: per-tensor scale factor.

func DequantQ4KF32

func DequantQ4KF32(
	src, dst unsafe.Pointer,
	rows, K int,
	stream unsafe.Pointer,
) error

DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory. src is raw Q4_K super-blocks, dst is [rows, K] FP32.

func Div

func Div(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Div launches the elementwise divide kernel: c = a / b.

func DivBroadcast

func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

DivBroadcast launches the broadcast div kernel.

func DivBroadcast4D

func DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

DivBroadcast4D launches the 4D broadcast div kernel.

func DivFP16

func DivFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

DivFP16 launches the FP16 elementwise divide kernel: c = a / b.

func DivScalar

func DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

DivScalar launches the scalar divide kernel: c = a / scalar.

func Exp

func Exp(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Exp launches the elementwise exp kernel: c = exp(a).

func F32ToFP16

func F32ToFP16(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error

F32ToFP16 converts n float32 elements to FP16 on GPU.

func FP4GemvF16 added in v0.3.0

func FP4GemvF16(
	wFP4, scales, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

FP4GemvF16 performs NVFP4 fused dequant-GEMV with FP16 activations:

y[m] = sum_k( dequant(W_fp4[m,k]) * x_fp16[k] )

W_fp4: device pointer to packed NVFP4 data [M, K] (8 bytes per block of 16). scales: device pointer to [M * ceil(K/16)] float16 block scales. x: device pointer to [K] float16 input vector. y: device pointer to [M] float32 output vector.

func FP8Add

func FP8Add(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error

FP8Add launches the FP8 dequant+add kernel: c[i] = dequant(a[i])*scaleA + dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.

func FP8Gemm added in v0.3.0

func FP8Gemm(a, b, c unsafe.Pointer, m, k, n int, scaleA, scaleB float32, stream unsafe.Pointer) error

FP8Gemm launches an FP8 E4M3 GEMM using cublasLt. A: [M, K] FP8 E4M3, B: [K, N] FP8 E4M3, C: [M, N] FP16 output. scaleA and scaleB are per-tensor dequantization scales. The output is computed as: C = (scaleA * scaleB) * (A @ B) in FP16.

func FP8Mul

func FP8Mul(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error

FP8Mul launches the FP8 dequant+mul kernel: c[i] = dequant(a[i])*scaleA * dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.

func FP8RMSNorm

func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, s unsafe.Pointer) error

FP8RMSNorm launches the FP8 dequant+RMSNorm kernel. input: FP8 E4M3 [rows, D], weight: FP16 [D], output: FP16 [rows, D]. Dequantizes input on load, computes RMSNorm with FP32 accumulation, writes FP16.

func FP16ToF32

func FP16ToF32(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error

FP16ToF32 converts n FP16 elements to float32 on GPU.

func Fill

func Fill(data unsafe.Pointer, value float32, n int, s unsafe.Pointer) error

Fill launches the fill kernel: sets all elements to value.

func FlashAttention2Decode added in v0.3.0

func FlashAttention2Decode(
	Q, K, V, O unsafe.Pointer,
	numBH, maxKVLen, headDim, kvLen int,
	kvLenPtr unsafe.Pointer,
	numQueryHeads, numKVHeads int,
	stream unsafe.Pointer,
) error

FlashAttention2Decode computes single-query attention for autoregressive decode using FlashAttention-2 with multi-warp KV parallelism. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).

Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.

kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).

func FlashAttention2Forward added in v0.3.0

func FlashAttention2Forward(
	Q, K, V, O unsafe.Pointer,
	batch, heads, seqLen, headDim int,
	causal bool,
	stream unsafe.Pointer,
) error

FlashAttention2Forward computes scaled dot-product attention using the FlashAttention-2 tiled algorithm. All tensors are [batch, heads, seq_len, head_dim] in row-major order. When causal is true, an upper-triangular mask is applied. Memory usage is O(N), not O(N^2).

func FlashAttentionDecode

func FlashAttentionDecode(
	Q, K, V, O unsafe.Pointer,
	numBH, maxKVLen, headDim, kvLen int,
	kvLenPtr unsafe.Pointer,
	numQueryHeads, numKVHeads int,
	stream unsafe.Pointer,
) error

FlashAttentionDecode computes single-query attention for autoregressive decode. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).

Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.

kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).

func FlashAttentionForward

func FlashAttentionForward(
	Q, K, V, O unsafe.Pointer,
	batch, heads, seqLen, headDim int,
	causal bool,
	stream unsafe.Pointer,
) error

FlashAttentionForward computes scaled dot-product attention using a fused tiled kernel. All tensors are in [batch, heads, seq_len, head_dim] layout. When causal is true, an upper-triangular mask is applied.

func FusedAddRMSNormF32

func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

FusedAddRMSNormF32 performs fused residual add + RMSNorm in one kernel launch. input: [rows, D] (read-only), residual: [rows, D] (read-only), weight: [D], normedOut: [rows, D], sumOut: [rows, D].

func FusedNormAddF32

func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch. output = rmsnorm(input, weight, eps) + residual. input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].

func FusedQKNormRoPEF32

func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, s unsafe.Pointer) error

FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].

func FusedRoPEF32

func FusedRoPEF32(
	input, cosAngles, sinAngles, output unsafe.Pointer,
	batch, seqLen, headDim, halfRotary, cosStride int,
	stream unsafe.Pointer,
) error

FusedRoPEF32 applies fused rotary positional embedding (RoPE) to FP32 data.

func FusedSwiGLUF32

func FusedSwiGLUF32(
	w1, w3, output unsafe.Pointer,
	n int,
	stream unsafe.Pointer,
) error

FusedSwiGLUF32 applies fused SwiGLU activation: output[i] = w1[i] * sigmoid(w1[i]) * w3[i].

func Gather

func Gather(table unsafe.Pointer, indices unsafe.Pointer,
	output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error

Gather launches the embedding table gather kernel with int64 indices. table: [V, D], indices: [N] int64, output: [N, D].

func GatherI32

func GatherI32(table unsafe.Pointer, indices unsafe.Pointer,
	output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error

GatherI32 launches the embedding table gather kernel with int32 indices. table: [V, D], indices: [N] int32, output: [N, D].

func GemmQ4F32

func GemmQ4F32(
	A_q4, B, C unsafe.Pointer,
	M, K, N, dataOffset int,
	stream unsafe.Pointer,
) error

GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B. A_q4 is in GPU separated layout (scales then data), B is [K, N] FP32, C is [M, N] FP32. dataOffset is the byte offset from A_q4 to the packed data region.

func GemmQ8F32

func GemmQ8F32(
	A_q8, B, C unsafe.Pointer,
	M, K, N int,
	stream unsafe.Pointer,
) error

GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B. A_q8 is packed Q8_0 blocks, B is [K, N] FP32, C is [M, N] FP32.

func GemvQ4KDp4aF32 added in v0.3.0

func GemvQ4KDp4aF32(
	W_q4k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ4KDp4aF32 performs Q4_K fused dequant-GEMV using dp4a INT8 dot-product. Same interface as GemvQ4KF32 but uses __dp4a for higher throughput.

func GemvQ4KDp4aF32Available added in v0.3.0

func GemvQ4KDp4aF32Available() bool

GemvQ4KDp4aF32Available reports whether the dp4a INT8 Q4_K GEMV kernel is loaded.

func GemvQ4KF32

func GemvQ4KF32(
	W_q4k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x. W_q4k is raw Q4_K super-blocks, x is [K] FP32, y is [M] FP32.

func GemvQ4KSm121F32 added in v0.3.0

func GemvQ4KSm121F32(
	W_q4k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ4KSm121F32 performs Q4_K fused dequant-GEMV using the sm_121 optimized kernel (8 warps/block, vectorized 128-bit loads, __ldcg activation caching).

Falls back to GemvQ4KF32 when the sm_121 kernel is unavailable. K must be a multiple of 256.

func GemvQ5KF32 added in v0.3.0

func GemvQ5KF32(
	W_q5k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ5KF32 performs Q5_K fused dequant-GEMV: y = dequant(W_q5k) * x. W_q5k is raw Q5_K super-blocks, x is [K] FP32, y is [M] FP32.

func GemvQ5_0F32 added in v0.3.0

func GemvQ5_0F32(
	W_q5_0, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ5_0F32 performs Q5_0 fused dequant-GEMV: y = dequant(W_q5_0) * x. W_q5_0 is raw Q5_0 blocks, x is [K] FP32, y is [M] FP32.

func GemvQ6KF32 added in v0.3.0

func GemvQ6KF32(
	W_q6k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ6KF32 performs Q6_K fused dequant-GEMV: y = dequant(W_q6k) * x. W_q6k is raw Q6_K super-blocks, x is [K] FP32, y is [M] FP32.

func GemvWarpF16 added in v0.3.0

func GemvWarpF16(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error

GemvWarpF16 computes y = A*x using the warp-specialized GEMV kernel (FP16). Each warp handles a different output row tile for decode-phase (batch=1) workloads. y[M], A[M x N] row-major, x[N]. All FP16. Accumulation in FP32 for precision.

func GemvWarpF32 added in v0.3.0

func GemvWarpF32(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error

GemvWarpF32 computes y = A*x using the warp-specialized GEMV kernel (FP32). Each warp handles a different output row tile for decode-phase (batch=1) workloads. y[M], A[M x N] row-major, x[N]. All FP32.

func IncrementCounter

func IncrementCounter(counter unsafe.Pointer, delta int, s unsafe.Pointer) error

IncrementCounter atomically increments a GPU-resident int32 by delta.

func IsFP4GemvSupported added in v0.3.0

func IsFP4GemvSupported() bool

IsFP4GemvSupported returns true if the current GPU supports NVFP4 GEMV (requires sm_100+ Blackwell architecture).

func IsFP8GemmSupported added in v0.3.0

func IsFP8GemmSupported() bool

IsFP8GemmSupported returns true if the current GPU supports FP8 GEMM (requires sm_89+ Ada Lovelace architecture).

func IsPagedAttentionSupported added in v0.3.0

func IsPagedAttentionSupported() bool

IsPagedAttentionSupported returns true if the paged attention kernel symbol was loaded from libkernels.so.

func IsQ4KSm121Supported added in v0.3.0

func IsQ4KSm121Supported() bool

IsQ4KSm121Supported reports whether the loaded kernel library contains the sm_121 optimized Q4_K GEMV kernel AND the current GPU is sm_12x (Blackwell). The result is cached after the first call.

func IsRaggedAttentionSupported added in v0.3.0

func IsRaggedAttentionSupported() bool

IsRaggedAttentionSupported returns true if the ragged attention kernel symbol was loaded from libkernels.so.

func IsSelectiveScanSupported added in v0.3.0

func IsSelectiveScanSupported() bool

IsSelectiveScanSupported reports whether the selective scan kernel is available.

func Log

func Log(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Log launches the elementwise log kernel: c = log(a).

func Mul

func Mul(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Mul launches the elementwise multiply kernel: c = a * b.

func MulBroadcast

func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

MulBroadcast launches the broadcast mul kernel.

func MulBroadcast4D

func MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

MulBroadcast4D launches the 4D broadcast mul kernel.

func MulFP16

func MulFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

MulFP16 launches the FP16 elementwise multiply kernel: c = a * b.

func MulScalar

func MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

MulScalar launches the scalar multiply kernel: c = a * scalar.

func OffsetMemcpy

func OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error

OffsetMemcpy copies dim floats from src to dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven KV cache append.

func OffsetMemcpyFP16

func OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error

OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.

func PagedAttentionForward added in v0.3.0

func PagedAttentionForward(
	Q, O unsafe.Pointer,
	blockPtrsK, blockPtrsV unsafe.Pointer,
	blockIndices unsafe.Pointer,
	seqLen, blockSize, headDim int,
	numQHeads, numKVHeads int,
	batch int,
	stream unsafe.Pointer,
) error

PagedAttentionForward computes scaled dot-product attention with block-table indirection for paged KV caches.

Q: [batch*numQHeads, headDim] -- single query per head. O: [batch*numQHeads, headDim] -- output, same shape as Q. blockPtrsK: device array of float* pointers to K blocks.

Each block holds [blockSize, numKVHeads, headDim] floats.

blockPtrsV: device array of float* pointers to V blocks (same layout). blockIndices: device array [batch * maxNumBlocks] mapping logical block

index to physical block index in blockPtrs arrays.

seqLen: actual number of valid K/V token positions. blockSize: number of token positions per block. headDim: dimension per head. numQHeads: number of query heads per batch element. numKVHeads: number of KV heads per batch element. batch: number of sequences in the batch.

func Pow

func Pow(base, exp, c unsafe.Pointer, n int, s unsafe.Pointer) error

Pow launches the elementwise power kernel: c = base ^ exp.

func PowScalar

func PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

PowScalar launches the scalar power kernel: c = pow(a, scalar).

func RMSNorm

func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

RMSNorm launches the fused RMSNorm kernel. input: [rows, D], weight: [D], output: [rows, D], scales: [rows].

func RMSNormFP16

func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

RMSNormFP16 launches the FP16 RMSNorm kernel with FP32 accumulation. input: [rows, D], weight: [D], output: [rows, D].

func RaggedAttentionForward added in v0.3.0

func RaggedAttentionForward(
	Q, K, V, O unsafe.Pointer,
	seqLens, cumSeqLens unsafe.Pointer,
	batch, numQHeads, numKVHeads, headDim int,
	stream unsafe.Pointer,
) error

RaggedAttentionForward computes scaled dot-product attention for variable-length sequences packed into a single batch (ragged batching). A block-diagonal attention mask prevents cross-sequence attention.

Q: [totalTokens * numQHeads, headDim] -- packed queries. K: [totalTokens * numKVHeads, headDim] -- packed keys. V: [totalTokens * numKVHeads, headDim] -- packed values. O: [totalTokens * numQHeads, headDim] -- output. seqLens: [batch] int32 -- actual sequence length for each sequence. cumSeqLens: [batch] int32 -- cumulative offsets (prefix sums, first = 0). batch: number of sequences. numQHeads: number of query heads. numKVHeads: number of KV heads. headDim: dimension per head.

func Repeat

func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s unsafe.Pointer) error

Repeat launches the repeat kernel: replicates axisDim elements along an axis. outerSize = product of dims before axis, axisDim = size of axis, innerSize = product of dims after axis.

func ResetCounter

func ResetCounter(counter unsafe.Pointer, value int, s unsafe.Pointer) error

ResetCounter sets a GPU-resident int32 to value.

func RoPESelect

func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer,
	halfRotary int, s unsafe.Pointer) error

RoPESelect copies halfRotary cos/sin values from the precomputed table at position counter[0]. Used for GPU-driven RoPE angle selection.

func Rsqrt

func Rsqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Rsqrt launches the elementwise rsqrt kernel: c = 1/sqrt(a).

func ScaledSoftmaxF32

func ScaledSoftmaxF32(
	input, output unsafe.Pointer,
	outer, inner, axisSize int,
	scale float32,
	stream unsafe.Pointer,
) error

ScaledSoftmaxF32 applies fused scaled softmax: output = softmax(input * scale).

func ScaledSoftmaxFP16

func ScaledSoftmaxFP16(
	input, output unsafe.Pointer,
	outer, inner, axisSize int,
	scale float32,
	stream unsafe.Pointer,
) error

ScaledSoftmaxFP16 applies fused scaled softmax on FP16 data with FP32 accumulation.

func SelectiveScanForward added in v0.3.0

func SelectiveScanForward(x, A, B, C, D, y unsafe.Pointer,
	batch, dModel, dState, seqLen int, s unsafe.Pointer) error

SelectiveScanForward launches the GPU selective scan kernel for Mamba/SSM.

x: [batch, d_model, seq_len] input on device A: [d_model, d_state] state matrix on device B: [batch, d_state, seq_len] input-dependent state on device C: [batch, d_state, seq_len] output-dependent state on device D: [d_model] skip connection on device (may be nil) y: [batch, d_model, seq_len] output on device

func SgemvM1

func SgemvM1(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error

SgemvM1 computes y = A*x for M=1 decode (single-token GEMV). y[M], A[M x N] row-major, x[N]. All FP32.

func Sin

func Sin(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sin launches the elementwise sin kernel: c = sin(a).

func Softmax

func Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error

Softmax launches the softmax kernel along an axis.

func Sqrt

func Sqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sqrt launches the elementwise sqrt kernel: c = sqrt(a).

func Sub

func Sub(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sub launches the elementwise subtract kernel: c = a - b.

func SubBroadcast

func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

SubBroadcast launches the broadcast sub kernel.

func SubBroadcast4D

func SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

SubBroadcast4D launches the 4D broadcast sub kernel.

func SubFP16

func SubFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

SubFP16 launches the FP16 elementwise subtract kernel: c = a - b.

func SubScalar

func SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

SubScalar launches the scalar subtract kernel: c = a - scalar.

func SumAxis

func SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error

SumAxis launches the sum-reduction kernel along an axis.

func Tanh

func Tanh(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Tanh launches the elementwise tanh kernel: c = tanh(a).

func TanhPrime

func TanhPrime(a, upstream, c unsafe.Pointer, n int, s unsafe.Pointer) error

TanhPrime launches the tanh derivative kernel: c = (1 - tanh(a)^2) * upstream.

func Transpose2D

func Transpose2D(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error

Transpose2D launches the tiled 2D transpose kernel. Input: [rows, cols] -> Output: [cols, rows].

func TransposeND

func TransposeND(input, output unsafe.Pointer,
	inStrides, outStrides, perm []int32,
	ndim, total int, s unsafe.Pointer) error

TransposeND launches the general N-D transpose kernel.

Types

type KernelLib

type KernelLib struct {
	// contains filtered or unexported fields
}

KernelLib holds dlopen'd function pointers for custom CUDA kernels compiled into libkernels.so.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL