kernels

package
v1.25.2 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 26, 2026 License: Apache-2.0 Imports: 5 Imported by: 0

Documentation

Overview

Package kernels provides Go wrappers for custom CUDA kernels. (Stability: stable) Build libkernels.a first: cd internal/cuda/kernels && make All functional code requires the "cuda" build tag. Stability: stable

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func Add

func Add(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Add launches the elementwise add kernel: c = a + b.

func AddBroadcast

func AddBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

AddBroadcast launches the broadcast add kernel.

func AddBroadcast4D

func AddBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

AddBroadcast4D launches the 4D broadcast add kernel.

func AddFP16

func AddFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

AddFP16 launches the FP16 elementwise add kernel: c = a + b.

func AddScalar

func AddScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

AddScalar launches the scalar add kernel: c = a + scalar.

func Argmax

func Argmax(input unsafe.Pointer, result unsafe.Pointer,
	scratch unsafe.Pointer, n int, s unsafe.Pointer) error

Argmax launches the GPU argmax kernel. input: [n] float32 on device, result: single int32 on device, scratch: device temp storage of at least 2*ceil(n/256)*4 bytes.

func Cos

func Cos(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Cos launches the elementwise cos kernel: c = cos(a).

func DequantFP8E4M3ToFP16

func DequantFP8E4M3ToFP16(input, output unsafe.Pointer, scale float32, n int, s unsafe.Pointer) error

DequantFP8E4M3ToFP16 launches the FP8 E4M3 -> FP16 dequantization kernel. input: n bytes of FP8 E4M3 data, output: n FP16 values, scale: per-tensor scale factor.

func DequantQ4KF32

func DequantQ4KF32(
	src, dst unsafe.Pointer,
	rows, K int,
	stream unsafe.Pointer,
) error

DequantQ4KF32 dequantizes Q4_K super-blocks to FP32 in global memory. src is raw Q4_K super-blocks, dst is [rows, K] FP32.

func Div

func Div(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Div launches the elementwise divide kernel: c = a / b.

func DivBroadcast

func DivBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

DivBroadcast launches the broadcast div kernel.

func DivBroadcast4D

func DivBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

DivBroadcast4D launches the 4D broadcast div kernel.

func DivFP16

func DivFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

DivFP16 launches the FP16 elementwise divide kernel: c = a / b.

func DivScalar

func DivScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

DivScalar launches the scalar divide kernel: c = a / scalar.

func Exp

func Exp(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Exp launches the elementwise exp kernel: c = exp(a).

func F32ToFP16

func F32ToFP16(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error

F32ToFP16 converts n float32 elements to FP16 on GPU.

func FP8Add

func FP8Add(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error

FP8Add launches the FP8 dequant+add kernel: c[i] = dequant(a[i])*scaleA + dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.

func FP8Mul

func FP8Mul(a, b, c unsafe.Pointer, scaleA, scaleB float32, n int, s unsafe.Pointer) error

FP8Mul launches the FP8 dequant+mul kernel: c[i] = dequant(a[i])*scaleA * dequant(b[i])*scaleB. a, b: FP8 E4M3 inputs, c: FP16 output.

func FP8RMSNorm

func FP8RMSNorm(input, weight, output unsafe.Pointer, scale, eps float32, rows, D int, s unsafe.Pointer) error

FP8RMSNorm launches the FP8 dequant+RMSNorm kernel. input: FP8 E4M3 [rows, D], weight: FP16 [D], output: FP16 [rows, D]. Dequantizes input on load, computes RMSNorm with FP32 accumulation, writes FP16.

func FP16ToF32

func FP16ToF32(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error

FP16ToF32 converts n FP16 elements to float32 on GPU.

func Fill

func Fill(data unsafe.Pointer, value float32, n int, s unsafe.Pointer) error

Fill launches the fill kernel: sets all elements to value.

func FlashAttentionDecode

func FlashAttentionDecode(
	Q, K, V, O unsafe.Pointer,
	numBH, maxKVLen, headDim, kvLen int,
	kvLenPtr unsafe.Pointer,
	numQueryHeads, numKVHeads int,
	stream unsafe.Pointer,
) error

FlashAttentionDecode computes single-query attention for autoregressive decode. Supports GQA: numQueryHeads may differ from numKVHeads (must be a multiple).

Q: [batch*numQueryHeads, 1, headDim] -- single query per head. K: [batch*numKVHeads, maxKVLen, headDim] -- pre-allocated KV cache buffer. V: [batch*numKVHeads, maxKVLen, headDim] O: [batch*numQueryHeads, 1, headDim] -- output.

kvLen is the actual KV sequence length (used when kvLenPtr is nil). kvLenPtr is a GPU-resident int32 pointer; when non-nil the kernel reads the KV length from GPU memory at runtime, making it compatible with CUDA graph replay (the value is not frozen at capture time).

func FlashAttentionForward

func FlashAttentionForward(
	Q, K, V, O unsafe.Pointer,
	batch, heads, seqLen, headDim int,
	causal bool,
	stream unsafe.Pointer,
) error

FlashAttentionForward computes scaled dot-product attention using a fused tiled kernel. All tensors are in [batch, heads, seq_len, head_dim] layout. When causal is true, an upper-triangular mask is applied.

func FusedAddRMSNormF32

func FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

FusedAddRMSNormF32 performs fused residual add + RMSNorm in one kernel launch. input: [rows, D] (read-only), residual: [rows, D] (read-only), weight: [D], normedOut: [rows, D], sumOut: [rows, D].

func FusedNormAddF32

func FusedNormAddF32(input, weight, residual, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

FusedNormAddF32 applies RMSNorm then adds residual in one kernel launch. output = rmsnorm(input, weight, eps) + residual. input: [rows, D], weight: [D], residual: [rows, D], output: [rows, D].

func FusedQKNormRoPEF32

func FusedQKNormRoPEF32(input, weightQ, weightK, cosAngles, sinAngles, output unsafe.Pointer, eps float32, totalHeads, headDim, numQHeads, halfRotary int, s unsafe.Pointer) error

FusedQKNormRoPEF32 applies per-head RMSNorm + RoPE to combined Q+K heads. input: [totalHeads, headDim], weightQ/weightK: [headDim], cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim].

func FusedRoPEF32

func FusedRoPEF32(
	input, cosAngles, sinAngles, output unsafe.Pointer,
	batch, seqLen, headDim, halfRotary, cosStride int,
	stream unsafe.Pointer,
) error

FusedRoPEF32 applies fused rotary positional embedding (RoPE) to FP32 data.

func FusedSwiGLUF32

func FusedSwiGLUF32(
	w1, w3, output unsafe.Pointer,
	n int,
	stream unsafe.Pointer,
) error

FusedSwiGLUF32 applies fused SwiGLU activation: output[i] = w1[i] * sigmoid(w1[i]) * w3[i].

func Gather

func Gather(table unsafe.Pointer, indices unsafe.Pointer,
	output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error

Gather launches the embedding table gather kernel with int64 indices. table: [V, D], indices: [N] int64, output: [N, D].

func GatherI32

func GatherI32(table unsafe.Pointer, indices unsafe.Pointer,
	output unsafe.Pointer, N, D, V int, s unsafe.Pointer) error

GatherI32 launches the embedding table gather kernel with int32 indices. table: [V, D], indices: [N] int32, output: [N, D].

func GemmQ4F32

func GemmQ4F32(
	A_q4, B, C unsafe.Pointer,
	M, K, N, dataOffset int,
	stream unsafe.Pointer,
) error

GemmQ4F32 performs Q4_0 dequant-GEMM: C = dequant(A_q4) * B. A_q4 is in GPU separated layout (scales then data), B is [K, N] FP32, C is [M, N] FP32. dataOffset is the byte offset from A_q4 to the packed data region.

func GemmQ8F32

func GemmQ8F32(
	A_q8, B, C unsafe.Pointer,
	M, K, N int,
	stream unsafe.Pointer,
) error

GemmQ8F32 performs Q8_0 dequant-GEMM: C = dequant(A_q8) * B. A_q8 is packed Q8_0 blocks, B is [K, N] FP32, C is [M, N] FP32.

func GemvQ4KF32

func GemvQ4KF32(
	W_q4k, x, y unsafe.Pointer,
	M, K int,
	stream unsafe.Pointer,
) error

GemvQ4KF32 performs Q4_K fused dequant-GEMV: y = dequant(W_q4k) * x. W_q4k is raw Q4_K super-blocks, x is [K] FP32, y is [M] FP32.

func IncrementCounter

func IncrementCounter(counter unsafe.Pointer, delta int, s unsafe.Pointer) error

IncrementCounter atomically increments a GPU-resident int32 by delta.

func Log

func Log(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Log launches the elementwise log kernel: c = log(a).

func Mul

func Mul(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Mul launches the elementwise multiply kernel: c = a * b.

func MulBroadcast

func MulBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

MulBroadcast launches the broadcast mul kernel.

func MulBroadcast4D

func MulBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

MulBroadcast4D launches the 4D broadcast mul kernel.

func MulFP16

func MulFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

MulFP16 launches the FP16 elementwise multiply kernel: c = a * b.

func MulScalar

func MulScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

MulScalar launches the scalar multiply kernel: c = a * scalar.

func OffsetMemcpy

func OffsetMemcpy(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error

OffsetMemcpy copies dim floats from src to dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven KV cache append.

func OffsetMemcpyFP16

func OffsetMemcpyFP16(dst, src, counter unsafe.Pointer, dim, maxSeqLen int, s unsafe.Pointer) error

OffsetMemcpyFP16 copies dim floats from F32 src to FP16 dst at offset counter*dim. counter is a GPU-resident int32. Used for GPU-driven FP16 KV cache append.

func Pow

func Pow(base, exp, c unsafe.Pointer, n int, s unsafe.Pointer) error

Pow launches the elementwise power kernel: c = base ^ exp.

func PowScalar

func PowScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

PowScalar launches the scalar power kernel: c = pow(a, scalar).

func RMSNorm

func RMSNorm(input, weight, output, scales unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

RMSNorm launches the fused RMSNorm kernel. input: [rows, D], weight: [D], output: [rows, D], scales: [rows].

func RMSNormFP16

func RMSNormFP16(input, weight, output unsafe.Pointer, eps float32, rows, D int, s unsafe.Pointer) error

RMSNormFP16 launches the FP16 RMSNorm kernel with FP32 accumulation. input: [rows, D], weight: [D], output: [rows, D].

func Repeat

func Repeat(src, dst unsafe.Pointer, outerSize, axisDim, innerSize, reps int, s unsafe.Pointer) error

Repeat launches the repeat kernel: replicates axisDim elements along an axis. outerSize = product of dims before axis, axisDim = size of axis, innerSize = product of dims after axis.

func ResetCounter

func ResetCounter(counter unsafe.Pointer, value int, s unsafe.Pointer) error

ResetCounter sets a GPU-resident int32 to value.

func RoPESelect

func RoPESelect(cosTable, sinTable, cosOut, sinOut, counter unsafe.Pointer,
	halfRotary int, s unsafe.Pointer) error

RoPESelect copies halfRotary cos/sin values from the precomputed table at position counter[0]. Used for GPU-driven RoPE angle selection.

func Rsqrt

func Rsqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Rsqrt launches the elementwise rsqrt kernel: c = 1/sqrt(a).

func ScaledSoftmaxF32

func ScaledSoftmaxF32(
	input, output unsafe.Pointer,
	outer, inner, axisSize int,
	scale float32,
	stream unsafe.Pointer,
) error

ScaledSoftmaxF32 applies fused scaled softmax: output = softmax(input * scale).

func ScaledSoftmaxFP16

func ScaledSoftmaxFP16(
	input, output unsafe.Pointer,
	outer, inner, axisSize int,
	scale float32,
	stream unsafe.Pointer,
) error

ScaledSoftmaxFP16 applies fused scaled softmax on FP16 data with FP32 accumulation.

func SgemvM1

func SgemvM1(y, A, x unsafe.Pointer, M, N int, s unsafe.Pointer) error

SgemvM1 computes y = A*x for M=1 decode (single-token GEMV). y[M], A[M x N] row-major, x[N]. All FP32.

func Sin

func Sin(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sin launches the elementwise sin kernel: c = sin(a).

func Softmax

func Softmax(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error

Softmax launches the softmax kernel along an axis.

func Sqrt

func Sqrt(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sqrt launches the elementwise sqrt kernel: c = sqrt(a).

func Sub

func Sub(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

Sub launches the elementwise subtract kernel: c = a - b.

func SubBroadcast

func SubBroadcast(a, b, c unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, s unsafe.Pointer) error

SubBroadcast launches the broadcast sub kernel.

func SubBroadcast4D

func SubBroadcast4D(a, b, c unsafe.Pointer, d0, d1, d2, d3, sa0, sa1, sa2, sa3, sb0, sb1, sb2, sb3 int, s unsafe.Pointer) error

SubBroadcast4D launches the 4D broadcast sub kernel.

func SubFP16

func SubFP16(a, b, c unsafe.Pointer, n int, s unsafe.Pointer) error

SubFP16 launches the FP16 elementwise subtract kernel: c = a - b.

func SubScalar

func SubScalar(a unsafe.Pointer, scalar float32, c unsafe.Pointer, n int, s unsafe.Pointer) error

SubScalar launches the scalar subtract kernel: c = a - scalar.

func SumAxis

func SumAxis(input, output unsafe.Pointer, outer, inner, axisSize int, s unsafe.Pointer) error

SumAxis launches the sum-reduction kernel along an axis.

func Tanh

func Tanh(a, c unsafe.Pointer, n int, s unsafe.Pointer) error

Tanh launches the elementwise tanh kernel: c = tanh(a).

func TanhPrime

func TanhPrime(a, upstream, c unsafe.Pointer, n int, s unsafe.Pointer) error

TanhPrime launches the tanh derivative kernel: c = (1 - tanh(a)^2) * upstream.

func Transpose2D

func Transpose2D(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error

Transpose2D launches the tiled 2D transpose kernel. Input: [rows, cols] -> Output: [cols, rows].

func TransposeND

func TransposeND(input, output unsafe.Pointer,
	inStrides, outStrides, perm []int32,
	ndim, total int, s unsafe.Pointer) error

TransposeND launches the general N-D transpose kernel.

Types

type KernelLib

type KernelLib struct {
	// contains filtered or unexported fields
}

KernelLib holds dlopen'd function pointers for custom CUDA kernels compiled into libkernels.so.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL