Documentation
¶
Overview ¶
Package backends defines the interface to a computation building and execution system needs to implement to be used by GoMLX.
It is based on 3 interfaces:
- DataInterface: handles how data (Tensors) is stored in buffers for the backend. These things are handled differently by different backends and even by different accelerators with the same backend.
- Builder: how computation graphs are built.
- Executable: how executable computations are executed.
It is based on OpenXLA's API for now.
A backend that doesn't implement every operation can simply return an "<op> not implemented" error for any op, and it would still work for computations that don't require those operations. The backend/notimplemented package helps bootstrap any new backend implementation by providing a "Not Implemented" implementation for all methods of the Builder interface.
Index ¶
- Constants
- Variables
- func FFTTypeStrings() []string
- func List() []string
- func OpTypeStrings() []string
- func ReduceOpTypeStrings() []string
- func Register(name string, constructor Constructor)
- type Backend
- type Buffer
- type Builder
- type Capabilities
- type Constructor
- type ConvolveAxesConfig
- type DataInterface
- type DeviceNum
- type Executable
- type FFTType
- type Op
- type OpType
- type PadAxis
- type ReduceOpType
- type StandardOps
Constants ¶
const ConfigEnvVar = "GOMLX_BACKEND"
ConfigEnvVar is the name of the environment variable with the default backend configuration to use: "GOMLX_BACKEND".
The format of the configuration is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend-specific (e.g.: for xla backend, it is the pjrt plugin name).
const GOMLX_BACKEND = ConfigEnvVar
GOMLX_BACKEND is deprecated and will be removed in future versions -- it is an alias to ConfigEnvVar Deprecated: use ConfigEnvVar.
Variables ¶
var DefaultConfig = "xla"
DefaultConfig is the name of the default backend configuration to use if specified.
See NewWithConfig for the format of the configuration string.
Functions ¶
func FFTTypeStrings ¶ added in v0.17.1
func FFTTypeStrings() []string
FFTTypeStrings returns a slice of all String values of the enum
func OpTypeStrings ¶ added in v0.19.0
func OpTypeStrings() []string
OpTypeStrings returns a slice of all String values of the enum
func ReduceOpTypeStrings ¶ added in v0.17.1
func ReduceOpTypeStrings() []string
ReduceOpTypeStrings returns a slice of all String values of the enum
func Register ¶
func Register(name string, constructor Constructor)
Register backend with the given name and a default constructor that takes as input a configuration string that is passed along to the backend constructor.
To be safe, call Register during initialization of a package.
Types ¶
type Backend ¶
type Backend interface {
// Name returns the short name of the backend. E.g.: "xla" for the Xla/PJRT plugin.
Name() string
// String returns the same as Name.
String() string
// Description is a longer description of the Backend that can be used to pretty-print.
Description() string
// NumDevices return the number of devices available for this Backend.
NumDevices() DeviceNum
// Capabilities returns information about what is supported by this backend.
Capabilities() Capabilities
// Builder creates a new builder used to define a new named computation.
Builder(name string) Builder
// DataInterface is the sub-interface that defines the API to transfer Buffer to/from accelerators for the backend.
DataInterface
// Finalize releases all the associated resources immediately and makes the backend invalid.
// Any operation on a Backend after Finalize is called is undefined, except IsFinalized.
Finalize()
// IsFinalized returns true if the backend is finalized.
//
// Tensors stored on a backend may hold a reference to a finalized backend, and when being garbage collected,
// check whether it is finalized before requesting the backend to finalize its buffers.
IsFinalized() bool
}
Backend is the API that needs to be implemented by a GoMLX backend.
func MustNew ¶ added in v0.20.0
func MustNew() Backend
MustNew returns a new default Backend or panics if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
func New ¶
New returns a new default Backend or an error if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
func NewOrErr
deprecated
added in
v0.20.0
NewOrErr returns a new default Backend or an error if it fails.
The default is:
1. The environment $GOMLX_BACKEND (ConfigEnvVar) is used as a configuration if defined. 2. Next, it uses the variable DefaultConfig as the configuration. 3. The first registered backend is used with an empty configuration.
It fails if no backends were registered.
Deprecated: at the next version this function will be removed. Use New instead.
func NewWithConfig ¶
NewWithConfig takes a configuration string formated as
The format of config is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend-specific (e.g.: for xla backend, it is the pjrt plugin name).
type Buffer ¶
type Buffer any
Buffer represents actual data (a tensor) stored in the accelerator that is actually going to execute the graph. It's used as input/output of computation execution. A Buffer is always associated to a DeviceNum, even if there is only one.
It is opaque from GoMLX perspective, but one of the backend methods take this value as input, and needs
type Builder ¶
type Builder interface {
// Compile the computation built. This immediately invalidates the Builder and returns an Executable that
// can now be used to run the computation.
//
// It is given the list of outputs.
Compile(outputs ...Op) (Executable, error)
// Name of the computation being built.
Name() string
// OpShape returns the shape of a computation Op.
// Notice this is not an operation and doesn't change the graph being built.
//
// One can use the shape and create a constant out of it.
OpShape(op Op) (shapes.Shape, error)
// Parameter creates an input parameter for the computation.
// During execution of a compiled computation (returned by Builder.Compile) this value will need to be fed
// in the same order it is created.
Parameter(name string, shape shapes.Shape) (Op, error)
// Constant creates a constant in the graph with the given flat values, and the shape defined by dims.
//
// The flat value must be a slice of a basic type supported -- that can be converted to a DType.
//
// The value is copied into the graph. It's recommended that for very large tensors,
// even if constants, that they are passed as side inputNodes (or variables, see context package) instead.
Constant(flat any, dims ...int) (Op, error)
// Identity returns an Op whose output is the same as its input.
// It's a no-op that can serve as a place-holder.
Identity(x Op) (Op, error)
// ReduceWindow runs a reduction function of the type given by reductionType,
// it can be either ReduceMaxNode, ReduceSumNode or ReduceMultiplyNode.
//
// The parameter windowDimensions must be set and have a value for each axis.
// If strides is nil, it's assumed to be the same as windowDimensions -- that is, the strides jump a window at a time.
// If baseDilations, windowDilations are nil, they are assumed to be 1 (no dilation).
// If paddings is nil, they are assumed to be 0.
ReduceWindow(x Op, reductionType ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (Op, error)
// RngBitGenerator generates the given shape filled with random bits.
// It takes the current random number generator (RNG) state, see RngState or RngStateFromSeed.
// The algorithm is hard-coded to use Philox algorithm for now.
//
// It returns the new state of the RNG and the generated values (with random bits) with the given shape.
RngBitGenerator(state Op, shape shapes.Shape) (newState Op, values Op, err error)
// BatchNormForInference implements Batch Norm for inference. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnorminference.
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormForInference(operand, scale, offset, mean, variance Op, epsilon float32, axis int) (Op, error)
// BatchNormForTraining implements Batch Norm for training. See details in
// https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
//
// It returns the normalized tensor, the batchMean and the batchVariance.
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormForTraining(operand, scale, offset Op, epsilon float32, axis int) (normalized Op, batchMean Op, batchVariance Op, err error)
// BatchNormGradient calculates the BatchNorm gradient. See details in
// https://openxla.org/xla/operation_semantics#batchnormgrad
//
// The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the
// batch normalization.
//
// It returns as a tuple with the 3 elements.
//
// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
BatchNormGradient(operand, scale, mean, variance, gradOutput Op, epsilon float32, axis int) (gradOperand Op, gradScale Op, gradOffset Op, err error)
// BitCount returns the number of bits that are set to one.
BitCount(operand Op) (Op, error)
// StandardOps include automatically generated list of operations for the Builder.
// Note: If StandardOps is an interface with methods, those methods would also need to be updated
// to return an error. However, without its definition, I can only modify the explicitly listed methods.
StandardOps
}
Builder defines the set of ops to support building a computation. It is the sub-interface of Backend.
Each Builder can also:
- Not implement standard operations by returning an error -- this restricts what type of models it can support. See Backend.Capabilities and package github.com/gomlx/gomlx/backends/notimplemented
- Support specialized operations beyond those defined in this interface -- this requires careful interface casting by the caller (in package github.com/gomlx/gomlx/graph) and fallback to backends that don't support these specialized ops.
type Capabilities ¶ added in v0.19.0
type Capabilities struct {
// Operations supported by a backend.
// If not listed, it's assumed to be false, hence not supported.
Operations map[OpType]bool
// DTypes list the data types supported by a backend.
// If not listed, it's assumed to be false, hence not supported.
DTypes map[dtypes.DType]bool
}
Capabilities holds mappings of what is supported by a backend.
func (Capabilities) Clone ¶ added in v0.19.0
func (c Capabilities) Clone() Capabilities
Clone makes a deep copy of the Capabilities.
type Constructor ¶
Constructor takes a config string (optionally empty) and returns a Backend.
type ConvolveAxesConfig ¶
type ConvolveAxesConfig struct {
InputBatch, InputChannel int
InputSpatial []int
KernelInputChannel, KernelOutputChannel int
KernelSpatial []int
OutputBatch, OutputChannel int
OutputSpatial []int
}
ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output has batch and channel axes. Kernel has inputChannel and outputChannel axes.
See Builder.ConvGeneralDilated
func (ConvolveAxesConfig) Clone ¶
func (c ConvolveAxesConfig) Clone() ConvolveAxesConfig
Clone returns a deep copy of the structure.
type DataInterface ¶
type DataInterface interface {
// BufferFinalize allows the client to inform backend that buffer is no longer needed and associated resources can be
// freed immediately -- as opposed to waiting for a GC.
//
// A finalized buffer should never be used again. Preferably, the caller should set its references to it to nil.
BufferFinalize(buffer Buffer) error
// BufferShape returns the shape for the buffer.
BufferShape(buffer Buffer) (shapes.Shape, error)
// BufferDeviceNum returns the deviceNum for the buffer.
BufferDeviceNum(buffer Buffer) (DeviceNum, error)
// BufferToFlatData transfers the flat values of buffer to the Go flat array.
// The slice flat must have the exact number of elements required to store the Buffer shape.
//
// See also FlatDataToBuffer, BufferShape, and shapes.Shape.Size.
BufferToFlatData(buffer Buffer, flat any) error
// BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType)
// to the deviceNum, and returns the corresponding Buffer.
BufferFromFlatData(deviceNum DeviceNum, flat any, shape shapes.Shape) (Buffer, error)
// that can be used directly by the engine and has a local address that can be read or mutated
// directly by the client.
HasSharedBuffers() bool
// computations and directly read or mutated by the clients.
//
// It panics if the backend doesn't support shared buffers -- see HasSharedBuffer.
//
// The shared buffer should not be mutated while it is used by an execution.
// Also, the shared buffer cannot be "donated" during execution.
//
// When done, to release the memory, call BufferFinalized on the returned buffer.
//
// It returns a handle to the buffer and a slice of the corresponding data type pointing
// to the shared data.
NewSharedBuffer(deviceNum DeviceNum, shape shapes.Shape) (buffer Buffer, flat any, err error)
// BufferData returns a slice pointing to the buffer storage memory directly.
//
// This only works if HasSharedBuffer is true, that is, if the backend engine runs on CPU, or
// shares CPU memory.
//
// The returned slice becomes invalid after the buffer is destroyed.
BufferData(buffer Buffer) (flat any, err error)
}
DataInterface is the Backend's subinterface that defines the API to transfer Buffer to/from accelerators for the backend.
type DeviceNum ¶
type DeviceNum int
DeviceNum represents which device holds a buffer or should execute a computation. It's up to the backend to interpret it, but it should be between 0 and Backend.NumDevices.
type Executable ¶
type Executable interface {
// Finalize immediately frees resources associated to the executable.
Finalize()
// Inputs returns the parameters' names and shapes, in order created by the Builder.Parameter calls.
Inputs() (names []string, inputShapes []shapes.Shape)
// Outputs returns the computation's output shapes, in the order given to the Builder.Compile call.
Outputs() (outputShapes []shapes.Shape)
// Execute the executable on the default device (0).
// The number and shapes of the inputs must match those returned by Inputs.
//
// The inputs marked in donate will become invalid after use.
// This is useful if the input buffer is no longer needed or if updating a variable
// so its Buffer space can be reused as an output Buffer.
//
// Donated buffers are no longer valid after the call.
// If donate is nil, it is assumed to be false for all buffers, and no buffer is donated.
Execute(inputs []Buffer, donate []bool) ([]Buffer, error)
}
Executable is the API for compiled programs ready to execute.
type FFTType ¶
type FFTType int
func FFTTypeString ¶ added in v0.17.1
FFTTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func FFTTypeValues ¶ added in v0.17.1
func FFTTypeValues() []FFTType
FFTTypeValues returns all values of the enum
func (FFTType) IsAFFTType ¶ added in v0.17.1
IsAFFTType returns "true" if the value is listed in the enum definition. "false" otherwise
type Op ¶
type Op any
Op represents the output of an operation, during the computation graph building time.
It is opaque from GoMLX perspective: it passes Op as input to the other methods.
type OpType ¶ added in v0.19.0
type OpType int
OpType is an enum of all generic operations that can be supported by a Backend.Builder.
Notice: nothing precludes a specialized backend Builder to support other ops not included here. It requires some careful casting of interfaces by the caller (presumably in package github.com/gomlx/gomlx/graph) and fallback to backends that don't support the specialized op.
const ( OpTypeInvalid OpType = iota OpTypeParameter OpTypeConstant OpTypeIdentity OpTypeReduceWindow OpTypeRngBitGenerator OpTypeBatchNormForInference OpTypeBatchNormForTraining OpTypeBatchNormGradient OpTypeBitCount OpTypeAbs OpTypeAdd OpTypeArgMinMax OpTypeBitcast OpTypeBitwiseAnd OpTypeBitwiseNot OpTypeBitwiseOr OpTypeBitwiseXor OpTypeBroadcast OpTypeBroadcastInDim OpTypeCeil OpTypeClz OpTypeComplex OpTypeConcatenate OpTypeConj OpTypeConvGeneralDilated OpTypeConvertDType OpTypeCos OpTypeDiv OpTypeDot OpTypeDotGeneral OpTypeDynamicSlice OpTypeDynamicUpdateSlice OpTypeEqual OpTypeEqualTotalOrder OpTypeErf OpTypeExp OpTypeExpm1 OpTypeFFT OpTypeFloor OpTypeGather OpTypeGreaterOrEqual OpTypeGreaterOrEqualTotalOrder OpTypeGreaterThan OpTypeGreaterThanTotalOrder OpTypeImag OpTypeIota OpTypeIsFinite OpTypeLessOrEqual OpTypeLessOrEqualTotalOrder OpTypeLessThan OpTypeLessThanTotalOrder OpTypeLog OpTypeLog1p OpTypeLogicalAnd OpTypeLogicalNot OpTypeLogicalOr OpTypeLogicalXor OpTypeLogistic OpTypeMax OpTypeMin OpTypeMul OpTypeNeg OpTypeNotEqual OpTypeNotEqualTotalOrder OpTypePad OpTypePow OpTypeReal OpTypeReduceBitwiseAnd OpTypeReduceBitwiseOr OpTypeReduceBitwiseXor OpTypeReduceLogicalAnd OpTypeReduceLogicalOr OpTypeReduceLogicalXor OpTypeReduceMax OpTypeReduceMin OpTypeReduceProduct OpTypeReduceSum OpTypeRem OpTypeReshape OpTypeReverse OpTypeRound OpTypeRsqrt OpTypeScatterMax OpTypeScatterMin OpTypeScatterSum OpTypeSelectAndScatterMax OpTypeSelectAndScatterMin OpTypeSelectAndScatterSum OpTypeShiftLeft OpTypeShiftRightArithmetic OpTypeShiftRightLogical OpTypeSign OpTypeSin OpTypeSlice OpTypeSqrt OpTypeSub OpTypeTanh OpTypeTranspose OpTypeWhere // OpTypeLast should always be kept the last, it is used as a counter/marker for OpType. OpTypeLast )
func OpTypeString ¶ added in v0.19.0
OpTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func OpTypeValues ¶ added in v0.19.0
func OpTypeValues() []OpType
OpTypeValues returns all values of the enum
type PadAxis ¶
type PadAxis struct {
Start, End, Interior int
}
PadAxis defines the amount of padding preceding one axis (Start), at the end of axis (End) or in between the inputs (Interior). This is used as a parameter for the Pad operation.
type ReduceOpType ¶
type ReduceOpType int
ReduceOpType select among the basic types of reduction supported, see XlaBuilder.ReduceComputation.
const ( // ReduceOpUndefined is an undefined value. ReduceOpUndefined ReduceOpType = iota // ReduceOpSum reduces by summing all elements being reduced. ReduceOpSum // ReduceOpProduct reduces by multiplying all elements being reduced. ReduceOpProduct // ReduceOpMax reduces by taking the maximum value. ReduceOpMax // ReduceOpMin reduces by taking the minimum value. ReduceOpMin )
func ReduceOpTypeString ¶ added in v0.17.1
func ReduceOpTypeString(s string) (ReduceOpType, error)
ReduceOpTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func ReduceOpTypeValues ¶ added in v0.17.1
func ReduceOpTypeValues() []ReduceOpType
ReduceOpTypeValues returns all values of the enum
func (ReduceOpType) IsAReduceOpType ¶ added in v0.17.1
func (i ReduceOpType) IsAReduceOpType() bool
IsAReduceOpType returns "true" if the value is listed in the enum definition. "false" otherwise
func (ReduceOpType) String ¶
func (i ReduceOpType) String() string
type StandardOps ¶
type StandardOps interface {
// Abs returns the Op that represents the output of the corresponding operation.
Abs(x Op) (Op, error)
// Add returns the element-wise sum of the two values.
// Standard broadcasting rules apply (see documentation).
// The op is created on the same XlaBuilder as used for x0 and x1.
Add(x0, x1 Op) (Op, error)
// ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x.
// outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input.
// It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than
// the rank of x.
// Examples:
// ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0} // (it chooses the 0 and the -3)
// ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)
ArgMinMax(x Op, axis int, outputDType dtypes.DType, isMin bool) (Op, error)
// Bitcast performs an elementwise bit-cast operation from a dtype to another dtype.
// The bitcast doesn't "convert" anything, it just reinterprets the bits from x.DType() to the targetDType.
// If x.DType() and targetDType use the same number of bytes (targetDType.Size() = x.DType().Size()),
// the dimensions are not changed, simply the dtype is changed.
// If targetDType.Size() > x.DType().Size(), it requires that x last axis to have a dimension of targetDType.Size() / x.DType().Size(),
// and the returned shape will trim the last axis.
// If targetDType.Size() < x.DType().Size(), the returned shape will have an extra axis in the end, with dimension of
// x.DType().Size() / targetDType.Size().
// E.g: Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xdead, 0xbeef}}
Bitcast(x Op, targetDType dtypes.DType) (Op, error)
// BitwiseAnd returns the element-wise bitwise AND operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
BitwiseAnd(x0, x1 Op) (Op, error)
// BitwiseNot returns the element-wise bitwise AND operation.
BitwiseNot(x Op) (Op, error)
// BitwiseOr returns the element-wise bitwise OR operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
BitwiseOr(x0, x1 Op) (Op, error)
// BitwiseXor returns the element-wise bitwise XOR operator.
// The op is created on the same XlaBuilder as used for x0 and x1.
BitwiseXor(x0, x1 Op) (Op, error)
// Broadcast prefixes dimensions to an array by duplicating the data in the array.
// See BroadcastInDim for a broadcast in between the axes.
// The new dimensions dims are inserted on the left, i.e., if
// prefixDims has values `{a0, ..., aN}` and the operand shape
// has dimensions {b0, ..., bM} then the shape of the output has
// dimensions {a0, ..., aN, b0, ..., bM}.
// The new dimensions id into copies of the operand, i.e.
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
Broadcast(x Op, prefixDims ...int) (Op, error)
// BroadcastInDim broadcasts x to an output with the given shape.
// broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()).
// The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output.
// broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only
// broadcast and introduce new axes in-between.
// This also requires that the i-th input axis is either 1 or is the same as the
// output dimension it's broadcasting into.
// For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:
// - Specifying []int{1} as broadcastAxes will generate output
// {{1, 2},
// {1, 2}}
// - On the other hand, specifying []int{0} as broadcastAxes
// will generate output
// {{1 , 1},
// {2 , 2}}
BroadcastInDim(x Op, outputShape shapes.Shape, broadcastAxes []int) (Op, error)
// Ceil returns the Op that represents the output of the corresponding operation.
Ceil(x Op) (Op, error)
// Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.
Clz(x Op) (Op, error)
// Complex returns the complex number taking x0 as the real part and x1 as the imaginary part.
// The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or
// `dtypes.Float64`.
// The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes.
// The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case
// the value is broadcast to every other value.
// The op is created on the same XlaBuilder as used for x0 and x1.
Complex(x0, x1 Op) (Op, error)
// Concatenate results on the given axis.
// All axes that are not being concatenated must match dimensions.
// It doesn't work with scalars -- use ExpandDims.
// If there is only one operand, it is returned and this is a no-op.
Concatenate(axis int, operands ...Op) (Op, error)
// Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i
Conj(x Op) (Op, error)
// ConvGeneralDilated is a generic Convolution operation offered by XLA.
// featureAxisAfter defines whether the features (aka. channels or depth) axis comes after the
// spatial dimension. Example: a 2D input can be one of the two:
// - featureAxisAfter=false: input=[batch_size, features, height, width], filter=[output_features, input_features, height, width]
// - featureAxisAfter=true: input=[batch_size, height, width, features], filter=[output_features, height, width, input_features]
// Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution.
// There operand and filter are called lhs and rhs.
// (XLA documentation is unfortunately poor, much is guess-work).
// Also useful, https://arxiv.org/pdf/1603.07285v1.pdf.
ConvGeneralDilated(operand, filter Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilation, filterDilation []int, filterGroupCount, batchGroupCount int) (Op, error)
// ConvertDType of x to dtype.
ConvertDType(x Op, dtype dtypes.DType) (Op, error)
// Cos returns the Op that represents the output of the corresponding operation.
Cos(x Op) (Op, error)
// Div returns the element-wise division of the two values.
// Standard broadcasting rules apply (see documentation).
// The op is created on the same XlaBuilder as used for x0 and x1.
Div(x0, x1 Op) (Op, error)
// Dot returns the "dot product" operation.
// The exact semantics of this operation depend on the ranks of the operands:
// | Input | Output | Semantics |
// | vector [n] dot vector [n] | scalar | vector dot product |
// | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication |
// | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication |
// The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and
// the first dimension of x1.
// These are the "contracted" dimensions.
// The contracted dimensions of x0 and x1 must be of the same size.
// In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or
// matrix/matrix multiplications.
// The op is created on the same XlaBuilder as used for x0 and x1.
Dot(x0, x1 Op) (Op, error)
// DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications
// for a general vector product -- a generalized "Einsum". Each axis can be:
// - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions
// must match in lhs and rhs.
// - Crossed (default), in which case the output is the combination (concatenation) of the
// dimensions.
// - Contracted (contracting axes), where the output does multiply the values and reduce sum
// those dimensions.
// It follows that the resulting dimension number starts with the batch dimension, then the 'lhs'
// non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension.
// It provides the basic means of implementing Einsum.
DotGeneral(lhs Op, lhsContractingAxes, lhsBatchAxes []int, rhs Op, rhsContractingAxes, rhsBatchAxes []int) (Op, error)
// DynamicSlice extracts a sub-array from the input array at dynamic start_indices.
// The size of the slice in each axis is passed in sliceDims, which specify the slice
// intervals for each axis: [start, start + size).
// The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
// See description in https://openxla.org/xla/operation_semantics#dynamicslice
DynamicSlice(operand Op, startIndices []Op, sliceDims []int) (Op, error)
// DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten
// at startIndices.
// The shape of update determines the shape of the sub-array of the result which is updated.
// The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
// See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice
DynamicUpdateSlice(operand, update Op, startIndices []Op) (Op, error)
// Equal performs element-wise equality check, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
Equal(x0, x1 Op) (Op, error)
// EqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
EqualTotalOrder(x0, x1 Op) (Op, error)
// Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.
Erf(x Op) (Op, error)
// Exp returns the Op that represents the output of the corresponding operation.
Exp(x Op) (Op, error)
// Expm1 returns the Op that represents the output of the corresponding operation.
Expm1(x Op) (Op, error)
// FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions.
// See documentation in https://www.tensorflow.org/xla/operation_semantics.
// Underlying, CPU FFT is backed by Eigen's TensorFFT and GPU FFT uses cuFFT.
FFT(operand Op, fftType FFTType, fftLength []int) (Op, error)
// Floor returns the Op that represents the output of the corresponding operation.
Floor(x Op) (Op, error)
// Gather is a powerful but cumbersome Gather operation offered by XLA.
// Full details in https://www.tensorflow.org/xla/operation_semantics#gather.
// (Warning: it's circular and cumbersome)
// The output of Gather has the same DType of the operand, from where we are pulling the data.
// It's shape will be composed of 2 parts:
// - Batch axes: they come from the axes of startIndices, except the "indexVectorAxis" (usually the last)
// that is used as the indices into the operand. (*)
// - "Offset axes": these are axes that come from the operand, the sizes given by sliceSizes. Notice
// that if sliceSizes for an axis is 1, and that axis feature in the collapsedSliceAxes list, this
// axis gets omitted in the output.
// So in general output.Rank() = startIndices.Rank() - 1 + len(offsetAxes).
// (*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an
// extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
// Arguments:
// - operand: the values from where we are gathering. The output DType will follow the operand one.
// - startIndices: are the indices we want to gather. There will be one axis pointed by indexVector axis which
// enumerates the indices of the slice to be gathered in the operand array (their values are mapped to the axis
// in the operand according to startIndexMap).
// All other axes are "batch dimensions" and they will have equivalent axes (same dimensions) in the output.
// - indexVectorAxis: which of the axis in startIndices is collected and used as the start index for slices
// to be gathered in the operand.
// It is typically the last axis of startIndices, so startIndices.Shape.Rank()-1.
// There is a special case where indexVectorAxis == startIndices.Rank() in which case we assume there is an
// extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
// - offsetOutputAxes: axes in the _output_ (not on the operand) that will hold the "offset slices", slices that are not
// collapsed. It points in which position (axis) in the output these slices should show up. Any axis in sliceSizes
// that is > 1 must feature here.
// Notice all axes in the operand will either become an "offset axis" in the output, if their slice size > 1,
// of optionally collapsed (or "squeezed") in the output, if their slice size == 1. We map the axes in the output
// (given in offsetAxes) to the axes in the operand (the axes not present in collapsedSliceAxes) sequentially.
// One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetAxes).
// - collapsedSliceAxes: for sliceSizes that are 1 in the operand, one may not want to include them in the output.
// The _operand_ axes included here are marked to be collapsed (removed) in the output. Notice, the corresponding
// value in sliceSizes must be 1.
// One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetOutputAxes).
// - startIndexMap: this maps which value in startIndices is used for which axis index in the slice to be gathered.
// Notice len(startIndexMap) must match the startIndices.Shape().Dimensions[indexVectorAxis].
// E.g: if startIndices.shape=(2, 3), indexVectorAxis=1, and operand.rank=4 and startIndexMap=[]int{0, 1, 2},
// this mean each row of the startIndices will point to the first 3 axis (0,1 and 2) in operand.
// In many cases this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each "index vector" fully defines
// an element on the operand. In some this is only a prefix of the operand's rank.
// For those axis in the operand not explicitly set (so if len(startIndexMap) < operand.Rank()), the corresponding
// axis start index is considered to be 0, and one sets the sliceSizes to take the slice one wants (typically the
// full slice).
// - sliceSizes: once the start index from where to gather is resolved, this defines how much data in each axis
// to gather. The "offset" output axes (see above) will have dimensions equal to this number for not axes that
// are not collapsed.
// - indicesAreSorted: can be set to true if its guaranteed that startIndices are sorted (in ascending order,
// after scattering its values according to start_index_map) by the user. This allows for some optimizations
// in some platforms.
Gather(operand, startIndices Op, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (Op, error)
// GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
GreaterOrEqual(x0, x1 Op) (Op, error)
// GreaterOrEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
GreaterOrEqualTotalOrder(x0, x1 Op) (Op, error)
// GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
GreaterThan(x0, x1 Op) (Op, error)
// GreaterThanTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
GreaterThanTotalOrder(x0, x1 Op) (Op, error)
// Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.
Imag(x Op) (Op, error)
// Iota creates a constant of the given shape with increasing numbers (starting from 0)
// on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0)
// returns [[0 0][1 1]].
Iota(shape shapes.Shape, iotaAxis int) (Op, error)
// IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN.
// It returns an array of boolean values with the same shape as the input, where each element is true if and only if
// the corresponding input element is finite.
IsFinite(x Op) (Op, error)
// LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
LessOrEqual(x0, x1 Op) (Op, error)
// LessOrEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
LessOrEqualTotalOrder(x0, x1 Op) (Op, error)
// LessThan performs element-wise comparison, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
LessThan(x0, x1 Op) (Op, error)
// LessThanTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
LessThanTotalOrder(x0, x1 Op) (Op, error)
// Log returns the Op that represents the output of the corresponding operation.
Log(x Op) (Op, error)
// Log1p returns the expression log(x+1).
Log1p(x Op) (Op, error)
// LogicalAnd returns the element-wise logical AND operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
LogicalAnd(x0, x1 Op) (Op, error)
// LogicalNot returns the Op that represents the output of the corresponding operation.
LogicalNot(x Op) (Op, error)
// LogicalOr returns the element-wise logical OR operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
LogicalOr(x0, x1 Op) (Op, error)
// LogicalXor returns the element-wise logical XOR operator.
// The op is created on the same XlaBuilder as used for x0 and x1.
LogicalXor(x0, x1 Op) (Op, error)
// Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.
Logistic(x Op) (Op, error)
// Max returns the element-wise highest value among the two.
// The op is created on the same XlaBuilder as used for x0 and x1.
Max(x0, x1 Op) (Op, error)
// Min returns the element-wise smallest value among the two.
// The op is created on the same XlaBuilder as used for x0 and x1.
Min(x0, x1 Op) (Op, error)
// Mul returns the element-wise multiplication of the two values.
// Standard broadcasting rules apply (see documentation).
// The op is created on the same XlaBuilder as used for x0 and x1.
Mul(x0, x1 Op) (Op, error)
// Neg returns the Op that represents the output of the corresponding operation.
Neg(x Op) (Op, error)
// NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input.
// The op is created on the same XlaBuilder as used for x0 and x1.
NotEqual(x0, x1 Op) (Op, error)
// NotEqualTotalOrder returns the element-wise operation.
// Standard broadcasting rules apply (see documentation).
// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
// The op is created on the same XlaBuilder as used for x0 and x1.
NotEqualTotalOrder(x0, x1 Op) (Op, error)
// Pad injects padding on the start, end or interior (in between each element) of the given operand.
// There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros,
// that is, no padding for those axes.
Pad(x, fillValue Op, axesConfig ...PadAxis) (Op, error)
// Pow returns the Op that represents the output of the corresponding operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
Pow(x0, x1 Op) (Op, error)
// Real return the real part of a complex number. It returns x if the x is a float number.
Real(x Op) (Op, error)
// ReduceBitwiseAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical And of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceBitwiseAnd(x Op, axes ...int) (Op, error)
// ReduceBitwiseOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Or of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceBitwiseOr(x Op, axes ...int) (Op, error)
// ReduceBitwiseXor is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Xor of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceBitwiseXor(x Op, axes ...int) (Op, error)
// ReduceLogicalAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical And of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceLogicalAnd(x Op, axes ...int) (Op, error)
// ReduceLogicalOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Or of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceLogicalOr(x Op, axes ...int) (Op, error)
// ReduceLogicalXor is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Xor of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceLogicalXor(x Op, axes ...int) (Op, error)
// ReduceMax is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the max value.
// If no axes are given, it reduces the full array.
ReduceMax(x Op, axes ...int) (Op, error)
// ReduceMin is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the min value.
// If no axes are given, it reduces the full array.
ReduceMin(x Op, axes ...int) (Op, error)
// ReduceProduct is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the product of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceProduct(x Op, axes ...int) (Op, error)
// ReduceSum is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the sum of the reduced axes.
// If no axes are given, it reduces the full array.
ReduceSum(x Op, axes ...int) (Op, error)
// Rem returns the remainder operation, also known as modulo (or Mod for short).
// Notice despite the name XLA implements Mod not IEEE754 Remainder operation.
// The op is created on the same XlaBuilder as used for x0 and x1.
Rem(x0, x1 Op) (Op, error)
// Reshape reshapes x to the new dimensions.
// Total size cannot change, it's just a "reinterpretation" of the same flat data.
// The dtype remains the same, see ConvertDType to actually convert the values.
Reshape(x Op, dimensions ...int) (Op, error)
// Reverse returns x with the values for the given dimensions reversed, that is,
// the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`.
// The shape remains the same.
Reverse(x Op, axes ...int) (Op, error)
// Round returns the Op that represents the output of the corresponding operation.
Round(x Op) (Op, error)
// Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).
Rsqrt(x Op) (Op, error)
// ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.
ScatterMax(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (Op, error)
// ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.
ScatterMin(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (Op, error)
// ScatterSum values from updates pointed by scatterIndices to operand.
ScatterSum(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (Op, error)
// SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
// It selects the values in the window such that it works as reverse for a PoolMax operation.
// See details in https://openxla.org/xla/operation_semantics#selectandscatter
SelectAndScatterMax(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) (Op, error)
// SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
// It selects the values in the window such that it works as reverse for a PoolMin operation.
// See details in https://openxla.org/xla/operation_semantics#selectandscatter
SelectAndScatterMin(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) (Op, error)
// SelectAndScatterSum runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
// It selects the values in the window such that it works as reverse for a PoolSum operation.
// See details in https://openxla.org/xla/operation_semantics#selectandscatter
SelectAndScatterSum(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) (Op, error)
// ShiftLeft n bits. It implicitly preserves the sign bit, if there is no overflow. So ShiftLeft(-1, 1) = -2.
// The op is created on the same XlaBuilder as used for x0 and x1.
ShiftLeft(x0, x1 Op) (Op, error)
// ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1.
// The op is created on the same XlaBuilder as used for x0 and x1.
ShiftRightArithmetic(x0, x1 Op) (Op, error)
// ShiftRightLogical shifts right by n bits, destroying the sign bit.
// The op is created on the same XlaBuilder as used for x0 and x1.
ShiftRightLogical(x0, x1 Op) (Op, error)
// Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.
Sign(x Op) (Op, error)
// Sin returns the Op that represents the output of the corresponding operation.
Sin(x Op) (Op, error)
// Slice extracts a sub-array from the input array.
// The sub-array is of the same rank as the input and contains the values inside a bounding box within the input array
// where the dimensions and indices of the bounding box are given as arguments to the slice operation.
// The strides set the input stride of the slice in each axis and must be >= 1.
// It is optional, and if missing it is assumed to be 1 for every dimension.
// Examples:
// Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
// Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}
Slice(x Op, starts, limits, strides []int) (Op, error)
// Sqrt returns the Op that represents the output of the corresponding operation.
Sqrt(x Op) (Op, error)
// Sub returns the element-wise subtraction of the two values.
// Standard broadcasting rules apply (see documentation).
// The op is created on the same XlaBuilder as used for x0 and x1.
Sub(x0, x1 Op) (Op, error)
// Tanh returns the Op that represents the output of the corresponding operation.
Tanh(x Op) (Op, error)
// Transpose axes of x.
// There should be one value in permutations for each axis in x.
// The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].
Transpose(x Op, permutations ...int) (Op, error)
// Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).
Where(condition, onTrue, onFalse Op) (Op, error)
}
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
Package _default includes the default backends, namely XLA and SimpleGo.
|
Package _default includes the default backends, namely XLA and SimpleGo. |
|
Package notimplemented implements a backends.Builder interface that throws a "Not implemented" exception to all operations.
|
Package notimplemented implements a backends.Builder interface that throws a "Not implemented" exception to all operations. |
|
Package shapeinference calculates the shape resulting from operations, and validates its inputs.
|
Package shapeinference calculates the shape resulting from operations, and validates its inputs. |
|
Package simplego implements a simple, and not very fast, but very portable backend for GoMLX.
|
Package simplego implements a simple, and not very fast, but very portable backend for GoMLX. |
|
Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX.
|
Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX. |
|
cpu/dynamic
Package dynamic links the XLA/PJRT CPU plugin dynamically (as in ".so" libraries) with your binary.
|
Package dynamic links the XLA/PJRT CPU plugin dynamically (as in ".so" libraries) with your binary. |
|
cpu/static
Package static links the XLA/PJRT CPU plugin statically with your binary.
|
Package static links the XLA/PJRT CPU plugin statically with your binary. |