Documentation
¶
Overview ¶
Package kernels provides Go wrappers for custom CUDA kernels. (Stability: stable) Build libkernels.a first: cd internal/cuda/kernels && make All functional code requires the "cuda" build tag. Stability: stable
Index ¶
- func Add(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func AddBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func AddFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Argmax(input unsafe.Pointer, result unsafe.Pointer, scratch unsafe.Pointer, n int, ...) error
- func Cos(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error
- func DequantQ4KF32(src, dst unsafe.Pointer, rows, K int, stream unsafe.Pointer) error
- func Div(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func DivBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func DivFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Exp(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func F32ToFP16(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error
- func FP8Add(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error
- func FP8Mul(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error
- func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, ...) error
- func FP16ToF32(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error
- func Fill(data unsafe.Pointer, value float32, n int, s unsafe.Pointer) error
- func FlashAttentionDecode(Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, ...) error
- func FlashAttentionForward(Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, ...) error
- func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, ...) error
- func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, ...) error
- func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func FusedRoPEF32(input, cosAngles, sinAngles, output unsafe.Pointer, ...) error
- func FusedSwiGLUF32(w1, w3, output unsafe.Pointer, n int, stream unsafe.Pointer) error
- func Gather(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, ...) error
- func GatherI32(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, ...) error
- func GemmQ4F32(A_q4, B, C unsafe.Pointer, M, K, N, dataOffset int, stream unsafe.Pointer) error
- func GemmQ8F32(A_q8, B, C unsafe.Pointer, M, K, N int, stream unsafe.Pointer) error
- func GemvQ4KF32(W_q4k, x, y unsafe.Pointer, M, K int, stream unsafe.Pointer) error
- func IncrementCounter(counter unsafe.Pointer, delta int, s unsafe.Pointer) error
- func Log(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Mul(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func MulBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func MulFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error
- func OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error
- func Pow(base, exp, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, ...) error
- func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, ...) error
- func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, ...) error
- func ResetCounter(counter unsafe.Pointer, value int, s unsafe.Pointer) error
- func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, ...) error
- func Rsqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func ScaledSoftmaxF32(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func ScaledSoftmaxFP16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, ...) error
- func SgemvM1(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error
- func Sin(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error
- func Sqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Sub(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
- func SubBroadcast4D(a, b, c unsafe.Pointer, ...) error
- func SubFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error
- func Tanh(a, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func TanhPrime(a, upstream, c unsafe.Pointer, n int, s unsafe.Pointer) error
- func Transpose2D(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error
- func TransposeND(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ...) error
- type KernelLib
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func AddBroadcast ¶
func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
AddBroadcast launches the broadcast add kernel.
func AddBroadcast4D ¶
func AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
AddBroadcast4D launches the 4D broadcast add kernel.
func Argmax ¶
func Argmax(input unsafe.Pointer, result unsafe.Pointer, scratch unsafe.Pointer, n int, s unsafe.Pointer) error
Argmax launches the GPU argmax kernel. input: [n] float32 on device, result: single int32 on device, scratch: device temp storage of at least 2*ceil(n/256)*4 bytes.
func DequantFP8E4M3ToFP16 ¶
func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error
DequantFP8E4M3ToFP16 launches the FP8 E4M3 -> FP16 dequantization kernel. input: n bytes of FP8 E4M3 data, output: n FP16 values, scale: per-tensor scale factor.
func DequantQ4KF32 ¶
DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory. src is raw Q4_K super-blocks, dst is [rows, K] FP32.
func DivBroadcast ¶
func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
DivBroadcast launches the broadcast div kernel.
func DivBroadcast4D ¶
func DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
DivBroadcast4D launches the 4D broadcast div kernel.
func FP8Add ¶
FP8Add launches the FP8 dequant+add kernel: c[i] = dequant(a[i])*scaleA + dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.
func FP8Mul ¶
FP8Mul launches the FP8 dequant+mul kernel: c[i] = dequant(a[i])*scaleA * dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.
func FP8RMSNorm ¶
func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, s unsafe.Pointer) error
FP8RMSNorm launches the FP8 dequant+RMSNorm kernel. input: FP8 E4M3 [rows, D], weight: FP16 [D], output: FP16 [rows, D]. Dequantizes input on load, computes RMSNorm with FP32 accumulation, writes FP16.
func FlashAttentionDecode ¶
func FlashAttentionDecode( Q, K, V, O unsafe.Pointer, numBH, maxKVLen, headDim, kvLen int, kvLenPtr unsafe.Pointer, numQueryHeads, numKVHeads int, stream unsafe.Pointer, ) error
FlashAttentionDecode computes single-query attention for autoregressive decode. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).
Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.
kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).
func FlashAttentionForward ¶
func FlashAttentionForward( Q, K, V, O unsafe.Pointer, batch, heads, seqLen, headDim int, causal bool, stream unsafe.Pointer, ) error
FlashAttentionForward computes scaled dot-product attention using a fused tiled kernel. All tensors are in [batch, heads, seq_len, head_dim] layout. When causal is true, an upper-triangular mask is applied.
func FusedAddRMSNormF32 ¶
func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
FusedAddRMSNormF32 performs fused residual add + RMSNorm in one kernel launch. input: [rows, D] (read-only), residual: [rows, D] (read-only), weight: [D], normedOut: [rows, D], sumOut: [rows, D].
func FusedNormAddF32 ¶
func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch. output = rmsnorm(input, weight, eps) + residual. input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].
func FusedQKNormRoPEF32 ¶
func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, s unsafe.Pointer) error
FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].
func FusedRoPEF32 ¶
func FusedRoPEF32( input, cosAngles, sinAngles, output unsafe.Pointer, batch, seqLen, headDim, halfRotary, cosStride int, stream unsafe.Pointer, ) error
FusedRoPEF32 applies fused rotary positional embedding (RoPE) to FP32 data.
func FusedSwiGLUF32 ¶
FusedSwiGLUF32 applies fused SwiGLU activation: output[i] = w1[i] * sigmoid(w1[i]) * w3[i].
func Gather ¶
func Gather(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error
Gather launches the embedding table gather kernel with int64 indices. table: [V, D], indices: [N] int64, output: [N, D].
func GatherI32 ¶
func GatherI32(table unsafe.Pointer, indices unsafe.Pointer, output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error
GatherI32 launches the embedding table gather kernel with int32 indices. table: [V, D], indices: [N] int32, output: [N, D].
func GemmQ4F32 ¶
GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B. A_q4 is in GPU separated layout (scales then data), B is [K, N] FP32, C is [M, N] FP32. dataOffset is the byte offset from A_q4 to the packed data region.
func GemmQ8F32 ¶
GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B. A_q8 is packed Q8_0 blocks, B is [K, N] FP32, C is [M, N] FP32.
func GemvQ4KF32 ¶
GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x. W_q4k is raw Q4_K super-blocks, x is [K] FP32, y is [M] FP32.
func IncrementCounter ¶
IncrementCounter atomically increments a GPU-resident int32 by delta.
func MulBroadcast ¶
func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
MulBroadcast launches the broadcast mul kernel.
func MulBroadcast4D ¶
func MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
MulBroadcast4D launches the 4D broadcast mul kernel.
func OffsetMemcpy ¶
OffsetMemcpy copies dim floats from src to dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven KV cache append.
func OffsetMemcpyFP16 ¶
OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.
func RMSNorm ¶
func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
RMSNorm launches the fused RMSNorm kernel. input: [rows, D], weight: [D], output: [rows, D], scales: [rows].
func RMSNormFP16 ¶
func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error
RMSNormFP16 launches the FP16 RMSNorm kernel with FP32 accumulation. input: [rows, D], weight: [D], output: [rows, D].
func Repeat ¶
func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s unsafe.Pointer) error
Repeat launches the repeat kernel: replicates axisDim elements along an axis. outerSize = product of dims before axis, axisDim = size of axis, innerSize = product of dims after axis.
func ResetCounter ¶
ResetCounter sets a GPU-resident int32 to value.
func RoPESelect ¶
func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer, halfRotary int, s unsafe.Pointer) error
RoPESelect copies halfRotary cos/sin values from the precomputed table at position counter[0]. Used for GPU-driven RoPE angle selection.
func ScaledSoftmaxF32 ¶
func ScaledSoftmaxF32( input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream unsafe.Pointer, ) error
ScaledSoftmaxF32 applies fused scaled softmax: output = softmax(input * scale).
func ScaledSoftmaxFP16 ¶
func ScaledSoftmaxFP16( input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream unsafe.Pointer, ) error
ScaledSoftmaxFP16 applies fused scaled softmax on FP16 data with FP32 accumulation.
func SgemvM1 ¶
SgemvM1 computes y = A*x for M=1 decode (single-token GEMV). y[M], A[M x N] row-major, x[N]. All FP32.
func SubBroadcast ¶
func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error
SubBroadcast launches the broadcast sub kernel.
func SubBroadcast4D ¶
func SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error
SubBroadcast4D launches the 4D broadcast sub kernel.
func Transpose2D ¶
Transpose2D launches the tiled 2D transpose kernel. Input: [rows, cols] -> Output: [cols, rows].
Types ¶
Source Files
¶
- argmax_purego.go
- counter_purego.go
- dequant_q4k_purego.go
- doc.go
- elementwise_fp16_purego.go
- elementwise_purego.go
- flash_attention_purego.go
- fp8_ops_purego.go
- fused_add_rmsnorm_purego.go
- fused_norm_add_purego.go
- fused_qk_norm_rope_purego.go
- fused_rope_purego.go
- fused_swiglu_purego.go
- gather_purego.go
- gemm_q4_purego.go
- gemm_q8_purego.go
- gemv_q4k_purego.go
- offset_memcpy_purego.go
- purego.go
- rmsnorm_purego.go
- rope_select_purego.go
- scaled_softmax_purego.go
- sgemv_m1_purego.go
- transpose_purego.go