Documentation
¶
Overview ¶
Package compute defines abstractions for building, compiling, transferring data (buffers) and executing machine learning computation graphs in GoMLX.
The core interface is called Backend, and it is built around four core interfaces:
- DataInterface: Manages tensor data storage, memory allocation, and the transfer of data buffers to and from the backend. It abstracts the complexities of different hardware accelerators (e.g., CPU, GPU, TPU) and their respective memory models.
- Builder: Provides the API for constructing computation graphs. A graph is typically composed of a "main" function, potentially alongside other helper functions.
- Function: Represents a discrete sub-graph or logical unit of operations within a larger computation.
- Executable: Represents a compiled, ready-to-run computation graph that can be invoked with inputs on the target backend, yielding the resulting data buffers.
While conceptually inspired by OpenXLA's StableHLO, the API has diverged some. It aims to be backend-agnostic to support a diverse range of execution environments (CPU, GPU, TPU, etc.) and optimisations (e.g. JIT with static shapes vs dynamic shapes but less optimization)
Backend Implementations ¶
A Backend is not required to implement every defined operation. If a backend encounters an unsupported operation, it should gracefully return an ErrNotImplemented. There is also a [Backend.Capabilities] method that returns what the backend supports (e.g. supported dtypes, ops, or if it supports dynamic shapes).
Computations that do not rely on the missing operation work normally. In some cases, the computation can handle the error by using alternative fallbacks, and work around missing version (the strategy used for some fused ops).
The [notimplemented] package provides a default implementation (that returns ErrNotImplemented) for all methods of Builder. It's good practice for any Backend to wrap it: it greatly simplifies the implemention of the backend, and serves for future compatibility (if new ops are created, an existing backend doesn't need to be changed, it will simply gracefully fail).
Error Messages with Stack Traces ¶
By convention errors should be wrapped with a stack trace, using (for now until we find a better one) the "github.com/pkg/errors" package.
Example:
import "github.com/pkg/errors" ... return errors.Wrapf(ErrNotImplemented, "...")
[notimplemented]: pkg.go.dev/github.com/gomlx/compute/notimplemented
Index ¶
- Constants
- Variables
- func FFTTypeStrings() []string
- func IsNotImplemented(err error) bool
- func List() []string
- func OpTypeStrings() []string
- func ReduceOpTypeStrings() []string
- func Register(name string, constructor Constructor)
- type ActivationType
- type AxesLayout
- type Backend
- type Buffer
- type Builder
- type Capabilities
- type CollectiveOps
- type Constructor
- type ConvolveAxesConfig
- type DataInterface
- type DeviceNum
- type DotGeneralConfig
- type Executable
- type FFTType
- type Function
- type FusedOps
- type GGMLQuantType
- type Mesh
- type OpType
- type PadAxis
- type Quantization
- type QuantizationScheme
- type ReduceOpType
- type ScaledDotProductAttentionConfig
- type ShardingSpec
- type StandardOps
- type Value
Constants ¶
const ConfigEnvVar = "GOMLX_BACKEND"
ConfigEnvVar is the name of the environment variable with the default backend configuration to use: "GOMLX_BACKEND".
The format of the configuration is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend-specific (e.g.: for xla backend, it is the pjrt plugin name).
const MainName = "main"
Main function name, created by Builder.Main().
Variables ¶
var DefaultConfig = "xla"
DefaultConfig is the name of the default backend configuration to use if specified.
See NewWithConfig for the format of the configuration string.
var ErrNotImplemented = stderrors.New("op not implemented")
ErrNotImplemented indicates an op is not implemented (typically a fused op, but normal ops may return this as well) for the given configuration (e.g. unsupported dtype or backend). Backends should wrap this error so InternalFusedOpCaller can distinguish "not supported" from genuine bugs and fall back to the decomposed implementation.
It doesn't contain a stack, attach a stack to with with errors.Wrapf(ErrNotImplemented, "...") when using it.
var IQ4NLLookupTable = [16]float32{
-127, -104, -83, -65, -49, -35, -22, -10,
1, 13, 25, 38, 53, 69, 89, 113,
}
IQ4NLLookupTable contains the 16 fixed IQ4_NL non-linear dequantization values. These map 4-bit nibble indices to pre-normalization integer values (not final floats). Final dequantized value = per-block scale * IQ4NLLookupTable[nibble]. Values from llama.cpp's kvalues_iq4nl.
var NF4LookupTable = [16]float32{
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
}
NF4LookupTable contains the 16 fixed QLoRA NormalFloat4 dequantization values. Used by both the fused executor and the decomposed graph-level fallback.
var RNGStateShape = shapes.Make(dtypes.Uint64, 3) //nolint:mnd // This is a constant.
RNGStateShape is the default shape for the random number generator state. It dependents on the algorithm, but for now we are using Philox.
Functions ¶
func FFTTypeStrings ¶
func FFTTypeStrings() []string
FFTTypeStrings returns a slice of all String values of the enum
func IsNotImplemented ¶
IsNotImplemented checks whether the error is a ErrNotImplemented.
func OpTypeStrings ¶
func OpTypeStrings() []string
OpTypeStrings returns a slice of all String values of the enum
func ReduceOpTypeStrings ¶
func ReduceOpTypeStrings() []string
ReduceOpTypeStrings returns a slice of all String values of the enum
func Register ¶
func Register(name string, constructor Constructor)
Register backend with the given name and a default constructor that takes as input a configuration string that is passed along to the backend constructor.
To be safe, call Register during initialization of a package.
Types ¶
type ActivationType ¶
type ActivationType int
ActivationType specifies the activation function for fused operations.
const ( ActivationNone ActivationType = iota ActivationGelu ActivationRelu ActivationSilu ActivationHardSwish ActivationTanh )
func (ActivationType) String ¶
func (a ActivationType) String() string
String returns the name of the activation type.
type AxesLayout ¶
type AxesLayout int
AxesLayout specifies the ordering of axes in 4D attention tensors.
const ( // AxesLayoutBHSD is the [batch, heads, seq, dim] layout used by PyTorch's F.scaled_dot_product_attention, // ONNX, and most inference runtimes. AxesLayoutBHSD AxesLayout = iota // AxesLayoutBSHD is the [batch, seq, heads, dim] layout used internally by MultiHeadAttention // (after Dense projections which produce [batch, seq, heads, dim]). AxesLayoutBSHD )
func (AxesLayout) HeadsAxis ¶
func (l AxesLayout) HeadsAxis() int
HeadsAxis returns the axis index for the heads dimension.
func (AxesLayout) SeqAxis ¶
func (l AxesLayout) SeqAxis() int
SeqAxis returns the axis index for the sequence dimension.
func (AxesLayout) String ¶
func (l AxesLayout) String() string
String returns the name of the layout.
type Backend ¶
type Backend interface {
// Name returns the short name of the backend. E.g.: "xla" for the Xla/PJRT plugin.
Name() string
// String returns the same as Name.
String() string
// Description is a longer description of the Backend that can be used to pretty-print.
Description() string
// NumDevices return the number of devices available for this Backend.
NumDevices() int
// DeviceDescription returns a description of the device at the given deviceNum.
DeviceDescription(deviceNum DeviceNum) string
// Capabilities returns information about what is supported by this backend.
Capabilities() Capabilities
// Builder creates a new builder used to define a newly named computation.
Builder(name string) Builder
// DataInterface is the sub-interface that defines the API to transfer Buffer to/from accelerators for the backend.
DataInterface
// Finalize releases all the associated resources immediately and makes the backend invalid.
// Any operation on a Backend after Finalize is called is undefined, except IsFinalized.
Finalize()
// IsFinalized returns true if the backend is finalized.
//
// Tensors stored on a backend may hold a reference to a finalized backend, and when being garbage collected,
// check whether it is finalized before requesting the backend to finalize its buffers.
IsFinalized() bool
}
Backend is the API that needs to be implemented by a compute backend. See package compute for more information.
func MustNew ¶
func MustNew() Backend
MustNew returns a new default Backend or panics if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
func New ¶
New returns a new default Backend or an error if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
func NewOrErr
deprecated
NewOrErr returns a new default Backend or an error if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
Deprecated: at the next version this function will be removed. Use New instead.
func NewWithConfig ¶
NewWithConfig takes a configuration string formated as
The format of config is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend-specific (e.g.: for xla backend, it is the PJRT plugin name).
type Buffer ¶
type Buffer interface {
// Backend returns the compute Backend that owns and manages this buffer.
Backend() Backend
// Finalize allows the client to inform the backend that the buffer is no longer needed and associated
// resources can be freed immediately, as opposed to waiting for a garbage collection.
//
// A finalized buffer should never be used again.
Finalize() error
// Shape returns the shape for the buffer.
Shape() (shapes.Shape, error)
// DeviceNum returns the deviceNum for the buffer.
DeviceNum() (DeviceNum, error)
// ToFlatData transfers the flat values of a buffer to a Go flat slice.
// The slice flat must have the exact number of elements required to store the Buffer shape,
// and be a slice of the corresponding DType -- see DType.GoType().
ToFlatData(flat any) error
// Data returns a slice pointing to the buffer storage memory directly.
// This only works if the backend's HasSharedBuffer is true.
Data() (flat any, err error)
// CopyToDevice copies the buffer to the deviceNum.
//
// Accelerators often have a much faster bus on which to transfer data, so this is expected to be potentially
// much faster than copying to the host and to the new device.
CopyToDevice(deviceNum DeviceNum) (bufferOnDevice Buffer, err error)
}
Buffer represents actual data (a tensor) stored in the accelerator that is actually going to execute the graph. It's used as input/output of computation execution. A Buffer is always associated to a DeviceNum, even if there is only one.
type Builder ¶
type Builder interface {
// Name of the computation being built.
Name() string
// Main returns the main function of this computation, named MainName.
// Operations added to Main become part of the compiled computation.
// This is the default function where all operations should be added
// unless explicitly building a sub-function.
Main() Function
// NewFunction creates a new named function within this builder.
// These are top-level functions that can be called form the main function.
//
// The name must be unique, and differnt from MainName (== "main"), the main function's name.
//
// These functions can be called from the main function or other functions.
//
// See also Function.Closure() to create unnamed local functions used in ops like While, If and others.
//
// Returns an error if the backend doesn't support sub-functions.
NewFunction(name string) (Function, error)
// OpShape returns the shape of a computation Op.
// Notice this is not an operation and doesn't change the graph being built.
//
// One can use the shape and create a constant out of it.
OpShape(op Value) (shapes.Shape, error)
// DistributedSPMD creates a computation that will be executed on multiple devices in SPMD fashion
// (SPMD = single program, multiple data).
//
// Use DeviceAssignment to assign the devices to the computation -- the default assignment is incremental
// devices starting from 0.
DistributedSPMD(numDevices int) error
// DistributedAutoSharding creates a computation that will be executed on multiple devices with auto-sharding.
// This currently aims at XLA Shardy [1] framework. But other backends can implement it with the same semantics,
// if appropriate.
//
// [1] https://github.com/openxla/shardy
DistributedAutoSharding(meshes ...Mesh) error
// DeviceAssignment assigns the concrete devices to the computation.
//
// The number of devices must match the number of devices in the computation.
// Usually, that is 1. But if DistributedSPMD was used, it can be more.
DeviceAssignment(devices ...DeviceNum) error
// Compile the computation built. This immediately invalidates the Builder
// and returns an Executable that can be used to run the computation.
//
// The Main function must have had Return() called before compilation.
Compile() (Executable, error)
}
Builder defines the interface for building a computation.
A Builder manages one or more Functions, with Main() being the primary entry point that gets compiled into an Executable. Operations are added to Functions (not directly to Builder), and Function.Return() must be called before Builder.Compile().
Each Builder can also:
- Not implement standard operations by returning an error -- this restricts what type of models it can support. See Backend.Capabilities and package github.com/gomlx/compute/notimplemented.
- Support specialized operations beyond those defined in this interface -- this requires careful interface casting by the caller.
type Capabilities ¶
type Capabilities struct {
// Operations supported by a backend.
// If not listed, it's assumed to be false, hence not supported.
Operations map[OpType]bool
// Functions indicates whether the backend supports functions (top-level functions or closures).
// Without functions, it's not possible to support Call() op or any other
// op that takes as input a closure (While, If, etc.)
Functions bool
// DTypes list the data types supported by a backend.
// If not listed, it's assumed to be false, hence not supported.
DTypes map[dtypes.DType]bool
// PreferConstantsForVariables indicates that the backend prefers context variables
// (model weights) to be embedded as constants in the computation graph rather than
// passed as parameters (inputs) at execution time. This enables optimizations like
// weight blob storage and eliminates per-inference data transfer overhead.
// When true, libraries like onnx-gomlx should use graph.Const() instead of
// Variable.ValueGraph() for model weights.
PreferConstantsForVariables bool
}
Capabilities holds mappings of what is supported by a backend.
func (Capabilities) Clone ¶
func (c Capabilities) Clone() Capabilities
Clone makes a deep copy of the Capabilities.
type CollectiveOps ¶
type CollectiveOps interface {
// AllReduce is a distributed (multi-device) operation that reduces over the build the AllReduce operation
// across replica groups.
//
// - operands: list of operands to be replicated -- often this operation is called over all the parameters
// of a model, hence the option to pass a variable number of parameters to them.
// - reductionType: how the operands should be reduced.
// - replicaGroups: a collection of replica groups: each replica group ([]int) is a collection of devices that
// will participate in the distributed operation. The devices are given as indices (hence []int) into the
// device assignments (not absolute DeviceNum).
AllReduce(operands []Value, reductionType ReduceOpType, replicaGroups [][]int) ([]Value, error)
}
CollectiveOps is an interface for collective operations, that is, operations executed across multiple devices.
EXPERIMENTAL: currently the best supported distribution model is "GSPMD" (Global Single Program Multiple Data), which automatically distributes the computation across all devices, without having to use explicit collective operations.
type Constructor ¶
Constructor takes a config string (optionally empty) and returns a Backend.
type ConvolveAxesConfig ¶
type ConvolveAxesConfig struct {
InputBatch, InputChannels int
InputSpatial []int
KernelInputChannels, KernelOutputChannels int
KernelSpatial []int
OutputBatch, OutputChannels int
OutputSpatial []int
}
ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output have batch and channel axes. Kernel has inputChannel and outputChannel axes.
See Builder.ConvGeneral.
func (ConvolveAxesConfig) Clone ¶
func (c ConvolveAxesConfig) Clone() ConvolveAxesConfig
Clone returns a deep copy of the structure.
type DataInterface ¶
type DataInterface interface {
// BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType)
// to the deviceNum, and returns the corresponding Buffer.
BufferFromFlatData(deviceNum DeviceNum, flat any, shape shapes.Shape) (Buffer, error)
// that can be used directly by the engine and has a local address that can be read or mutated
// directly by the client.
HasSharedBuffers() bool
// computations and directly read or mutated by the clients.
//
// It panics if the backend doesn't support shared buffers -- see HasSharedBuffers.
NewSharedBuffer(deviceNum DeviceNum, shape shapes.Shape) (buffer Buffer, flat any, err error)
}
DataInterface is the Backend's subinterface that defines the API to transfer Buffer to/from accelerators for the backend.
type DeviceNum ¶
type DeviceNum int
DeviceNum represents which device holds a buffer or should execute a computation. It's up to the backend to interpret it, but it should be between 0 and Backend.NumDevices.
type DotGeneralConfig ¶
type DotGeneralConfig struct {
// AccumulatorDType is the data type of the accumulator during the matrix multiplication.
// Commonly set to Float32 for half-precision (Float16 and BFloat16) operations, or maybe for Int32
// for small quantized values operations.
//
// If left empty, it defaults to the same dtype as the inputs.
//
// Some backends may not support this option and this will cause it to simply convert the input to the accumulation
// type upfront, which is less efficient.
AccumulatorDType dtypes.DType
// OutputDType is the data type of the output of the matrix multiplication.
//
// If left empty, it defaults to the same dtype as the AccumulatorDType.
//
// Some backends may not support this option and this will cause it to simply convert the input to the output
// type upfront, which is less efficient.
OutputDType dtypes.DType
}
DotGeneralConfig are optional configurations for the DotGeneral operation.
type Executable ¶
type Executable interface {
// Finalize immediately frees resources associated to the executable.
Finalize()
// Inputs returns the parameters' names and shapes, in order created by the Builder.Parameter calls.
Inputs() (names []string, inputShapes []shapes.Shape)
// Outputs return the computation's output shapes, in the order given to the Builder.Compile call.
Outputs() (outputShapes []shapes.Shape)
// Execute the computation.
// The number and shapes of the inputs must match those of the Executable parameters (returned by [Executable.Inputs]).
//
// The inputs marked as donate will become invalid after use.
// This is useful if the input buffer is no longer needed or if updating a variable
// so its Buffer space can be reused as an output Buffer.
//
// Donated buffers are no longer valid after the call.
// If donate is nil, it is assumed to be false for all buffers, and no buffer is donated.
//
// For portable computations (not compiled with a fixed device assignment), the execution runs on the defaultDevice.
// For non-portable computations (where the device assignment is fixed), the defaultDevice is ignored.
//
// For SPMD distributed computations (see [Builder.DistributedSPMD]), the executable is replicated on each device.
// There will be multiple inputs per executable parameter, one per device.
// They are organized as "device-major", that is the input for parameter i on device j is given by inputs[j*numParams + i].
Execute(inputs []Buffer, donate []bool, defaultDevice DeviceNum) ([]Buffer, error)
}
Executable is the API for compiled programs ready to execute.
type FFTType ¶
type FFTType int
FFTType select among the basic types of Fast Fourier Transform (FFT) supported.
func FFTTypeString ¶
FFTTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func (FFTType) IsAFFTType ¶
IsAFFTType returns "true" if the value is listed in the enum definition. "false" otherwise
type Function ¶
type Function interface {
// Name of the function. It will return "" for closures.
Name() string
// Builder returns the builder of which this function is part of.
Builder() Builder
// Parent returns the parent function of the current function.
// This is only set for "closures" within another functions.
// For top-level functions, like "main", or for backends that don't support fun this returns nil.
Parent() Function
// Closure returns a new local function, that can be used by certain operations like While, If, Sort.
// Closure functions can access values from its parent function.
Closure() (Function, error)
// StandardOps includes all standard math/ML operations.
StandardOps
// CollectiveOps includes all collective (distributed cross-device) operations.
CollectiveOps
// FusedOps includes optional fused operations for better performance.
FusedOps
// Parameter creates an input parameter for this function.
//
// For the Main function, these become the computation's input parameters
// that must be provided when executing the compiled computation.
//
// For sub-functions, these define the function's input signature.
//
// The sharding defines how the parameter will be sharded for distributed
// operations. Set it to nil if not using distribution.
Parameter(name string, shape shapes.Shape, sharding *ShardingSpec) (Value, error)
// Constant creates a constant in the function with the given flat values
// and the shape defined by the dimensions.
//
// The flat value must be a slice of a basic type supported (that can be
// converted to a DType).
//
// The value is copied into the graph. It's recommended that for very large
// tensors, even if constants, that they are passed as parameters instead.
Constant(flat any, dims ...int) (Value, error)
// Shape returns the shape of the given Value.
//
// Notice, this doesn't create an op on the graph, it's purely for reporting/introspection.
Shape(v Value) (shapes.Shape, error)
// Return marks the outputs of this function.
// Once called, the function can no longer be futher modified.
//
// For the Main function, this defines what values will be returned when
// the compiled computation is executed.
//
// For sub-functions, this defines what values are returned when the
// function is called.
//
// The shardings parameter optionally specifies output sharding for
// distributed computation with AutoSharding. Set to nil otherwise.
//
// Return must be called exactly once before Builder.Compile().
Return(outputs []Value, shardings []*ShardingSpec) error
// Call a function with the given inputs.
//
// The function f must be from the same builder.
Call(f Function, inputs ...Value) ([]Value, error)
// Sort sorts one or more tensors along the specified axis using a comparator closure.
//
// The comparator is a closure that takes 2*N scalar inputs (where N is the number of tensors)
// and returns a single boolean. For each pair of positions being compared, it receives
// (lhs_0, lhs_1, ..., lhs_N-1, rhs_0, rhs_1, ..., rhs_N-1) where lhs_i and rhs_i are scalars
// from tensor i at the two positions being compared.
//
// The comparator should return true if lhs should come before rhs in the sorted order.
// For a standard ascending sort on a single tensor, the comparator returns lhs < rhs.
//
// All input tensors must have the same shape. The axis must be valid for the input shape.
// If isStable is true, the sort maintains the relative order of equal elements.
//
// Returns the sorted tensors in the same order as inputs.
Sort(comparator Function, axis int, isStable bool, inputs ...Value) ([]Value, error)
// While executes a loop while a condition is true.
//
// The condition closure (cond) takes N values (the current state) and returns a single
// boolean scalar indicating whether to continue looping.
//
// The body closure takes N values (the current state) and returns N values (the new state).
// The shapes of the outputs must match the shapes of the inputs.
//
// The initialState values are passed to both cond and body on the first iteration.
// On subsequent iterations, the outputs of body become the new state.
//
// Returns the final state values when cond returns false.
While(cond, body Function, initialState ...Value) ([]Value, error)
// If executes one of two branches based on a boolean predicate.
//
// The pred must be a scalar boolean value.
//
// The trueBranch and falseBranch are closures that take no parameters (they can capture
// values from the parent scope) and return N values each. Both branches must return
// the same number of outputs with matching shapes.
//
// Returns the outputs of the executed branch.
If(pred Value, trueBranch, falseBranch Function) ([]Value, error)
}
Function represents a computation function within a Builder.
A Function contains operations (via StandardOps and CollectiveOps), constants, and parameters. Multiple functions can be composed within a Builder, with Main() being the entry point that gets compiled.
Other top-level functions created via Builder.NewFunction() can be used for modular computation, while-loop bodies, conditional branches, reduce operations, etc.
The typical lifecycle is:
- Create parameters via Parameter()
- Build computation using StandardOps/CollectiveOps methods
- Mark outputs via Return()
After all functions of a Builder are finished (and Return() has been called), one compiles the Builder with Builder.Compile().
type FusedOps ¶
type FusedOps interface {
// FusedSoftmax computes softmax along the specified axis.
//
// Note: unlike the generic softmax in GoMLX's graph package, the fused
// softmax only accepts one axis. The axis must be non-negative (the caller
// normalizes negative indices before calling).
FusedSoftmax(x Value, axis int) (Value, error)
// FusedGelu computes Gaussian Error Linear Unit activation.
// If exact is true, the exact GELU (using erf) is computed;
// otherwise the tanh approximation is used.
FusedGelu(x Value, exact bool) (Value, error)
// FusedLayerNorm applies layer normalization over specified axes.
// gamma and beta can be nil if no learned scale/offset.
// epsilon: numerical stability constant (typically 1e-5).
FusedLayerNorm(x Value, axes []int, epsilon float64, gamma, beta Value) (Value, error)
// FusedDense performs fused matmul + optional bias + optional activation.
//
// It does y = activation(x @ W + bias). Where @ is a standard matmul,
// it contracts x's last axis with weight's first axis.
//
// - x: [batch..., in_features], weight: [in_features, out_features...],
// - bias: [out_features...] (nil-able).
// - activation: applied after the matmul+bias; set to ActivationNone for no activation.
FusedDense(x, weight, bias Value, activation ActivationType) (Value, error)
// FusedScaledDotProductAttention computes multi-head scaled dot-product attention.
//
// output = softmax(query @ key^T * scale + mask) @ value, computed per-head with GQA support.
//
// Inputs:
// - query, key, value: 4D tensors whose axis ordering is determined by axesLayout.
// For AxesLayoutBHSD: query [batch, numHeads, seqLen, headDim],
// key/value [batch, numKVHeads, kvLen, headDim].
// For AxesLayoutBSHD: query [batch, seqLen, numHeads, headDim],
// key/value [batch, kvLen, numKVHeads, headDim].
// - mask: [seqLen, kvLen] (seqLen is the query sequence length): optional (can be nil) mask
// that can be either boolean or additive (any dtype other than Bool). See also causal below.
// Boolean mask: true = attend, false = ignore.
// Float/additive mask: added to scores before softmax.
// Must be broadcastable to the score tensor shape.
//
// Parameters:
// - numHeads: number of query attention heads
// - numKVHeads: number of key/value attention heads (for GQA; numHeads must be divisible by numKVHeads)
// - axesLayout: determines the axis ordering of query/key/value tensors
// - scale: scaling factor applied to query @ key^T (typically 1/sqrt(headDim))
// - causal: if true, apply causal (lower-triangular) mask. Callers (e.g. attention.Core)
// treat causal and mask as mutually exclusive, folding causal into the mask before calling
// this method when both are needed. Backends may assume they won't both be set.
// - options: optional optimization hints (nil uses defaults). See ScaledDotProductAttentionConfig.
//
// Output: same shape as query.
FusedScaledDotProductAttention(
query, key, value, mask Value,
numHeads, numKVHeads int,
axesLayout AxesLayout,
scale float64,
causal bool,
options *ScaledDotProductAttentionConfig) (Value, error)
// QuantizedEmbeddingLookup performs a quantized embedding lookup (row gather)
// with on-the-fly dequantization.
//
// This is the quantized analogue of Gather for embedding lookups, inspired by
// llama.cpp's ggml_get_rows. For now it is only implemented for the GGML
// quantization scheme, but could be extended for others if/when needed.
//
// Inputs:
// - data: [vocabSize, bytesPerRow] Uint8 with native GGML block layout.
// - indices: integer tensor with last dimension = 1 (same as Gather convention).
// For embeddings: [batch, seqLen, 1].
// - dataQuantization: describes how to dequantize the data rows. Must not be nil.
// Only QuantGGML scheme is currently supported.
//
// Output: float32 tensor with shape [batch..., K] where K = (bytesPerRow / bytesPerBlock) * valuesPerBlock.
// For embeddings with indices [batch, seqLen, 1]: output is [batch, seqLen, K].
QuantizedEmbeddingLookup(data, indices Value,
dataQuantization *Quantization) (Value, error)
// FusedQuantizedDense performs fused dequantization + matmul + optional bias + optional activation.
//
// It computes y = activation(x @ dequant(weights, weightsQuantization) + bias), where the
// dequantization and matmul are fused into a single pass to avoid materializing the
// full float32 weight matrix.
//
// Inputs:
// - x: [batch..., K] float32 input activations.
// - weights: For Linear/NF4: [K, N] with dtype reflecting storage precision (e.g. Int4, Int8).
// For sub-byte types the caller should Bitcast packed uint8 data to the correct dtype
// before calling.
// For GGML: [N, bytesPerRow] Uint8 with native GGML block layout, where N is the
// output-features dimension and bytesPerRow = (K / valuesPerBlock) * bytesPerBlock.
// - bias: [N] float32 (nil-able), added after matmul but before activation.
// - weightsQuantization: describes how to dequantize the weights tensor. Must not be nil.
// - activation: applied after matmul+bias; set to ActivationNone for no activation.
//
// Future: inputQuantization, outputQuantization, and biasQuantization parameters may be
// added to support fully quantized operations where the activations and/or output are
// also quantized.
FusedQuantizedDense(x, weights, bias Value,
weightsQuantization *Quantization,
activation ActivationType) (Value, error)
// FusedAttentionQKVProjection performs fused Query-Key-Value projection: a single large matmul
// merged with a scatter into separate query (Q), key (K), value (V) outputs with optional
// per-projection bias.
//
// The caller is expected to flatten any leading dimensions (e.g. batch and sequence) into a
// single "batch" axis before calling, and reshape the outputs afterwards. For example, with
// BSHD layout the caller reshapes [batch, seqLen, inFeatures] → [batch*seqLen, inFeatures],
// calls this method, then reshapes each output back to [batch, seqLen, ...].
//
// Inputs:
// - x: [batch, inFeatures] (batch may include a merged sequence dimension)
// - wQKV: [inFeatures, queryDim+2*keyValueDim] (Q/K/V weights concatenated along last axis)
// - biasQ: [queryDim] (optional, nil for no bias)
// - biasK: [keyValueDim] (optional, nil for no bias)
// - biasV: [keyValueDim] (optional, nil for no bias)
//
// Parameters:
// - queryDim: output dimension for query projection
// - keyValueDim: output dimension for key and value projections
//
// Outputs: query [batch, queryDim], key [batch, keyValueDim], value [batch, keyValueDim]
FusedAttentionQKVProjection(
x, wQKV, biasQ, biasK, biasV Value,
queryDim, keyValueDim int) (
query, key, value Value, err error)
}
FusedOps defines optional fused operations. Backends may implement these for better performance; the graph layer falls back to decomposed operations when unavailable.
Like with standard ops, if the backend doesn't implement the fused op, return ErrNotImplemented (wrapped with a stack).
type GGMLQuantType ¶
type GGMLQuantType int
GGMLQuantType identifies the specific GGML block quantization format. Enum values are aligned with go-highway's gguf.QuantType for future integration.
const ( GGMLQ4_0 GGMLQuantType = iota // 18 bytes/block, 32 values GGMLQ8_0 // 34 bytes/block, 32 values GGMLIQ4NL // 18 bytes/block, 32 values (non-linear lookup) GGMLQ2_K // 84 bytes/block, 256 values GGMLQ3_K // 110 bytes/block, 256 values GGMLQ4_K // 144 bytes/block, 256 values GGMLQ5_K // 176 bytes/block, 256 values GGMLQ6_K // 210 bytes/block, 256 values )
func (GGMLQuantType) BytesPerBlock ¶
func (t GGMLQuantType) BytesPerBlock() int
BytesPerBlock returns the byte size of one quantized block.
func (GGMLQuantType) String ¶
func (t GGMLQuantType) String() string
String returns the name of the GGML quantization type.
func (GGMLQuantType) ValuesPerBlock ¶
func (t GGMLQuantType) ValuesPerBlock() int
ValuesPerBlock returns the number of float32 values represented by one block.
type Mesh ¶
type Mesh struct {
Name string
AxesSizes []int
AxesNames []string
// LogicalDeviceAssignment is the logical assignment of devices to the mesh.
// The numbers here correspond to the indices on the "hard" device assignment set with
// Builder.DeviceAssignment() method.
//
// If left empty, the default assignment is incremental devices starting from 0.
LogicalDeviceAssignment []int
}
Mesh represents a mesh of devices, passed to the Builder.DistributedAutoSharding method.
AxesSizes and AxesNames define the mesh topology.
type OpType ¶
type OpType int
OpType is an enum of all generic operations that can be supported by a Backend.Builder.
Notice: nothing precludes a specialized backend Builder to support other ops not included here. It requires some careful casting of interfaces by the caller.
const ( OpTypeInvalid OpType = iota OpTypeParameter OpTypeConstant OpTypeIdentity OpTypeReduceWindow OpTypeRNGBitGenerator OpTypeBatchNormForInference OpTypeBatchNormForTraining OpTypeBatchNormGradient OpTypeBitCount OpTypeAbs OpTypeAdd OpTypeArgMinMax OpTypeAtan2 OpTypeBitcast OpTypeBitwiseAnd OpTypeBitwiseNot OpTypeBitwiseOr OpTypeBitwiseXor OpTypeBroadcastInDim OpTypeCall OpTypeClamp OpTypeCeil OpTypeClz OpTypeComplex OpTypeConcatenate OpTypeConj OpTypeConvGeneral OpTypeConvertDType OpTypeCos OpTypeDiv OpTypeDot OpTypeDotGeneral OpTypeDynamicSlice OpTypeDynamicUpdateSlice OpTypeEqual OpTypeEqualTotalOrder OpTypeErf OpTypeExp OpTypeExpm1 OpTypeFFT OpTypeFloor OpTypeGather OpTypeGreaterOrEqual OpTypeGreaterOrEqualTotalOrder OpTypeGreaterThan OpTypeGreaterThanTotalOrder OpTypeImag OpTypeIota OpTypeIsFinite OpTypeIsNaN OpTypeLessOrEqual OpTypeLessOrEqualTotalOrder OpTypeLessThan OpTypeLessThanTotalOrder OpTypeLog OpTypeLog1p OpTypeLogicalAnd OpTypeLogicalNot OpTypeLogicalOr OpTypeLogicalXor OpTypeLogistic OpTypeMax OpTypeMin OpTypeMul OpTypeNeg OpTypeNotEqual OpTypeNotEqualTotalOrder OpTypePad OpTypePow OpTypeReal OpTypeReduceBitwiseAnd OpTypeReduceBitwiseOr OpTypeReduceBitwiseXor OpTypeReduceLogicalAnd OpTypeReduceLogicalOr OpTypeReduceLogicalXor OpTypeReduceMax OpTypeReduceMin OpTypeReduceProduct OpTypeReduceSum OpTypeRem OpTypeReshape OpTypeReverse OpTypeRound OpTypeRsqrt OpTypeScatterMax OpTypeScatterMin OpTypeScatterSum OpTypeSelectAndScatterMax OpTypeSelectAndScatterMin OpTypeSelectAndScatterSum OpTypeShiftLeft OpTypeShiftRightArithmetic OpTypeShiftRightLogical OpTypeSign OpTypeSin OpTypeSlice OpTypeSqrt OpTypeSub OpTypeTanh OpTypeTranspose OpTypeWhere OpTypeSort OpTypeWhile OpTypeIf // OpTypeCapturedValue represents a value captured from a parent scope in a closure. // This allows closures to reference values computed in enclosing functions. OpTypeCapturedValue OpTypeAllReduce OpTypeCollectiveBroadcast OpTypeAllGather // OpTypeBlockForDotGeneral pre-blocks a tensor for efficient DotGeneral execution. // This is an internal optimization used by the simplego backend. OpTypeBlockForDotGeneral OpTypeFusedSoftmax OpTypeFusedLayerNorm OpTypeFusedGelu OpTypeFusedDense OpTypeFusedScaledDotProductAttention OpTypeFusedAttentionQKVProjection OpTypeFusedQuantizedDense OpTypeQuantizedEmbeddingLookup // OpTypeLast should always be kept the last, it is used as a counter/marker for OpType. OpTypeLast )
func OpTypeString ¶
OpTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
type PadAxis ¶
type PadAxis struct {
Start, End, Interior int
}
PadAxis defines the amount of padding preceding one axis (Start), at the end of axis (End) or in between the inputs (Interior). This is used as a parameter for the Pad operation.
type Quantization ¶
type Quantization struct {
// Scheme: Linear, NF4, or GGML.
Scheme QuantizationScheme
// Scale is the multiplicative factor.
// Shape: [K, NumBlocks] (block-wise), where K is the input-features
// (contracting) dimension of the [K, N] weight matrix and
// NumBlocks = ceil(N / BlockSize).
// Unused for QuantGGML (scales are embedded in the blocks).
Scale Value
// ZeroPoint is the additive offset (only for Linear).
// If nil, the quantization is assumed symmetric.
// Unused for QuantGGML and QuantNF4.
ZeroPoint Value
// BlockAxis is the dimension of the quantized tensor that is blocked.
// This is the output-features dimension (axis 1) of a [K, N] weight matrix.
// Currently only BlockAxis=1 is supported.
// Unused for QuantGGML.
BlockAxis int
// BlockSize is the number of elements in BlockAxis that share one scale.
// If BlockSize == N, it's effectively per-axis quantization.
// Unused for QuantGGML.
BlockSize int
// GGMLType specifies the concrete GGML block format (Q4_0, Q8_0, etc.).
// Only used when Scheme == QuantGGML.
GGMLType GGMLQuantType
}
Quantization describes how a value is quantized, and holds the information to dequantize it.
type QuantizationScheme ¶
type QuantizationScheme int
QuantizationScheme specifies how quantized integer values map to floating-point values.
const ( // QuantLinear is standard affine quantization: float_value = int_value * scale + zeroPoint. // Used with Int4 weights (symmetric, zeroPoint=nil) or Int8 weights. QuantLinear QuantizationScheme = iota // QuantNF4 is 4-bit NormalFloat from QLoRA: nibble indices are looked up in a fixed // 16-entry table, then multiplied by scale. QuantNF4 // QuantGGML indicates that the weights are stored in native GGML block format // (e.g. Q4_0, Q8_0, K-quants). The scales and zero points are embedded in the // blocks themselves, so Scale/ZeroPoint/BlockAxis/BlockSize in Quantization are // unused; the GGMLType field specifies the concrete block format. // Weights must be [N, bytesPerRow] Uint8 with native GGML block layout. QuantGGML )
func (QuantizationScheme) String ¶
func (q QuantizationScheme) String() string
String returns the name of the quantization scheme.
type ReduceOpType ¶
type ReduceOpType int
ReduceOpType select among the basic types of reduction supported.
const ( // ReduceOpUndefined is an undefined value. ReduceOpUndefined ReduceOpType = iota // ReduceOpSum reduces by summing all elements being reduced. ReduceOpSum // ReduceOpProduct reduces by multiplying all elements being reduced. ReduceOpProduct // ReduceOpMax reduces by taking the maximum value. ReduceOpMax // ReduceOpMin reduces by taking the minimum value. ReduceOpMin )
func ReduceOpTypeString ¶
func ReduceOpTypeString(s string) (ReduceOpType, error)
ReduceOpTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func ReduceOpTypeValues ¶
func ReduceOpTypeValues() []ReduceOpType
ReduceOpTypeValues returns all values of the enum
func (ReduceOpType) IsAReduceOpType ¶
func (i ReduceOpType) IsAReduceOpType() bool
IsAReduceOpType returns "true" if the value is listed in the enum definition. "false" otherwise
func (ReduceOpType) String ¶
func (i ReduceOpType) String() string
type ScaledDotProductAttentionConfig ¶
type ScaledDotProductAttentionConfig struct {
// QuantizedMatmuls: if true, the backend may use dynamic per-head symmetric
// affine quantization (scale-only, no zero point) to convert float32 Q/K/V slices
// to uint8 for the Q@K^T and attn@V matmul stages. Accumulation is done in int32,
// then dequantized back to float32. Softmax and masking remain in float32.
// This matches ONNX DynamicQuantizeLinear semantics and trades some numerical
// precision for throughput on hardware with fast integer dot-product instructions
// (e.g. ARM SDOT/UDOT, x86 VNNI). Backends that do not support quantized matmuls
// ignore this flag and use float arithmetic.
QuantizedMatmuls bool
}
ScaledDotProductAttentionConfig holds optional optimization hints for FusedScaledDotProductAttention. A nil *ScaledDotProductAttentionConfig means "use defaults" (all optimizations disabled).
type ShardingSpec ¶
ShardingSpec holds a list of per tensor (or Node) axis of a list of Mesh axes names. Any tensor axis that doesn't have a corresponding ShardingSpec is considered replicated. And any tensor axis for which the list of Mesh axes is empty is also considered replicated.
The ShardingSpec also holds the Mesh name over which it is defined.
type StandardOps ¶
type StandardOps interface {
// Abs returns the Op that represents the output of the corresponding operation.
Abs(x Value) (Value, error)
// Add returns the element-wise sum of the two values.
// Standard broadcasting rules apply (see documentation).
Add(lhs, rhs Value) (Value, error)
// Atan2 returns element-wise the arc tangent of y/x, using the signs of both arguments to determine
// the correct quadrant of the result.
// Standard broadcasting rules apply (see documentation).
Atan2(lhs, rhs Value) (Value, error)
// ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x.
//
// outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input.
// It's a form of reduction on the given axis, and that axis goes away.
// So the rank of the result is one less than the rank of x.
//
// If there is a NaN in the slice being examined, it is chosen for ArgMinMax -- this is inline with Jax, TensorFlow, and PyTorch.
//
// Examples:
//
// ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0} // (it chooses the 0 and the -3)
// ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it chooses the 2, 4, and 7)
ArgMinMax(x Value, axis int, outputDType dtypes.DType, isMin bool) (Value, error)
// BatchNormForInference implements batch normalization for inference.
//
// See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference.
//
// Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormForInference(operand, scale, offset, mean, variance Value, epsilon float32, featureAxis int) (Value, error)
// BatchNormForTraining implements batch normalization for training.
//
// See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
//
// It returns the normalized tensor, the batchMean, and the batchVariance.
//
// Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormForTraining(
operand, scale, offset Value,
epsilon float32,
featureAxis int,
) (normalized Value, batchMean Value, batchVariance Value, err error)
// BatchNormGradient calculates the batch normalization gradients with respect to the input, scale, and offset.
//
// See details in https://openxla.org/xla/operation_semantics#batchnormgrad
//
// The gradOutput is the adjoint gradient (the "V" in "VJP"), that is, the gradient with respect to the output of the
// batch normalization.
//
// Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormGradient(
operand, scale, mean, variance, gradOutput Value,
epsilon float32,
featureAxis int,
) (gradOperand Value, gradScale Value, gradOffset Value, err error)
// Bitcast performs an elementwise bitcast operation from a dtype to another dtype.
//
// The Bitcast doesn't "convert", rather it just reinterprets the bits from operand.DType() to the targetDType.
//
// If the element sizes (in bytes/bits) differ, the last dimension is adjusted:
// - Smaller target: a new trailing axis of size (srcBits / dstBits) is appended, so rank is increased by 1.
// - Larger target: the last axis must be equal to (dstBits / srcBits), and the resultign rank is decreased by 1 ("squeezed").
//
// E.g:
//
// Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xbeef, 0xdead}} // Little-endian encoding.
// Bitcast([1][2]uint16{{0xbeef, 0xdead}}, dtypes.UInt32) -> [1]uint32{0xdeadbeef}
Bitcast(operand Value, targetDType dtypes.DType) (Value, error)
// BitCount returns the number of bits that are set to one.
// Also known as Population Count ("Popcnt") or Hamming Weight.
BitCount(operand Value) (Value, error)
// BitwiseAnd returns the element-wise bitwise AND operation.
BitwiseAnd(lhs, rhs Value) (Value, error)
// BitwiseNot returns the element-wise bitwise AND operation.
BitwiseNot(x Value) (Value, error)
// BitwiseOr returns the element-wise bitwise OR operation.
BitwiseOr(lhs, rhs Value) (Value, error)
// BitwiseXor returns the element-wise bitwise XOR operator.
BitwiseXor(lhs, rhs Value) (Value, error)
// BroadcastInDim broadcasts x to an output with the given shape.
// broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()).
// The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output.
// broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only
// broadcast and introduce new axes in-between.
// This also requires that the i-th input axis is either 1 or is the same as the
// output dimension it's broadcasting into.
// For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:
// - Specifying []int{1} as broadcastAxes will generate output
// {{1, 2},
// {1, 2}}
// - On the other hand, specifying []int{0} as broadcastAxes
// will generate output
// {{1 , 1},
// {2 , 2}}
BroadcastInDim(x Value, outputShape shapes.Shape, broadcastAxes []int) (Value, error)
// Ceil returns the Op that represents the output of the corresponding operation.
Ceil(x Value) (Value, error)
// Clamp returns the element-wise clamping operation.
//
// The values max and min can either be a scalar or have the same shape as x.
Clamp(min, x, max Value) (Value, error)
// Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.
Clz(x Value) (Value, error)
// Complex returns the complex number taking x0 as the real part and x1 as the imaginary part.
// The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or
// `dtypes.Float64`.
// The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes.
// The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case
// the value is broadcast to every other value.
Complex(lhs, rhs Value) (Value, error)
// Concatenate operands on the given axis.
//
// All axes that are not being concatenated must match dimensions, except on the axes being concatenated.
// It doesn't work with scalars -- use ExpandAxes.
// If there is only one operand, it is returned and this is a no-op.
Concatenate(axis int, operands ...Value) (Value, error)
// Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i
Conj(x Value) (Value, error)
// ConvGeneral is a generic Convolution operation with arbitrary number of spatial axes, strides,
// paddings, dilations, and grouping.
//
// Arguments:
//
// - input: it must have one batch and one channel axis, and arbitrary number of spatial axes.
// - kernel: its rank must match the input's spatial axes.
// - axes: defines how the axes of input and kernel are mapped.
// - strides: stride of the convolution window, how it moves. If set, one value per spatial axis,
// and values must be >= 1. If not set, strides default to 1.
// - paddings: padding applied to the start and end of each axis of the input.
// If nil, it defaults to no padding.
// - inputDilations: "virtually" expand the input by inserting `2-1` copies of `0` (or whatever
// is the reduciton "zero" value) between the elements in each dimension.
// If nil, it's assumed to be 1 (no dilation) for each axis. Values must be >= 1.
// - kernelDilations: "virtually" expand the kernel by inserting `2-1` copies of `0` between the
// elements in each dimension.
// If nil, it's assumed to be 1 (no dilation) for each axis. Values must be >= 1.
// Also known as "atrous convolution".
// - channelGroupCount: number of input channels to group together for the convolution.
// (aka "grouped convolution"). If <= 1 it's disabled.
// - batchGroupCount: number of input batches to group together for the convolution.
// If <= 1 it's disabled.
//
// There is a more detailed description in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution.
// Also useful, https://arxiv.org/pdf/1603.07285v1.pdf.
// Note:
// - Another common term for "channels" is "features".
// - "Kernel" is also commonly called "weights" or "filters".
ConvGeneral(
input, kernel Value,
axes ConvolveAxesConfig,
strides []int,
paddings [][2]int,
inputDilations, kernelDilations []int,
channelGroupCount, batchGroupCount int,
) (Value, error)
// ConvertDType of x to dtype.
ConvertDType(x Value, dtype dtypes.DType) (Value, error)
// Cos returns the Op that represents the output of the corresponding operation.
Cos(x Value) (Value, error)
// Div returns the element-wise division of the two values.
// Standard broadcasting rules apply (see documentation).
Div(lhs, rhs Value) (Value, error)
// DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications
// for a general vector product -- a generalized "Einsum". Each axis can be:
//
// - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions
// must match in lhs and rhs.
// - Crossed (default), in which case the output is the combination (concatenation) of the
// dimensions.
// - Contracted (contracting axes), where the output does multiply the values and reduce sum
// those dimensions.
//
// The resulting shape is [batchIndices..., <lhs cross indices...>, <rhs cross indices...>], the
// indices come in the order they were provided. The output dtype is by default the same as
// the input, except if configured otherwise in config.OutputDType.
//
// It provides the basic means of implementing Einsum.
DotGeneral(
lhs Value,
lhsContractingAxes, lhsBatchAxes []int,
rhs Value,
rhsContractingAxes, rhsBatchAxes []int,
config DotGeneralConfig,
) (Value, error)
// DynamicSlice extracts a slice from the operand at the startIndices position and the given sliceSizes.
//
// - operand: tensor from where to take the slice.
// - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank().
// - sliceSizes: static values and fixed to keep the shape of the output static.
//
// The startIndices are adjusted as follows:
//
// adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - sliceSizes[i])
//
// See description in https://openxla.org/xla/operation_semantics#dynamicslice
DynamicSlice(operand Value, startIndices []Value, sliceDims []int) (Value, error)
// DynamicUpdateSlice updates the operand with the values given in update, at the position given by startIndices.
//
// - operand: original value that to be updated.
// - update: values to "paste" on top of operand, at position startIndices.
// - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank().
// - sliceSizes: static values and fixed to keep the shape of the output static.
//
// It returns a value with the same shape as the operand, with the values updated.
//
// The startIndices are adjusted as follows:
//
// adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - update.Dimensions[i])
DynamicUpdateSlice(operand, update Value, startIndices []Value) (Value, error)
// Equal performs element-wise equality check, returns boolean results with the same dimensions as input.
Equal(lhs, rhs Value) (Value, error)
// EqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
EqualTotalOrder(lhs, rhs Value) (Value, error)
// Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.
Erf(x Value) (Value, error)
// Exp returns the Op that represents the output of the corresponding operation.
Exp(x Value) (Value, error)
// Expm1 returns the Op that represents the output of the corresponding operation.
Expm1(x Value) (Value, error)
// FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions.
// See documentation in https://www.tensorflow.org/xla/operation_semantics.
// Underlying, CPU FFT is backed by Eigen's TensorFFT, and GPU FFT uses cuFFT.
FFT(operand Value, fftType FFTType, fftLength []int) (Value, error)
// Floor returns the Op that represents the output of the corresponding operation.
Floor(x Value) (Value, error)
// Gather is a powerful but cumbersome Gather operation offered by XLA.
// Full details in https://www.tensorflow.org/xla/operation_semantics#gather or
// in https://openxla.org/stablehlo/spec#gather (StableHLO also adds batch axes).
//
// The output of Gather has the same DType of the operand, from where we are pulling the data.
//
// Its output shape will be composed of 2 parts:
//
// - Batch axes: they come from the axes of startIndices, except the "indexVectorAxis" (usually the last)
// that is used as the indices into the operand. (*)
// - "Offset axes": these are axes that come from the operand, the sizes given by sliceSizes.
// Notice that if sliceSizes for an axis is 1, and that axis is present in the collapsedSliceAxes list, this
// axis gets omitted in the output.
//
// So in general output.Rank() = startIndices.Rank() - 1 + len(offsetAxes).
//
// (*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an
// extra implicit axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
//
// Arguments:
// - operand: the values from where we are gathering. The output DType will follow the operand one.
// - startIndices: are the indices we want to gather. The axis pointed by indexVector
// lists the indices of the slice to be gathered in the operand array (their values are mapped to the axis
// in the operand according to startIndexMap).
// All other axes are "batch dimensions" and they will have equivalent axes (same dimensions) in the output.
// - indexVectorAxis: which of the axis in startIndices is collected and used as the start index for slices
// to be gathered in the operand.
// It is typically the last axis of startIndices, so startIndices.Shape.Rank()-1.
// There is a special case where indexVectorAxis == startIndices.Rank() in which case we assume there is an
// extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
// - offsetOutputAxes: _output_ axes (not the operand's) that will hold the "offset slices", slices that are not
// collapsed. It points in which position (axis) in the output these slices should show up.
// The len(offsetOutputAxes) must match the dimension of indexVectorAxis (== startIndices.Dimensions[indexVectorAxis]).
// Notice all axes in the operand will either become an "offset axis" in the output,
// of optionally collapsed (or "squeezed") in the output, if included in collapsedSliceAxes.
// The axes in the output (given in offsetAxes) to the axes in the operand (the axes not present in collapsedSliceAxes) sequentially.
// One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetAxes).
// - collapsedSliceAxes: _operand_ axes (for which sliceSizes are 1) not to be included in the output.
// One must have sliceSizes[collapsedSliceAxes[i]] == 1 for all i.
// Also, one must have Rank(operand) == len(collapsedSliceAxes) + len(offsetOutputAxes).
// - startIndexMap: this maps which value in startIndices is used for which axis in the operand, select the slice to be gathered.
// Notice len(startIndexMap) must match the startIndices.Dimensions[indexVectorAxis].
// Also, len(startIndexMap) == len(offsetOutputAxes) -- offsetOutputAxes maps the same axes in the output.
// E.g.: if startIndices.shape=(2, 3), indexVectorAxis=1, and operand.rank=4 and startIndexMap=[]int{0, 1, 2},
// this means each row of the startIndices will point to the first 3 axes (0,1 and 2) in the operand.
// In many cases this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each "index vector" fully defines
// an element on the operand. In some this is only a prefix of the operand's rank.
// For those axes in the operand not explicitly set (so if len(startIndexMap) < operand.Rank()), the corresponding
// axis start index is considered to be 0, and one sets the sliceSizes to take the slice one wants (typically the
// full slice).
// - sliceSizes: a size for each operand's axis, so len(sliceSize) = operand.Rank().
// once the start index from where to gather is resolved, this defines how much data in each axis
// to gather.
// Constraints: sliceSizes[collapsedSliceAxes[i]] == 1, for all i.
// - indicesAreSorted: can be set to true if it's guaranteed that startIndices are sorted (in ascending order,
// after scattering its values according to start_index_map) by the user. This allows for some optimizations
// in some platforms.
//
// Out-of-bound (and negative) indices <i> are adjusted with max(min(<i>, axisDimension-1), 0), meaning they
// are taken from the border of the axes.
// TODO: Add batch support: operandBatchingAxes and startIndicesBatchingAxes.
Gather(
operand, startIndices Value,
indexVectorAxis int,
offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int,
indicesAreSorted bool,
) (Value, error)
// GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
GreaterOrEqual(lhs, rhs Value) (Value, error)
// GreaterOrEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
GreaterOrEqualTotalOrder(lhs, rhs Value) (Value, error)
// GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input.
GreaterThan(lhs, rhs Value) (Value, error)
// GreaterThanTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
GreaterThanTotalOrder(lhs, rhs Value) (Value, error)
// Identity returns an Op whose output is the same as its input.
// It's a no-op that can serve as a place-holder.
Identity(x Value) (Value, error)
// Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.
Imag(x Value) (Value, error)
// Iota creates a constant of the given shape with increasing numbers (starting from 0)
// on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0)
// returns [[0 0][1 1]].
Iota(shape shapes.Shape, iotaAxis int) (Value, error)
// IsFinite tests whether each element of operand is finite, i.e., if it is not positive nor negative infinity, and it is not NaN.
// It returns the same shape as the input, but with boolean values where each element is true if and only if
// the corresponding input element is finite.
IsFinite(x Value) (Value, error)
// IsNaN tests whether each element of operand is NaN, i.e., if it is not a finite number.
IsNaN(x Value) (Value, error)
// LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
LessOrEqual(lhs, rhs Value) (Value, error)
// LessOrEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
LessOrEqualTotalOrder(lhs, rhs Value) (Value, error)
// LessThan performs element-wise comparison, returns boolean results with the same dimensions as input.
LessThan(lhs, rhs Value) (Value, error)
// LessThanTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
LessThanTotalOrder(lhs, rhs Value) (Value, error)
// Log returns the Op that represents the output of the corresponding operation.
Log(x Value) (Value, error)
// Log1p returns the expression log(x+1).
Log1p(x Value) (Value, error)
// LogicalAnd returns the element-wise logical AND operation.
LogicalAnd(lhs, rhs Value) (Value, error)
// LogicalNot returns the Op that represents the output of the corresponding operation.
LogicalNot(x Value) (Value, error)
// LogicalOr returns the element-wise logical OR operation.
LogicalOr(lhs, rhs Value) (Value, error)
// LogicalXor returns the element-wise logical XOR operator.
LogicalXor(lhs, rhs Value) (Value, error)
// Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.
Logistic(x Value) (Value, error)
// Max returns the element-wise highest value among the two.
Max(lhs, rhs Value) (Value, error)
// Min returns the element-wise smallest value among the two.
Min(lhs, rhs Value) (Value, error)
// Mul returns the element-wise multiplication of the two values.
// Standard broadcasting rules apply (see documentation).
Mul(lhs, rhs Value) (Value, error)
// Neg returns the Op that represents the output of the corresponding operation.
Neg(x Value) (Value, error)
// NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input.
NotEqual(lhs, rhs Value) (Value, error)
// NotEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
NotEqualTotalOrder(lhs, rhs Value) (Value, error)
// Pad injects padding on the start, end, or interior (in between each element) of the given operand.
// There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros,
// that is, no padding for those axes.
Pad(x, fillValue Value, axesConfig ...PadAxis) (Value, error)
// Pow returns the Op that represents the output of the corresponding operation.
Pow(lhs, rhs Value) (Value, error)
// Real return the real part of a complex number. It returns x if the x is a float number.
Real(x Value) (Value, error)
// ReduceBitwiseAnd reduces x over the axes selected, performing a BitwiseAnd on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceBitwiseAnd(x Value, axes ...int) (Value, error)
// ReduceBitwiseOr reduces x over the axes selected, performing a BitwiseOr on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceBitwiseOr(x Value, axes ...int) (Value, error)
// ReduceBitwiseXor reduces x over the axes selected, performing a BitwiseXor on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceBitwiseXor(x Value, axes ...int) (Value, error)
// ReduceLogicalAnd reduces x over the axes selected, performing a LogicalAnd on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceLogicalAnd(x Value, axes ...int) (Value, error)
// ReduceLogicalOr reduces x over the axes selected, performing a LogicalOr on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceLogicalOr(x Value, axes ...int) (Value, error)
// ReduceLogicalXor reduces x over the axes selected, performing a LogicalXor on the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceLogicalXor(x Value, axes ...int) (Value, error)
// ReduceMax reduces x over the axes selected, taking the Max value of the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceMax(x Value, axes ...int) (Value, error)
// ReduceMin reduces x over the axes selected, taking the Min value of the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceMin(x Value, axes ...int) (Value, error)
// ReduceProduct reduces x over the axes selected, taking the product of the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceProduct(x Value, axes ...int) (Value, error)
// ReduceSum reduces x over the axes selected, taking the sum of the slices reduced.
//
// The returned result rank is decreased by len(axes).
//
// If no axes are given, it reduces the full array.
ReduceSum(x Value, axes ...int) (Value, error)
// ReduceWindow runs a reduction function of the type given by reductionType,
// it can be either ReduceMaxNode, ReduceSumNode, or ReduceMultiplyNode.
//
// - reductionType: the type of reduction to perform. E.g.: [ReduceOpMax], [ReduceOpSum],...
// - windowDimensions: the dimensions of the window, must be defined for each axis.
// - strides: stride over elements in each axis for each window reduction. If nil, it's assume to be the
// same as the windowDimensions -- that is, the strides jump a window at a time.
// - inputDilations: "virtually" expand the input by introducing "holes" between elements. I.e. if
// inputDilations are 2, then the input is expanded by inserting `2-1` copies of `0` (or whatever
// is the reduciton "zero" value) between the elements in each dimension.
// If nil, it's assumed to be 1 (no dilation) for each axis. Values must be >= 1.
// - windowDilations: "virtually" expand the window by inserting `2-1` copies of `0` between the
// elements in each dimension.
// If nil, it's assumed to be 1 (no dilation) for each axis. Values must be >= 1.
// - paddings: virtual padding to be added to the input at the edges (start and end) of each axis.
// If nil, it's assumed to be 0 for each axis.
ReduceWindow(
input Value,
reductionType ReduceOpType,
windowDimensions, strides, inputDilations, windowDilations []int,
paddings [][2]int,
) (Value, error)
// Rem returns the remainder operation, also known as modulo (or Mod for short).
// Notice despite the name XLA implements Mod not IEEE754 Remainder operation.
Rem(lhs, rhs Value) (Value, error)
// Reshape reshapes x to the new dimensions.
// Total size cannot change, it's just a "reinterpretation" of the same flat data.
// The dtype remains the same, see ConvertDType to actually convert the values.
Reshape(x Value, dimensions ...int) (Value, error)
// Reverse returns x with the values for the given dimensions reversed, that is,
// the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`.
// The shape remains the same.
Reverse(x Value, axes ...int) (Value, error)
// RNGBitGenerator generates the given shape filled with random bits.
//
// It takes as input a state (usually [3]uint64) and returns the updated state and the generated values (with random bits).
//
// Currently, the backend only supports the Philox algorithm. See https://dl.acm.org/doi/10.1145/2063384.2063405
RNGBitGenerator(state Value, shape shapes.Shape) (newState Value, values Value, err error)
// Round returns the Op that represents the output of the corresponding operation.
// This operation rounds to the nearest even.
Round(x Value) (Value, error)
// Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).
Rsqrt(x Value) (Value, error)
// ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.
ScatterMax(
operand, scatterIndices, updates Value,
indexVectorAxis int,
updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int,
indicesAreSorted, uniqueIndices bool,
) (Value, error)
// ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.
ScatterMin(
operand, scatterIndices, updates Value,
indexVectorAxis int,
updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int,
indicesAreSorted, uniqueIndices bool,
) (Value, error)
// ScatterSum values from updates pointed by scatterIndices to operand.
ScatterSum(
operand, scatterIndices, updates Value,
indexVectorAxis int,
updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int,
indicesAreSorted, uniqueIndices bool,
) (Value, error)
// SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand, selects values to update the output (like ScatterAdd)
// It selects the values in the window such that it works as reverse for a PoolMax operation.
// See details in https://openxla.org/xla/operation_semantics#selectandscatter
SelectAndScatterMax(operand, source Value, windowDimensions, windowStrides []int, paddings [][2]int) (Value, error)
// SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand, selects values to update the output (like ScatterAdd)
// It selects the values in the window such that it works as reverse for a PoolMin operation.
// See details in https://openxla.org/xla/operation_semantics#selectandscatter
SelectAndScatterMin(operand, source Value, windowDimensions, windowStrides []int, paddings [][2]int) (Value, error)
// ShiftLeft n bits. It implicitly preserves the sign bit if there is no overflow. So ShiftLeft(-1, 1) = -2.
ShiftLeft(lhs, rhs Value) (Value, error)
// ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1.
ShiftRightArithmetic(lhs, rhs Value) (Value, error)
// ShiftRightLogical shifts right by n bits, destroying the sign bit.
ShiftRightLogical(lhs, rhs Value) (Value, error)
// Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.
Sign(x Value) (Value, error)
// Sin returns the Op that represents the output of the corresponding operation.
Sin(x Value) (Value, error)
// Slice extracts a subarray from the input array.
//
// The subarray is of the same rank as the input and contains the values inside a bounding box within the input array
// where the dimensions and indices of the bounding box are given as arguments to the slice operation.
//
// The strides set the input stride of the slice in each axis and must be >= 1.
// It is optional, and if missing, it is assumed to be 1 for every dimension.
//
// The limits are defined on the x axes, and they are exclusive upper bounds, i.e. the slice includes
// elements from starts up to (but not including) limits.
//
// Examples:
// Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
// Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}
Slice(x Value, starts, limits, strides []int) (Value, error)
// Sqrt returns the Op that represents the output of the corresponding operation.
Sqrt(x Value) (Value, error)
// Sub returns the element-wise subtraction of the two values.
// Standard broadcasting rules apply (see documentation).
Sub(lhs, rhs Value) (Value, error)
// Tanh returns the Op that represents the output of the corresponding operation.
Tanh(x Value) (Value, error)
// Transpose axes of x.
// There should be one value in permutations for each axis in x.
// The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].
Transpose(x Value, permutation ...int) (Value, error)
// Where takes element-wise values from onTrue or onFalse depending on the value of the condition (must be boolean).
//
// The condition must be boolean, and onTrue and onFalse must have the same dtype.
//
// If either condition, onTrue or onFalse is a scalar, it will be broadcasted to the shape of the other operands.
Where(condition, onTrue, onFalse Value) (Value, error)
}
StandardOps lists the bulk of the operations that a backends.Builder must support.
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
Package distributed defines the following objects related to cross-device execution:
|
Package distributed defines the following objects related to cross-device execution: |
|
Package dtypes includes the DType enum for all supported data types for GoMLX.
|
Package dtypes includes the DType enum for all supported data types for GoMLX. |
|
bfloat16
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
|
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22 |
|
float16
Package float16 implements the IEEE 754 half-precision floating-point format (binary16).
|
Package float16 implements the IEEE 754 half-precision floating-point format (binary16). |
|
Package gobackend implements a native Go compute.Backend: very portable (including WASM) but not very fast.
|
Package gobackend implements a native Go compute.Backend: very portable (including WASM) but not very fast. |
|
internal
|
|
|
cmd/alternates_generator
command
|
|
|
cmd/gobackend_dtypemap
command
|
|
|
cmd/gobackend_ops_generator
command
gobackend_generator auto-generates parts of the Go backend:
|
gobackend_generator auto-generates parts of the Go backend: |
|
cmd/gobackend_opsregistration
command
gobackends_ops_registration generates a registration system for each of the compute.StandardOps methods, so they can be implemented separated by separate packages, as well as stub methods for `gobackends.Backend` that calls the corresponding registered method if configured.
|
gobackends_ops_registration generates a registration system for each of the compute.StandardOps methods, so they can be implemented separated by separate packages, as well as stub methods for `gobackends.Backend` that calls the corresponding registered method if configured. |
|
cmd/notimplemented_generator
command
notimplemented_generator generates "notimplemented" stubs for every API of the compute.Backend interface
|
notimplemented_generator generates "notimplemented" stubs for every API of the compute.Backend interface |
|
exceptions
Package exceptions provide helper functions to leverage Go's `panic`, `recover` and `defer` as an "exceptions" system.
|
Package exceptions provide helper functions to leverage Go's `panic`, `recover` and `defer` as an "exceptions" system. |
|
gobackend
Package gobackend implements a compute.Backend using Go only.
|
Package gobackend implements a compute.Backend using Go only. |
|
gobackend/defaultpkgs
Package defaultpkgs imports all the sub-packages that implement the gobackend.
|
Package defaultpkgs imports all the sub-packages that implement the gobackend. |
|
gobackend/dot
Package dot implements a general-purpose "dot product" ("Einsum") computation.
|
Package dot implements a general-purpose "dot product" ("Einsum") computation. |
|
gobackend/dot/matmul
package matmul provides the base implementations of matrx multiply for DotGeneral.
|
package matmul provides the base implementations of matrx multiply for DotGeneral. |
|
gobackend/ops
Package ops has the base implementation for most operations for the Go backend.
|
Package ops has the base implementation for most operations for the Go backend. |
|
must
Package must provide a set of functions that check for errors and panic on error.
|
Package must provide a set of functions that check for errors and panic on error. |
|
Package notimplemented implements a compute.Builder interface that returns a "not implemented" error for all operations.
|
Package notimplemented implements a compute.Builder interface that returns a "not implemented" error for all operations. |
|
Package shapeinference calculates the shape resulting from operations and validates its inputs.
|
Package shapeinference calculates the shape resulting from operations and validates its inputs. |
|
Package shapes define Shape and DType and associated tools.
|
Package shapes define Shape and DType and associated tools. |
|
Package support and its subpackages contains various supporting functionality for gomlx/compute that may be also useful for other users or GoMLX projects, as well as to other compute.Backend implementations.
|
Package support and its subpackages contains various supporting functionality for gomlx/compute that may be also useful for other users or GoMLX projects, as well as to other compute.Backend implementations. |
|
backendparser
Package backendparser parses the compute interfaces (Backend, Builder, Function, etc.) and enumerate their methods.
|
Package backendparser parses the compute interfaces (Backend, Builder, Function, etc.) and enumerate their methods. |
|
envutil
Package envutil provides utility functions for working with environment variables.
|
Package envutil provides utility functions for working with environment variables. |
|
humanize
Package humanize provides human-readable representations of numbers, bytes, durations, speed (units/sec), etc.
|
Package humanize provides human-readable representations of numbers, bytes, durations, speed (units/sec), etc. |
|
sets
Package sets implement a set type as a `map[T]struct{}` but with better ergonomics.
|
Package sets implement a set type as a `map[T]struct{}` but with better ergonomics. |
|
testutil
Package testutil provides utilities for testing a compute.Backend.
|
Package testutil provides utilities for testing a compute.Backend. |
|
xslices
Package xslices provide missing functionality to the slices package.
|
Package xslices provide missing functionality to the slices package. |
|
xsync
Package xsync implements some extra synchronization tools.
|
Package xsync implements some extra synchronization tools. |