Documentation
¶
Overview ¶
Package kernels provides Go wrappers for custom CUDA kernels. Build libkernels.a first: cd internal/cuda/kernels && make All functional code requires the "cuda" build tag.
Index ¶
- func Add(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func AddBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func AddFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Argmax(input unsafe.Pointer, result unsafe.Pointer, scratch unsafe.Pointer, n int, ...) error
- func Cos(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error
- func DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, stream unsafe.Pointer) error
- func Div(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func DivBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func DivFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Exp(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func F32ToFP16(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error
- func FP4GemvF16(wFP4, scales, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func FP8Add(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error
- func FP8Gemm(a, b, c unsafe.Pointer, m, k, n int, scaleA, scaleB float32, ...) error
- func FP8Mul(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error
- func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, ...) error
- func FP16ToF32(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error
- func Fill(data unsafe.Pointer, value float32, n int, s unsafe.Pointer) error
- func FlashAttention2Decode(Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, ...) error
- func FlashAttention2Forward(Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, ...) error
- func FlashAttentionDecode(Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, ...) error
- func FlashAttentionForward(Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, ...) error
- func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, ...) error
- func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, ...) error
- func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, stream unsafe.Pointer) error
- func Gather(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, ...) error
- func GatherI32(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, ...) error
- func GemmQ4F32(A_q4, B, C unsafe.Pointer, M, K, N, dataOffset int, stream unsafe.Pointer) error
- func GemmQ8F32(A_q8, B, C unsafe.Pointer, M, K, N int, stream unsafe.Pointer) error
- func GemvQ4KDp4aF32(W_q4k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvQ4KDp4aF32Available() bool
- func GemvQ4KF32(W_q4k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvQ4KSm121F32(W_q4k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvQ5KF32(W_q5k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvQ5_0F32(W_q5_0, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvQ6KF32(W_q6k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func GemvWarpF16(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error
- func GemvWarpF32(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error
- func IncrementCounter(counter unsafe.Pointer, delta int, s unsafe.Pointer) error
- func IsFP4GemvSupported() bool
- func IsFP8GemmSupported() bool
- func IsPagedAttentionSupported() bool
- func IsQ4KSm121Supported() bool
- func IsRaggedAttentionSupported() bool
- func IsSelectiveScanSupported() bool
- func Log(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Mul(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func MulBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func MulFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error
- func OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error
- func PagedAttentionForward(Q, O unsafe.Pointer, blockPtrsK, blockPtrsV unsafe.Pointer, ...) error
- func Pow(base, exp, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error
- func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, ...) error
- func RaggedAttentionForward(Q, K, V, O unsafe.Pointer, seqLens, cumSeqLens unsafe.Pointer, ...) error
- func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, ...) error
- func ResetCounter(counter unsafe.Pointer, value int, s unsafe.Pointer) error
- func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, ...) error
- func Rsqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func SelectiveScanForward(x, A, B, C, D, y unsafe.Pointer, batch, dModel, dState, seqLen int, ...) error
- func SgemvM1(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error
- func Sin(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error
- func Sqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Sub(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func SubBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func SubFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error
- func Tanh(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func TanhPrime(a, upstream, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Transpose2D(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error
- func TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error
- type KernelLib
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func AddBroadcast ¶
func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
AddBroadcast launches the broadcast add kernel.
func AddBroadcast4D ¶
func AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
AddBroadcast4D launches the 4D broadcast add kernel.
func Argmax ¶
func Argmax(input unsafe.Pointer, result unsafe.Pointer, scratch unsafe.Pointer, n int, s unsafe.Pointer) error
Argmax launches the GPU argmax kernel. input: [n] float32 on device, result: single int32 on device, scratch: device temp storage of at least 2*ceil(n/256)*4 bytes.
func DequantFP8E4M3ToFP16 ¶
func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error
DequantFP8E4M3ToFP16 launches the FP8 E4M3 -> FP16 dequantization kernel. input: n bytes of FP8 E4M3 data, output: n FP16 values, scale: per-tensor scale factor.
func DequantQ4KF32 ¶
DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory. src is raw Q4_K super-blocks, dst is [rows, K] FP32.
func DivBroadcast ¶
func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
DivBroadcast launches the broadcast div kernel.
func DivBroadcast4D ¶
func DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
DivBroadcast4D launches the 4D broadcast div kernel.
func FP4GemvF16 ¶ added in v0.3.0
FP4GemvF16 performs NVFP4 fused dequant-GEMV with FP16 activations:
y[m] = sum_k( dequant(W_fp4[m,k]) * x_fp16[k] )
W_fp4: device pointer to packed NVFP4 data [M, K] (8 bytes per block of 16). scales: device pointer to [M * ceil(K/16)] float16 block scales. x: device pointer to [K] float16 input vector. y: device pointer to [M] float32 output vector.
func FP8Add ¶
FP8Add launches the FP8 dequant+add kernel: c[i] = dequant(a[i])*scaleA + dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.
func FP8Gemm ¶ added in v0.3.0
func FP8Gemm(a, b, c unsafe.Pointer, m, k, n int, scaleA, scaleB float32, stream unsafe.Pointer) error
FP8Gemm launches an FP8 E4M3 GEMM using cublasLt. A: [M, K] FP8 E4M3, B: [K, N] FP8 E4M3, C: [M, N] FP16 output. scaleA and scaleB are per-tensor dequantization scales. The output is computed as: C = (scaleA * scaleB) * (A @ B) in FP16.
func FP8Mul ¶
FP8Mul launches the FP8 dequant+mul kernel: c[i] = dequant(a[i])*scaleA * dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.
func FP8RMSNorm ¶
func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, s unsafe.Pointer) error
FP8RMSNorm launches the FP8 dequant+RMSNorm kernel. input: FP8 E4M3 [rows, D], weight: FP16 [D], output: FP16 [rows, D]. Dequantizes input on load, computes RMSNorm with FP32 accumulation, writes FP16.
func FlashAttention2Decode ¶ added in v0.3.0
func FlashAttention2Decode( Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, kvLenPtr unsafe.Pointer, numQueryHeads, numKVHeads int, stream unsafe.Pointer, ) error
FlashAttention2Decode computes single-query attention for autoregressive decode using FlashAttention-2 with multi-warp KV parallelism. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).
Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.
kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).
func FlashAttention2Forward ¶ added in v0.3.0
func FlashAttention2Forward( Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, stream unsafe.Pointer, ) error
FlashAttention2Forward computes scaled dot-product attention using the FlashAttention-2 tiled algorithm. All tensors are [batch, heads, seq_len, head_dim] in row-major order. When causal is true, an upper-triangular mask is applied. Memory usage is O(N), not O(N^2).
func FlashAttentionDecode ¶
func FlashAttentionDecode( Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, kvLenPtr unsafe.Pointer, numQueryHeads, numKVHeads int, stream unsafe.Pointer, ) error
FlashAttentionDecode computes single-query attention for autoregressive decode. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).
Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.
kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).
func FlashAttentionForward ¶
func FlashAttentionForward( Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, stream unsafe.Pointer, ) error
FlashAttentionForward computes scaled dot-product attention using a fused tiled kernel. All tensors are in [batch, heads, seq_len, head_dim] layout. When causal is true, an upper-triangular mask is applied.
func FusedAddRMSNormF32 ¶
func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
FusedAddRMSNormF32 performs fused residual add + RMSNorm in one kernel launch. input: [rows, D] (read-only), residual: [rows, D] (read-only), weight: [D], normedOut: [rows, D], sumOut: [rows, D].
func FusedNormAddF32 ¶
func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch. output = rmsnorm(input, weight, eps) + residual. input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].
func FusedQKNormRoPEF32 ¶
func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, s unsafe.Pointer) error
FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].
func FusedRoPEF32 ¶
func FusedRoPEF32( input, cosAngles, sinAngles, output unsafe.Pointer, batch, seqLen, headDim, halfRotary, cosStride int, stream unsafe.Pointer, ) error
FusedRoPEF32 applies fused rotary positional embedding (RoPE) to FP32 data.
func FusedSwiGLUF32 ¶
FusedSwiGLUF32 applies fused SwiGLU activation: output[i] = w1[i] * sigmoid(w1[i]) * w3[i].
func Gather ¶
func Gather(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error
Gather launches the embedding table gather kernel with int64 indices. table: [V, D], indices: [N] int64, output: [N, D].
func GatherI32 ¶
func GatherI32(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error
GatherI32 launches the embedding table gather kernel with int32 indices. table: [V, D], indices: [N] int32, output: [N, D].
func GemmQ4F32 ¶
GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B. A_q4 is in GPU separated layout (scales then data), B is [K, N] FP32, C is [M, N] FP32. dataOffset is the byte offset from A_q4 to the packed data region.
func GemmQ8F32 ¶
GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B. A_q8 is packed Q8_0 blocks, B is [K, N] FP32, C is [M, N] FP32.
func GemvQ4KDp4aF32 ¶ added in v0.3.0
GemvQ4KDp4aF32 performs Q4_K fused dequant-GEMV using dp4a INT8 dot-product. Same interface as GemvQ4KF32 but uses __dp4a for higher throughput.
func GemvQ4KDp4aF32Available ¶ added in v0.3.0
func GemvQ4KDp4aF32Available() bool
GemvQ4KDp4aF32Available reports whether the dp4a INT8 Q4_K GEMV kernel is loaded.
func GemvQ4KF32 ¶
GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x. W_q4k is raw Q4_K super-blocks, x is [K] FP32, y is [M] FP32.
func GemvQ4KSm121F32 ¶ added in v0.3.0
GemvQ4KSm121F32 performs Q4_K fused dequant-GEMV using the sm_121 optimized kernel (8 warps/block, vectorized 128-bit loads, __ldcg activation caching).
Falls back to GemvQ4KF32 when the sm_121 kernel is unavailable. K must be a multiple of 256.
func GemvQ5KF32 ¶ added in v0.3.0
GemvQ5KF32 performs Q5_K fused dequant-GEMV: y = dequant(W_q5k) * x. W_q5k is raw Q5_K super-blocks, x is [K] FP32, y is [M] FP32.
func GemvQ5_0F32 ¶ added in v0.3.0
GemvQ5_0F32 performs Q5_0 fused dequant-GEMV: y = dequant(W_q5_0) * x. W_q5_0 is raw Q5_0 blocks, x is [K] FP32, y is [M] FP32.
func GemvQ6KF32 ¶ added in v0.3.0
GemvQ6KF32 performs Q6_K fused dequant-GEMV: y = dequant(W_q6k) * x. W_q6k is raw Q6_K super-blocks, x is [K] FP32, y is [M] FP32.
func GemvWarpF16 ¶ added in v0.3.0
GemvWarpF16 computes y = A*x using the warp-specialized GEMV kernel (FP16). Each warp handles a different output row tile for decode-phase (batch=1) workloads. y[M], A[M x N] row-major, x[N]. All FP16. Accumulation in FP32 for precision.
func GemvWarpF32 ¶ added in v0.3.0
GemvWarpF32 computes y = A*x using the warp-specialized GEMV kernel (FP32). Each warp handles a different output row tile for decode-phase (batch=1) workloads. y[M], A[M x N] row-major, x[N]. All FP32.
func IncrementCounter ¶
IncrementCounter atomically increments a GPU-resident int32 by delta.
func IsFP4GemvSupported ¶ added in v0.3.0
func IsFP4GemvSupported() bool
IsFP4GemvSupported returns true if the current GPU supports NVFP4 GEMV (requires sm_100+ Blackwell architecture).
func IsFP8GemmSupported ¶ added in v0.3.0
func IsFP8GemmSupported() bool
IsFP8GemmSupported returns true if the current GPU supports FP8 GEMM (requires sm_89+ Ada Lovelace architecture).
func IsPagedAttentionSupported ¶ added in v0.3.0
func IsPagedAttentionSupported() bool
IsPagedAttentionSupported returns true if the paged attention kernel symbol was loaded from libkernels.so.
func IsQ4KSm121Supported ¶ added in v0.3.0
func IsQ4KSm121Supported() bool
IsQ4KSm121Supported reports whether the loaded kernel library contains the sm_121 optimized Q4_K GEMV kernel AND the current GPU is sm_12x (Blackwell). The result is cached after the first call.
func IsRaggedAttentionSupported ¶ added in v0.3.0
func IsRaggedAttentionSupported() bool
IsRaggedAttentionSupported returns true if the ragged attention kernel symbol was loaded from libkernels.so.
func IsSelectiveScanSupported ¶ added in v0.3.0
func IsSelectiveScanSupported() bool
IsSelectiveScanSupported reports whether the selective scan kernel is available.
func MulBroadcast ¶
func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
MulBroadcast launches the broadcast mul kernel.
func MulBroadcast4D ¶
func MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
MulBroadcast4D launches the 4D broadcast mul kernel.
func OffsetMemcpy ¶
OffsetMemcpy copies dim floats from src to dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven KV cache append.
func OffsetMemcpyFP16 ¶
OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.
func PagedAttentionForward ¶ added in v0.3.0
func PagedAttentionForward( Q, O unsafe.Pointer, blockPtrsK, blockPtrsV unsafe.Pointer, blockIndices unsafe.Pointer, seqLen, blockSize, headDim int, numQHeads, numKVHeads int, batch int, stream unsafe.Pointer, ) error
PagedAttentionForward computes scaled dot-product attention with block-table indirection for paged KV caches.
Q: [batch*numQHeads, headDim] -- single query per head. O: [batch*numQHeads, headDim] -- output, same shape as Q. blockPtrsK: device array of float* pointers to K blocks.
Each block holds [blockSize, numKVHeads, headDim] floats.
blockPtrsV: device array of float* pointers to V blocks (same layout). blockIndices: device array [batch * maxNumBlocks] mapping logical block
index to physical block index in blockPtrs arrays.
seqLen: actual number of valid K/V token positions. blockSize: number of token positions per block. headDim: dimension per head. numQHeads: number of query heads per batch element. numKVHeads: number of KV heads per batch element. batch: number of sequences in the batch.
func RMSNorm ¶
func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
RMSNorm launches the fused RMSNorm kernel. input: [rows, D], weight: [D], output: [rows, D], scales: [rows].
func RMSNormFP16 ¶
func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
RMSNormFP16 launches the FP16 RMSNorm kernel with FP32 accumulation. input: [rows, D], weight: [D], output: [rows, D].
func RaggedAttentionForward ¶ added in v0.3.0
func RaggedAttentionForward( Q, K, V, O unsafe.Pointer, seqLens, cumSeqLens unsafe.Pointer, batch, numQHeads, numKVHeads, headDim int, stream unsafe.Pointer, ) error
RaggedAttentionForward computes scaled dot-product attention for variable-length sequences packed into a single batch (ragged batching). A block-diagonal attention mask prevents cross-sequence attention.
Q: [totalTokens * numQHeads, headDim] -- packed queries. K: [totalTokens * numKVHeads, headDim] -- packed keys. V: [totalTokens * numKVHeads, headDim] -- packed values. O: [totalTokens * numQHeads, headDim] -- output. seqLens: [batch] int32 -- actual sequence length for each sequence. cumSeqLens: [batch] int32 -- cumulative offsets (prefix sums, first = 0). batch: number of sequences. numQHeads: number of query heads. numKVHeads: number of KV heads. headDim: dimension per head.
func Repeat ¶
func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s unsafe.Pointer) error
Repeat launches the repeat kernel: replicates axisDim elements along an axis. outerSize = product of dims before axis, axisDim = size of axis, innerSize = product of dims after axis.
func ResetCounter ¶
ResetCounter sets a GPU-resident int32 to value.
func RoPESelect ¶
func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, s unsafe.Pointer) error
RoPESelect copies halfRotary cos/sin values from the precomputed table at position counter[0]. Used for GPU-driven RoPE angle selection.
func ScaledSoftmaxF32 ¶
func ScaledSoftmaxF32( input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream unsafe.Pointer, ) error
ScaledSoftmaxF32 applies fused scaled softmax: output = softmax(input * scale).
func ScaledSoftmaxFP16 ¶
func ScaledSoftmaxFP16( input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream unsafe.Pointer, ) error
ScaledSoftmaxFP16 applies fused scaled softmax on FP16 data with FP32 accumulation.
func SelectiveScanForward ¶ added in v0.3.0
func SelectiveScanForward(x, A, B, C, D, y unsafe.Pointer, batch, dModel, dState, seqLen int, s unsafe.Pointer) error
SelectiveScanForward launches the GPU selective scan kernel for Mamba/SSM.
x: [batch, d_model, seq_len] input on device A: [d_model, d_state] state matrix on device B: [batch, d_state, seq_len] input-dependent state on device C: [batch, d_state, seq_len] output-dependent state on device D: [d_model] skip connection on device (may be nil) y: [batch, d_model, seq_len] output on device
func SgemvM1 ¶
SgemvM1 computes y = A*x for M=1 decode (single-token GEMV). y[M], A[M x N] row-major, x[N]. All FP32.
func SubBroadcast ¶
func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
SubBroadcast launches the broadcast sub kernel.
func SubBroadcast4D ¶
func SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
SubBroadcast4D launches the 4D broadcast sub kernel.
func Transpose2D ¶
Transpose2D launches the tiled 2D transpose kernel. Input: [rows, cols] -> Output: [cols, rows].
Types ¶
Source Files
¶
- argmax_purego.go
- counter_purego.go
- dequant_q4k_purego.go
- doc.go
- elementwise_fp16_purego.go
- elementwise_purego.go
- flash_attention2_purego.go
- flash_attention_purego.go
- fp4_gemv_purego.go
- fp8_gemm_purego.go
- fp8_ops_purego.go
- fused_add_rmsnorm_purego.go
- fused_norm_add_purego.go
- fused_qk_norm_rope_purego.go
- fused_rope_purego.go
- fused_swiglu_purego.go
- gather_purego.go
- gemm_q4_purego.go
- gemm_q8_purego.go
- gemv_q4k_purego.go
- gemv_q4k_sm121_purego.go
- gemv_q5_0_purego.go
- gemv_q5k_purego.go
- gemv_q6k_purego.go
- gemv_warp_purego.go
- offset_memcpy_purego.go
- paged_attention_purego.go
- purego.go
- ragged_attention_purego.go
- rmsnorm_purego.go
- rope_select_purego.go
- scaled_softmax_purego.go
- selective_scan_purego.go
- sgemv_m1_purego.go
- transpose_purego.go