cublas

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 16, 2026 License: Apache-2.0 Imports: 4 Imported by: 0

Documentation

Overview

Package cublas provides low-level purego bindings for the cuBLAS library. Use Available() to check if cuBLAS is loadable at runtime.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func Available

func Available() bool

Available returns true if the cuBLAS library can be loaded at runtime. The result is cached after the first call.

func GemmEx

func GemmEx(h *Handle, m, n, k int, alpha float32,
	a unsafe.Pointer, aType CudaDataType,
	b unsafe.Pointer, bType CudaDataType,
	beta float32,
	c unsafe.Pointer, cType CudaDataType,
	computeType CublasComputeType,
) error

GemmEx performs mixed-precision general matrix multiplication. Row-major to column-major conversion: swap A/B and m/n.

func LtAvailable

func LtAvailable() bool

LtAvailable returns true if the cuBLASLt library can be loaded at runtime. The result is cached after the first call.

func LtMatmul

func LtMatmul(
	h *LtHandle,
	desc *MatmulDesc,
	alpha unsafe.Pointer,
	a unsafe.Pointer, layoutA *MatrixLayout,
	b unsafe.Pointer, layoutB *MatrixLayout,
	beta unsafe.Pointer,
	c unsafe.Pointer, layoutC *MatrixLayout,
	d unsafe.Pointer, layoutD *MatrixLayout,
	algo *LtMatmulAlgoResult,
	workspace unsafe.Pointer, workspaceSize int,
	stream uintptr,
) error

LtMatmul performs a matrix multiplication using cublasLt. alpha and beta are pointers to host scalars of the scale type. stream is the CUDA stream handle (0 for default stream). workspace is optional device workspace memory (can be nil with workspaceSize=0).

func Sgemm

func Sgemm(h *Handle, m, n, k int, alpha float32,
	a unsafe.Pointer, b unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

Sgemm performs single-precision general matrix multiplication. Row-major to column-major conversion: swap A/B and m/n.

func SgemmNT

func SgemmNT(h *Handle, m, n, k int, alpha float32,
	a unsafe.Pointer, b unsafe.Pointer,
	beta float32, c unsafe.Pointer,
) error

SgemmNT performs single-precision C = A * B^T where A is [m, k] and B is [n, k] (row-major). Uses CUBLAS_OP_T on the first cuBLAS argument.

func SgemmNTStridedBatched

func SgemmNTStridedBatched(h *Handle, m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	b unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

SgemmNTStridedBatched performs batched C = A * B^T using strided batched GEMM with CUBLAS_OP_T on the B operand.

func SgemmStridedBatched

func SgemmStridedBatched(h *Handle, m, n, k int, alpha float32,
	a unsafe.Pointer, strideA int64,
	b unsafe.Pointer, strideB int64,
	beta float32,
	c unsafe.Pointer, strideC int64,
	batch int,
) error

SgemmStridedBatched performs batched single-precision GEMM with strided access. Row-major to column-major conversion: swap A/B and m/n (same trick as Sgemm).

Parameters (in row-major terms):

m        - rows of A and C per batch
n        - columns of B and C per batch
k        - columns of A / rows of B
alpha    - scalar multiplier for A*B
a        - device pointer to A[0] (m x k, row-major)
strideA  - element stride between consecutive A matrices
b        - device pointer to B[0] (k x n, row-major)
strideB  - element stride between consecutive B matrices
beta     - scalar multiplier for C
c        - device pointer to C[0] (m x n, row-major), output
strideC  - element stride between consecutive C matrices
batch    - number of matrices in the batch

Types

type CublasComputeType

type CublasComputeType int

CublasComputeType identifies the compute precision for cublasGemmEx.

const (
	CublasCompute32F CublasComputeType = 68 // CUBLAS_COMPUTE_32F
)

type CudaDataType

type CudaDataType int

CudaDataType identifies the element data type for cublasGemmEx.

const (
	CudaR32F     CudaDataType = 0  // CUDA_R_32F  (float32)
	CudaR16F     CudaDataType = 2  // CUDA_R_16F  (float16)
	CudaR16BF    CudaDataType = 14 // CUDA_R_16BF (bfloat16)
	CudaR8F_E4M3 CudaDataType = 28 // CUDA_R_8F_E4M3 (fp8 e4m3)
)

type Handle

type Handle struct {
	// contains filtered or unexported fields
}

Handle wraps a cuBLAS handle (opaque pointer).

func CreateHandle

func CreateHandle() (*Handle, error)

CreateHandle creates a new cuBLAS context handle.

func (*Handle) Destroy

func (h *Handle) Destroy() error

Destroy releases the cuBLAS handle resources.

func (*Handle) SetStream

func (h *Handle) SetStream(streamPtr unsafe.Pointer) error

SetStream associates a CUDA stream with this cuBLAS handle.

type LtComputeType

type LtComputeType int

LtComputeType specifies the compute precision for cublasLt matmul.

const (
	LtComputeF32 LtComputeType = 68 // CUBLAS_COMPUTE_32F
	LtComputeF16 LtComputeType = 64 // CUBLAS_COMPUTE_16F
)

type LtEpilogue

type LtEpilogue int32

LtEpilogue specifies the epilogue operation applied after the matmul.

const (
	LtEpilogueDefault LtEpilogue = 1   // CUBLASLT_EPILOGUE_DEFAULT
	LtEpilogueReLU    LtEpilogue = 2   // CUBLASLT_EPILOGUE_RELU
	LtEpilogueGeLU    LtEpilogue = 32  // CUBLASLT_EPILOGUE_GELU
	LtEpilogueBias    LtEpilogue = 128 // CUBLASLT_EPILOGUE_BIAS
)

type LtHandle

type LtHandle struct {
	// contains filtered or unexported fields
}

LtHandle wraps a cublasLtHandle_t (opaque pointer).

func LtCreateHandle

func LtCreateHandle() (*LtHandle, error)

LtCreateHandle creates a new cuBLASLt context handle.

func (*LtHandle) Destroy

func (h *LtHandle) Destroy() error

Destroy releases the cuBLASLt handle resources.

type LtMatmulAlgoResult

type LtMatmulAlgoResult struct {
	// contains filtered or unexported fields
}

LtMatmulAlgoResult holds the result of a heuristic algorithm search. The raw bytes correspond to cublasLtMatmulHeuristicResult_t.

func MatmulAlgoGetHeuristic

func MatmulAlgoGetHeuristic(
	h *LtHandle,
	desc *MatmulDesc,
	layoutA, layoutB, layoutC, layoutD *MatrixLayout,
	pref *MatmulPreference,
	requestedCount int,
) ([]LtMatmulAlgoResult, error)

MatmulAlgoGetHeuristic finds the best algorithm for the given matmul configuration. Returns up to requestedCount results. The actual number found is returned.

func (*LtMatmulAlgoResult) AlgoPtr

func (r *LtMatmulAlgoResult) AlgoPtr() unsafe.Pointer

AlgoPtr returns a pointer to the embedded cublasLtMatmulAlgo_t (the first 1024 bytes of the heuristic result).

type LtMatmulDescAttribute

type LtMatmulDescAttribute int

LtMatmulDescAttribute identifies attributes of a matmul descriptor.

const (
	// LtMatmulDescScaleType sets the scale type for the matmul operation.
	LtMatmulDescScaleType LtMatmulDescAttribute = 0 // CUBLASLT_MATMUL_DESC_SCALE_TYPE
	// LtMatmulDescPointerMode sets the pointer mode.
	LtMatmulDescPointerMode LtMatmulDescAttribute = 1 // CUBLASLT_MATMUL_DESC_POINTER_MODE
	// LtMatmulDescTransA sets the transpose mode for matrix A.
	LtMatmulDescTransA LtMatmulDescAttribute = 2 // CUBLASLT_MATMUL_DESC_TRANSA
	// LtMatmulDescTransB sets the transpose mode for matrix B.
	LtMatmulDescTransB LtMatmulDescAttribute = 3 // CUBLASLT_MATMUL_DESC_TRANSB
	// LtMatmulDescEpilogue sets the epilogue function.
	LtMatmulDescEpilogue LtMatmulDescAttribute = 5 // CUBLASLT_MATMUL_DESC_EPILOGUE
	// LtMatmulDescEpilogueAuxPointer sets the auxiliary epilogue pointer.
	LtMatmulDescEpilogueAuxPointer LtMatmulDescAttribute = 6 // CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
	// LtMatmulDescEpilogueAuxLd sets the leading dimension of the auxiliary epilogue buffer.
	LtMatmulDescEpilogueAuxLd LtMatmulDescAttribute = 7 // CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
	// LtMatmulDescAScalePointer sets the A scale pointer (for FP8).
	LtMatmulDescAScalePointer LtMatmulDescAttribute = 17 // CUBLASLT_MATMUL_DESC_A_SCALE_POINTER
	// LtMatmulDescBScalePointer sets the B scale pointer (for FP8).
	LtMatmulDescBScalePointer LtMatmulDescAttribute = 18 // CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
	// LtMatmulDescDScalePointer sets the D scale pointer (for FP8).
	LtMatmulDescDScalePointer LtMatmulDescAttribute = 20 // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER
)

type LtMatrixLayoutAttribute

type LtMatrixLayoutAttribute int

LtMatrixLayoutAttribute identifies attributes of a matrix layout.

const (
	// LtMatrixLayoutType sets the data type of the matrix.
	LtMatrixLayoutType LtMatrixLayoutAttribute = 0 // CUBLASLT_MATRIX_LAYOUT_TYPE
	// LtMatrixLayoutOrder sets the memory order (row/col major).
	LtMatrixLayoutOrder LtMatrixLayoutAttribute = 1 // CUBLASLT_MATRIX_LAYOUT_ORDER
	// LtMatrixLayoutRows sets the number of rows.
	LtMatrixLayoutRows LtMatrixLayoutAttribute = 2 // CUBLASLT_MATRIX_LAYOUT_ROWS
	// LtMatrixLayoutCols sets the number of columns.
	LtMatrixLayoutCols LtMatrixLayoutAttribute = 3 // CUBLASLT_MATRIX_LAYOUT_COLS
	// LtMatrixLayoutLD sets the leading dimension.
	LtMatrixLayoutLD LtMatrixLayoutAttribute = 4 // CUBLASLT_MATRIX_LAYOUT_LD
	// LtMatrixLayoutBatchCount sets the batch count.
	LtMatrixLayoutBatchCount LtMatrixLayoutAttribute = 5 // CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
	// LtMatrixLayoutStridedBatchOffset sets the strided batch offset.
	LtMatrixLayoutStridedBatchOffset LtMatrixLayoutAttribute = 6 // CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
)

type LtOrder

type LtOrder int32

LtOrder specifies memory layout order.

const (
	LtOrderCol LtOrder = 0 // CUBLASLT_ORDER_COL
	LtOrderRow LtOrder = 1 // CUBLASLT_ORDER_ROW
)

type MatmulDesc

type MatmulDesc struct {
	// contains filtered or unexported fields
}

MatmulDesc wraps a cublasLtMatmulDesc_t (opaque pointer).

func CreateMatmulDesc

func CreateMatmulDesc(computeType LtComputeType, scaleType CudaDataType) (*MatmulDesc, error)

CreateMatmulDesc creates a new matmul descriptor with the given compute type and scale type (cudaDataType for the scale, e.g. CudaR32F).

func (*MatmulDesc) Destroy

func (d *MatmulDesc) Destroy() error

Destroy releases the matmul descriptor.

func (*MatmulDesc) SetAttribute

func (d *MatmulDesc) SetAttribute(attr LtMatmulDescAttribute, value unsafe.Pointer, sizeInBytes int) error

SetAttribute sets an attribute on the matmul descriptor. value must point to the attribute value, and sizeInBytes is its size.

type MatmulPreference

type MatmulPreference struct {
	// contains filtered or unexported fields
}

MatmulPreference wraps a cublasLtMatmulPreference_t (opaque pointer).

func CreateMatmulPreference

func CreateMatmulPreference() (*MatmulPreference, error)

CreateMatmulPreference creates a new matmul preference descriptor.

func (*MatmulPreference) Destroy

func (p *MatmulPreference) Destroy() error

Destroy releases the matmul preference descriptor.

type MatrixLayout

type MatrixLayout struct {
	// contains filtered or unexported fields
}

MatrixLayout wraps a cublasLtMatrixLayout_t (opaque pointer).

func CreateMatrixLayout

func CreateMatrixLayout(dataType CudaDataType, rows, cols, ld int) (*MatrixLayout, error)

CreateMatrixLayout creates a new matrix layout descriptor. dataType is the element type (e.g. CudaR32F), rows/cols are the matrix dimensions, and ld is the leading dimension.

func (*MatrixLayout) Destroy

func (l *MatrixLayout) Destroy() error

Destroy releases the matrix layout descriptor.

func (*MatrixLayout) SetAttribute

func (l *MatrixLayout) SetAttribute(attr LtMatrixLayoutAttribute, value unsafe.Pointer, sizeInBytes int) error

SetAttribute sets an attribute on the matrix layout.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL