Documentation
¶
Overview ¶
Package stablehlo implements a GoMLX backend using StableHLO (see github.com/gomlx/stablehlo) as a language to talk to PJRT, the C++ engine for XLA (see github.com/gomlx/gopjrt/pjrt).
The backend is registered as "stablehlo", "shlo" or "hlo" (all aliases to the same backend).
Index ¶
- Constants
- Variables
- func GetAvailablePlugins() []string
- func New(config string) (backends.Backend, error)
- func ShapeFromStableHLO(shape stablehloshapes.Shape) shapes.Shape
- func ShapeToStableHLO(shape shapes.Shape) stablehloshapes.Shape
- type Backend
- func (backend *Backend) BufferData(buffer backends.Buffer) (flat any, err error)
- func (backend *Backend) BufferDeviceNum(buffer backends.Buffer) (backends.DeviceNum, error)
- func (backend *Backend) BufferFinalize(buffer backends.Buffer) error
- func (backend *Backend) BufferFromFlatData(deviceNum backends.DeviceNum, flat any, shape shapes.Shape) (backends.Buffer, error)
- func (backend *Backend) BufferShape(buffer backends.Buffer) (shapes.Shape, error)
- func (backend *Backend) BufferToFlatData(buffer backends.Buffer, flat any) error
- func (backend *Backend) Builder(name string) backends.Builder
- func (backend *Backend) Capabilities() backends.Capabilities
- func (backend *Backend) CheckValid() error
- func (backend *Backend) Description() string
- func (backend *Backend) Finalize()
- func (backend *Backend) HasSharedBuffers() bool
- func (backend *Backend) IsFinalized() bool
- func (backend *Backend) Name() string
- func (backend *Backend) NewSharedBuffer(deviceNum backends.DeviceNum, shape shapes.Shape) (buffer backends.Buffer, flat any, err error)
- func (backend *Backend) NumDevices() backends.DeviceNum
- func (backend *Backend) String() string
- type Builder
- func (b *Builder) Abs(operand backends.Op) (backends.Op, error)
- func (b *Builder) Add(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) ArgMinMax(x backends.Op, axis int, outputDType dtypes.DType, isMin bool) (backends.Op, error)
- func (b *Builder) BatchNormForInference(input, scale, offset, mean, variance backends.Op, epsilon float32, ...) (backends.Op, error)
- func (b *Builder) BatchNormForTraining(input, scale, offset backends.Op, epsilon float32, featureAxis int) (output, batchMean, batchVar backends.Op, err error)
- func (b *Builder) BatchNormGradient(gradOutput, input, scale, mean, variance backends.Op, epsilon float32, ...) (gradInput, gradScale, gradOffset backends.Op, err error)
- func (b *Builder) BitCount(operand backends.Op) (backends.Op, error)
- func (b *Builder) Bitcast(x backends.Op, targetDType dtypes.DType) (backends.Op, error)
- func (b *Builder) BitwiseAnd(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) BitwiseNot(operand backends.Op) (backends.Op, error)
- func (b *Builder) BitwiseOr(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) BitwiseXor(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Broadcast(x backends.Op, prefixDims ...int) (backends.Op, error)
- func (b *Builder) BroadcastInDim(x backends.Op, outputShape shapes.Shape, broadcastAxes []int) (backends.Op, error)
- func (b *Builder) Ceil(operand backends.Op) (backends.Op, error)
- func (b *Builder) CheckValid() error
- func (b *Builder) Clamp(min, a backends.Op, max backends.Op) (backends.Op, error)
- func (b *Builder) Clz(operand backends.Op) (backends.Op, error)
- func (b *Builder) Compile(outputs ...backends.Op) (backends.Executable, error)
- func (b *Builder) Complex(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Concatenate(axis int, operands ...backends.Op) (backends.Op, error)
- func (b *Builder) Conj(operand backends.Op) (backends.Op, error)
- func (b *Builder) Constant(flat any, dimensions ...int) (backends.Op, error)
- func (b *Builder) ConvGeneral(input, kernel backends.Op, axes backends.ConvolveAxesConfig, strides []int, ...) (backends.Op, error)
- func (b *Builder) ConvertDType(x backends.Op, dtype dtypes.DType) (backends.Op, error)
- func (b *Builder) Cos(operand backends.Op) (backends.Op, error)
- func (b *Builder) Div(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Dot(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) DotGeneral(lhs backends.Op, lhsContractingAxes, lhsBatchAxes []int, rhs backends.Op, ...) (backends.Op, error)
- func (b *Builder) DynamicSlice(operand backends.Op, startIndices []backends.Op, sliceDims []int) (backends.Op, error)
- func (b *Builder) DynamicUpdateSlice(operand, update backends.Op, startIndices []backends.Op) (backends.Op, error)
- func (b *Builder) Equal(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) EqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Erf(operand backends.Op) (backends.Op, error)
- func (b *Builder) Exp(operand backends.Op) (backends.Op, error)
- func (b *Builder) Expm1(operand backends.Op) (backends.Op, error)
- func (b *Builder) FFT(x backends.Op, fftType backends.FFTType, fftLength []int) (backends.Op, error)
- func (b *Builder) Floor(operand backends.Op) (backends.Op, error)
- func (b *Builder) Gather(operand, startIndices backends.Op, indexVectorAxis int, ...) (backends.Op, error)
- func (b *Builder) GreaterOrEqual(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) GreaterOrEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) GreaterThan(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) GreaterThanTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Identity(x backends.Op) (backends.Op, error)
- func (b *Builder) Imag(operand backends.Op) (backends.Op, error)
- func (b *Builder) Iota(shape shapes.Shape, iotaAxis int) (backends.Op, error)
- func (b *Builder) IsFinite(operand backends.Op) (backends.Op, error)
- func (b *Builder) IsNaN(x backends.Op) (backends.Op, error)
- func (b *Builder) LessOrEqual(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) LessOrEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) LessThan(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) LessThanTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Log(operand backends.Op) (backends.Op, error)
- func (b *Builder) Log1p(operand backends.Op) (backends.Op, error)
- func (b *Builder) LogicalAnd(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) LogicalNot(operand backends.Op) (backends.Op, error)
- func (b *Builder) LogicalOr(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) LogicalXor(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Logistic(operand backends.Op) (backends.Op, error)
- func (b *Builder) Max(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Min(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Mul(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Neg(operand backends.Op) (backends.Op, error)
- func (b *Builder) NotEqual(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) NotEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) OpShape(op backends.Op) (shapes.Shape, error)
- func (b *Builder) Pad(x, fillValue backends.Op, axesConfig ...backends.PadAxis) (backends.Op, error)
- func (b *Builder) Parameter(name string, shape shapes.Shape) (backends.Op, error)
- func (b *Builder) Pow(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Real(operand backends.Op) (backends.Op, error)
- func (b *Builder) ReduceBitwiseAnd(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceBitwiseOr(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceBitwiseXor(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceLogicalAnd(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceLogicalOr(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceLogicalXor(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceMax(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceMin(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceProduct(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceSum(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) ReduceWindow(x backends.Op, reductionType backends.ReduceOpType, ...) (backends.Op, error)
- func (b *Builder) Rem(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Reshape(x backends.Op, dimensions ...int) (backends.Op, error)
- func (b *Builder) Reverse(x backends.Op, axes ...int) (backends.Op, error)
- func (b *Builder) RngBitGenerator(state backends.Op, shape shapes.Shape) (newState backends.Op, values backends.Op, err error)
- func (b *Builder) Round(operand backends.Op) (backends.Op, error)
- func (b *Builder) Rsqrt(operand backends.Op) (backends.Op, error)
- func (b *Builder) ScatterMax(operand, scatterIndices, updates backends.Op, indexVectorAxis int, ...) (backends.Op, error)
- func (b *Builder) ScatterMin(operand, scatterIndices, updates backends.Op, indexVectorAxis int, ...) (backends.Op, error)
- func (b *Builder) ScatterSum(operand, scatterIndices, updates backends.Op, indexVectorAxis int, ...) (backends.Op, error)
- func (b *Builder) SelectAndScatterMax(operand, source backends.Op, windowDimensions, windowStrides []int, ...) (backends.Op, error)
- func (b *Builder) SelectAndScatterMin(operand, source backends.Op, windowDimensions, windowStrides []int, ...) (backends.Op, error)
- func (b *Builder) ShiftLeft(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) ShiftRightArithmetic(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) ShiftRightLogical(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Sign(operand backends.Op) (backends.Op, error)
- func (b *Builder) Sin(operand backends.Op) (backends.Op, error)
- func (b *Builder) Slice(x backends.Op, starts, limits, strides []int) (backends.Op, error)
- func (b *Builder) Sqrt(operand backends.Op) (backends.Op, error)
- func (b *Builder) Sub(lhs, rhs backends.Op) (backends.Op, error)
- func (b *Builder) Tanh(operand backends.Op) (backends.Op, error)
- func (b *Builder) Transpose(x backends.Op, permutation ...int) (backends.Op, error)
- func (b *Builder) Where(condition, onTrue, onFalse backends.Op) (backends.Op, error)
- type DotGeneralConfig
- type Executable
- func (e *Executable) CheckValid() error
- func (e *Executable) Execute(inputs []backends.Buffer, donate []bool) ([]backends.Buffer, error)
- func (e *Executable) Finalize()
- func (e *Executable) Inputs() (names []string, inputShapes []shapes.Shape)
- func (e *Executable) Outputs() (outputShapes []shapes.Shape)
- type Node
Constants ¶
const BackendName = "stablehlo"
BackendName is the name of the backend.
The stablehlo backend also accepts the "hlo" and "pjrt" aliases.
Variables ¶
var Capabilities = backends.Capabilities{ Operations: map[backends.OpType]bool{ backends.OpTypeParameter: true, backends.OpTypeConstant: true, backends.OpTypeAbs: true, backends.OpTypeBitCount: true, backends.OpTypeBitwiseNot: true, backends.OpTypeCeil: true, backends.OpTypeClz: true, backends.OpTypeCos: true, backends.OpTypeErf: true, backends.OpTypeExp: true, backends.OpTypeExpm1: true, backends.OpTypeFloor: true, backends.OpTypeIsFinite: true, backends.OpTypeIsNaN: true, backends.OpTypeLog1p: true, backends.OpTypeLog: true, backends.OpTypeLogicalNot: true, backends.OpTypeLogistic: true, backends.OpTypeNeg: true, backends.OpTypeRound: true, backends.OpTypeRsqrt: true, backends.OpTypeSign: true, backends.OpTypeSin: true, backends.OpTypeSqrt: true, backends.OpTypeTanh: true, backends.OpTypeAdd: true, backends.OpTypeBitwiseAnd: true, backends.OpTypeBitwiseOr: true, backends.OpTypeBitwiseXor: true, backends.OpTypeDiv: true, backends.OpTypeLogicalAnd: true, backends.OpTypeLogicalOr: true, backends.OpTypeLogicalXor: true, backends.OpTypeMax: true, backends.OpTypeMin: true, backends.OpTypeMul: true, backends.OpTypePow: true, backends.OpTypeRem: true, backends.OpTypeSub: true, backends.OpTypeEqual: true, backends.OpTypeEqualTotalOrder: true, backends.OpTypeGreaterOrEqual: true, backends.OpTypeGreaterOrEqualTotalOrder: true, backends.OpTypeGreaterThan: true, backends.OpTypeGreaterThanTotalOrder: true, backends.OpTypeLessOrEqual: true, backends.OpTypeLessOrEqualTotalOrder: true, backends.OpTypeLessThan: true, backends.OpTypeLessThanTotalOrder: true, backends.OpTypeNotEqual: true, backends.OpTypeNotEqualTotalOrder: true, backends.OpTypeComplex: true, backends.OpTypeConj: true, backends.OpTypeImag: true, backends.OpTypeReal: true, backends.OpTypeArgMinMax: true, backends.OpTypeBatchNormForInference: true, backends.OpTypeBatchNormForTraining: true, backends.OpTypeBatchNormGradient: true, backends.OpTypeBitcast: true, backends.OpTypeBroadcast: true, backends.OpTypeBroadcastInDim: true, backends.OpTypeClamp: true, backends.OpTypeConcatenate: true, backends.OpTypeConvertDType: true, backends.OpTypeConvGeneral: true, backends.OpTypeDynamicSlice: true, backends.OpTypeDynamicUpdateSlice: true, backends.OpTypeDot: true, backends.OpTypeDotGeneral: true, backends.OpTypeFFT: true, backends.OpTypeGather: true, backends.OpTypeIdentity: true, backends.OpTypeIota: true, backends.OpTypePad: true, backends.OpTypeReduceBitwiseAnd: true, backends.OpTypeReduceBitwiseOr: true, backends.OpTypeReduceBitwiseXor: true, backends.OpTypeReduceLogicalAnd: true, backends.OpTypeReduceLogicalOr: true, backends.OpTypeReduceLogicalXor: true, backends.OpTypeReduceMax: true, backends.OpTypeReduceMin: true, backends.OpTypeReduceProduct: true, backends.OpTypeReduceSum: true, backends.OpTypeReduceWindow: true, backends.OpTypeReshape: true, backends.OpTypeReverse: true, backends.OpTypeRngBitGenerator: true, backends.OpTypeScatterSum: true, backends.OpTypeScatterMax: true, backends.OpTypeScatterMin: true, backends.OpTypeSelectAndScatterMax: true, backends.OpTypeSelectAndScatterMin: true, backends.OpTypeShiftLeft: true, backends.OpTypeShiftRightArithmetic: true, backends.OpTypeShiftRightLogical: true, backends.OpTypeSlice: true, backends.OpTypeTranspose: true, backends.OpTypeWhere: true, }, DTypes: map[dtypes.DType]bool{ dtypes.Bool: true, dtypes.Int8: true, dtypes.Int16: true, dtypes.Int32: true, dtypes.Int64: true, dtypes.Uint8: true, dtypes.Uint16: true, dtypes.Uint32: true, dtypes.Uint64: true, dtypes.Float32: true, dtypes.Float64: true, dtypes.BFloat16: true, dtypes.Complex64: true, dtypes.Complex128: true, }, }
Capabilities of the SimpleGo backends: the set of supported operations and data types.
var ( // DefaultPlugins is the list of plugins to use in preference order, if not otherwise specified. DefaultPlugins = []string{"cuda", "cpu"} )
Functions ¶
func GetAvailablePlugins ¶
func GetAvailablePlugins() []string
GetAvailablePlugins lists the available platforms -- it caches and reuses the result in future calls.
Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories if it is a ":" separated list. If it is not set, it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file) in that order.
If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.
See details in pjrt.AvailablePlugins.
func New ¶
New returns a new Backend using the config as a configuration. The config string should be the name of the PJRT plugin to use.
func ShapeFromStableHLO ¶
func ShapeFromStableHLO(shape stablehloshapes.Shape) shapes.Shape
ShapeFromStableHLO converts a StableHLO shape to a GomlX shape.
func ShapeToStableHLO ¶
func ShapeToStableHLO(shape shapes.Shape) stablehloshapes.Shape
ShapeToStableHLO converts a GomlX shape to a StableHLO shape.
Types ¶
type Backend ¶
type Backend struct {
DotGeneralConfig
// contains filtered or unexported fields
}
Backend implements the XLA/PJRT backends.Backend for GoMLX.
func NewWithOptions ¶
func NewWithOptions(config string, options pjrt.NamedValuesMap) (*Backend, error)
NewWithOptions creates a StableHLO backend with the given client options. It allows more control, not available with the default New constructor.
func (*Backend) BufferData ¶
BufferData implements backends.Backend interface.
For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.
func (*Backend) BufferDeviceNum ¶
BufferDeviceNum returns the deviceNum for the buffer.
func (*Backend) BufferFinalize ¶
BufferFinalize implements backends.DataInterface.
func (*Backend) BufferFromFlatData ¶
func (backend *Backend) BufferFromFlatData(deviceNum backends.DeviceNum, flat any, shape shapes.Shape) (backends.Buffer, error)
BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType) to the deviceNum, and returns the corresponding Buffer.
func (*Backend) BufferShape ¶
BufferShape returns the shape for the buffer.
func (*Backend) BufferToFlatData ¶
BufferToFlatData transfers the flat values of buffer to the Go flat array. The slice flat must have the exact number of elements required to store the Buffer shape.
See also FlatDataToBuffer, BufferShape, and shapes.Shape.Size.
func (*Backend) Capabilities ¶
func (backend *Backend) Capabilities() backends.Capabilities
Capabilities returns information about what is supported by this backend.
func (*Backend) CheckValid ¶
CheckValid returns an error if the backend is not valid: if it's nil or has already been finalized.
func (*Backend) Description ¶
Description is a longer description of the Backend that can be used to pretty-print.
func (*Backend) Finalize ¶
func (backend *Backend) Finalize()
Finalize releases all the associated resources immediately, and makes the backend invalid.
func (*Backend) HasSharedBuffers ¶
HasSharedBuffers returns whether this PJRT plugin supports "shared buffers". In PJRT that means supporting pjrt.Client.CreateViewOfDeviceBuffer.
func (*Backend) IsFinalized ¶
IsFinalized returns true if the backend is in an invalid state.
func (*Backend) Name ¶
Name returns the short name of the backend. E.g.: "stablehlo" for the StableHLO/PJRT plugin.
func (*Backend) NewSharedBuffer ¶
func (backend *Backend) NewSharedBuffer(deviceNum backends.DeviceNum, shape shapes.Shape) (buffer backends.Buffer, flat any, err error)
NewSharedBuffer implements backends.Backend interface.
For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.
func (*Backend) NumDevices ¶
NumDevices return the number of devices available for this Backend.
type Builder ¶
type Builder struct {
notimplemented.Builder
// contains filtered or unexported fields
}
Builder keeps track of the computation graph being defined.
func (*Builder) Abs ¶
Abs returns the Op that represents the output of the corresponding operation.
It is special-cased here because StableHLO doesn't define the Abs() of complex numbers.
func (*Builder) Add ¶
Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation).
func (*Builder) ArgMinMax ¶
func (b *Builder) ArgMinMax(x backends.Op, axis int, outputDType dtypes.DType, isMin bool) (backends.Op, error)
ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x.
outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input. It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than the rank of x.
Examples:
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0} // (it chooses the 0 and the -3)
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)
func (*Builder) BatchNormForInference ¶
func (b *Builder) BatchNormForInference(input, scale, offset, mean, variance backends.Op, epsilon float32, featureAxis int) (backends.Op, error)
BatchNormForInference implements backends.Builder interface.
func (*Builder) BatchNormForTraining ¶
func (b *Builder) BatchNormForTraining(input, scale, offset backends.Op, epsilon float32, featureAxis int) (output, batchMean, batchVar backends.Op, err error)
BatchNormForTraining implements backends.Builder interface.
func (*Builder) BatchNormGradient ¶
func (b *Builder) BatchNormGradient(gradOutput, input, scale, mean, variance backends.Op, epsilon float32, featureAxis int) (gradInput, gradScale, gradOffset backends.Op, err error)
BatchNormGradient implements backends.Builder interface.
func (*Builder) BitCount ¶
BitCount returns the number of bits that are set to one. Also known as Population Count ("Popcnt") or Hamming Weight.
func (*Builder) BitwiseAnd ¶
BitwiseAnd returns the element-wise bitwise AND operation.
func (*Builder) BitwiseNot ¶
BitwiseNot returns the element-wise bitwise AND operation.
func (*Builder) BitwiseXor ¶
BitwiseXor returns the element-wise bitwise XOR operator.
func (*Builder) BroadcastInDim ¶
func (b *Builder) BroadcastInDim(x backends.Op, outputShape shapes.Shape, broadcastAxes []int) (backends.Op, error)
BroadcastInDim broadcasts x to an output with the given shape. broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()). The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output. broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only broadcast and introduce new axes in-between. This also requires that the i-th input axis is either 1 or is the same as the output dimension it's broadcasting into. For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:
- Specifying []int{1} as broadcastAxes will generate output {{1, 2}, {1, 2}}
- On the other hand, specifying []int{0} as broadcastAxes will generate output {{1 , 1}, {2 , 2}}
func (*Builder) Ceil ¶
Ceil returns the Op that represents the output of the corresponding operation.
func (*Builder) CheckValid ¶
CheckValid returns an error if the backend or the builder are not ok.
E.g.: they have been finalized or the builder has already been compiled.
func (*Builder) Clamp ¶
Clamp returns the element-wise clamping operation.
The values max and min can either be a scalar or have the same shape as x.
func (*Builder) Clz ¶
Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.
func (*Builder) Complex ¶
Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value.
func (*Builder) Concatenate ¶
Concatenate operands on the given axis.
All axes that are not being concatenated must match dimensions, except on the axes being concatenated. It doesn't work with scalars -- use ExpandAxes. If there is only one operand, it is returned and this is a no-op.
func (*Builder) Constant ¶
Constant creates a constant in the graph with the given flat values and the shape defined by the dimensions.
The flat value must be a slice of a basic type supported -- that can be converted to a DType.
The value is copied into the graph. It's recommended that for very large tensors, even if constants, that they are passed as side inputNodes (or variables, see context package) instead.
func (*Builder) ConvGeneral ¶
func (b *Builder) ConvGeneral(input, kernel backends.Op, axes backends.ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilations, kernelDilations []int, channelGroupCount, batchGroupCount int) (backends.Op, error)
ConvGeneral implements the backends.Builder interface.
func (*Builder) ConvertDType ¶
ConvertDType implements backends.Builder interface.
func (*Builder) Div ¶
Div returns the element-wise division of the two values. Standard broadcasting rules apply (see documentation).
func (*Builder) Dot ¶
Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands: | Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication | The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications, or matrix/matrix multiplications.
func (*Builder) DotGeneral ¶
func (b *Builder) DotGeneral(lhs backends.Op, lhsContractingAxes, lhsBatchAxes []int, rhs backends.Op, rhsContractingAxes []int, rhsBatchAxes []int) (backends.Op, error)
DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:
- Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
- Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
- Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.
It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.
func (*Builder) DynamicSlice ¶
func (b *Builder) DynamicSlice(operand backends.Op, startIndices []backends.Op, sliceDims []int) (backends.Op, error)
DynamicSlice extracts a slice from the operand at the startIndices position and the given sliceSizes.
- operand: tensor from where to take the slice. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.
The startIndices are adjusted as follows:
adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - sliceSizes[i])
See description in https://openxla.org/xla/operation_semantics#dynamicslice
func (*Builder) DynamicUpdateSlice ¶
func (b *Builder) DynamicUpdateSlice(operand, update backends.Op, startIndices []backends.Op) (backends.Op, error)
DynamicUpdateSlice updates the operand with the values given in update, at the position given by startIndices.
- operand: original value that to be updated. - update: values to "paste" on top of operand, at position startIndices. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.
It returns a value with the same shape as the operand, with the values updated.
The startIndices are adjusted as follows:
adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - update.Dimensions[i])
func (*Builder) Equal ¶
Equal returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) EqualTotalOrder ¶
EqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) Erf ¶
Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.
func (*Builder) Expm1 ¶
Expm1 returns the Op that represents the output of the corresponding operation.
func (*Builder) FFT ¶
func (b *Builder) FFT(x backends.Op, fftType backends.FFTType, fftLength []int) (backends.Op, error)
FFT implements the Fast Fourier Transform operation. fftType specifies the type of FFT operation to perform. fftLength specifies the length of the transform for each axis.
func (*Builder) Floor ¶
Floor returns the Op that represents the output of the corresponding operation.
func (*Builder) Gather ¶
func (b *Builder) Gather(operand, startIndices backends.Op, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (backends.Op, error)
Gather is a powerful but cumbersome Gather operation. See details in the backend.
Notice GoMLX backend Gather operation doesn't support batching axes, which StableHLO does. For compatibility, we simply leave them empty.
func (*Builder) GreaterOrEqual ¶
GreaterOrEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) GreaterOrEqualTotalOrder ¶
GreaterOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) GreaterThan ¶
GreaterThan returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) GreaterThanTotalOrder ¶
GreaterThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) Identity ¶
Identity returns an Op whose output is the same as its input. It's a no-op that can serve as a place-holder.
func (*Builder) Imag ¶
Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.
func (*Builder) IsFinite ¶
IsFinite tests whether each element of operand is finite, i.e., if it is not positive nor negative infinity, and it is not NaN. It returns the same shape as the input, but with boolean values where each element is true if and only if the corresponding input element is finite.
func (*Builder) LessOrEqual ¶
LessOrEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) LessOrEqualTotalOrder ¶
LessOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) LessThan ¶
LessThan returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) LessThanTotalOrder ¶
LessThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) LogicalAnd ¶
LogicalAnd returns the element-wise logical AND operation.
func (*Builder) LogicalNot ¶
LogicalNot returns the Op that represents the output of the corresponding operation.
func (*Builder) LogicalXor ¶
LogicalXor returns the element-wise logical XOR operator.
func (*Builder) Logistic ¶
Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.
func (*Builder) Mul ¶
Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation).
func (*Builder) NotEqual ¶
NotEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).
func (*Builder) NotEqualTotalOrder ¶
NotEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).
The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
func (*Builder) Pad ¶
func (b *Builder) Pad(x, fillValue backends.Op, axesConfig ...backends.PadAxis) (backends.Op, error)
Pad injects padding on the start, end, or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.
func (*Builder) Parameter ¶
Parameter creates an input parameter for the computation.
During the computation's execution this value will need to be fed, in the same order it is created.
func (*Builder) Real ¶
Real return the real part of a complex number. It returns x if the x is a float number.
func (*Builder) ReduceBitwiseAnd ¶
ReduceBitwiseAnd implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceBitwiseOr ¶
ReduceBitwiseOr implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceBitwiseXor ¶
ReduceBitwiseXor implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceLogicalAnd ¶
ReduceLogicalAnd implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceLogicalOr ¶
ReduceLogicalOr implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceLogicalXor ¶
ReduceLogicalXor implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceMax ¶
ReduceMax implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceMin ¶
ReduceMin implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceProduct ¶
ReduceProduct implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceSum ¶
ReduceSum implements the corresponding method of the backends.Builder interface.
func (*Builder) ReduceWindow ¶
func (b *Builder) ReduceWindow(x backends.Op, reductionType backends.ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (backends.Op, error)
ReduceWindow runs a reduction function of the type given by reductionType, it can be either ReduceMaxNode, ReduceSumNode or ReduceMultiplyNode.
The parameter windowDimensions must be set and have a value for each axis. If strides is nil, it's assumed to be the same as windowDimensions -- that is, the strides jump a window at a time. If baseDilations, windowDilations are nil, they are assumed to be 1 (no dilation). If paddings is nil, they are assumed to be 0.
func (*Builder) Rem ¶
Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation.
func (*Builder) RngBitGenerator ¶
func (b *Builder) RngBitGenerator(state backends.Op, shape shapes.Shape) (newState backends.Op, values backends.Op, err error)
RngBitGenerator generates the given shape filled with random bits.
It takes as input a state (usually [3]uint64) and returns the updated state and the generated values (with random bits).
Currently, the backend only supports the Philox algorithm. See https://dl.acm.org/doi/10.1145/2063384.2063405
func (*Builder) Round ¶
Round returns the Op that represents the output of the corresponding operation.
func (*Builder) Rsqrt ¶
Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).
func (*Builder) ScatterMax ¶
func (b *Builder) ScatterMax(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)
ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.
func (*Builder) ScatterMin ¶
func (b *Builder) ScatterMin(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)
ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.
func (*Builder) ScatterSum ¶
func (b *Builder) ScatterSum(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)
ScatterSum values from updates pointed by scatterIndices to operand.
func (*Builder) SelectAndScatterMax ¶
func (b *Builder) SelectAndScatterMax(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)
SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand and selects the lowest values to update the output (like ScatterSum)
It selects the values in the window such that it works as reverse for a PoolMax operation.
Note: "Max" refers to the selection. After selected, the values are added into the output position.
See details in https://openxla.org/xla/operation_semantics#selectandscatter
func (*Builder) SelectAndScatterMin ¶
func (b *Builder) SelectAndScatterMin(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)
SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand and selects the lowest values to update the output (like ScatterSum)
It selects the values in the window such that it works as reverse for a PoolMax operation.
Note: "Min" refers to the selection. After selected, values are added into the output position.
See details in https://openxla.org/xla/operation_semantics#selectandscatter
func (*Builder) ShiftLeft ¶
ShiftLeft n bits. It implicitly preserves the sign bit if there is no overflow. So ShiftLeft(-1, 1) = -2.
func (*Builder) ShiftRightArithmetic ¶
ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1.
func (*Builder) ShiftRightLogical ¶
ShiftRightLogical shifts right by n bits, destroying the sign bit.
func (*Builder) Sign ¶
Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.
func (*Builder) Sqrt ¶
Sqrt returns the Op that represents the output of the corresponding operation.
func (*Builder) Sub ¶
Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation).
func (*Builder) Tanh ¶
Tanh returns the Op that represents the output of the corresponding operation.
type DotGeneralConfig ¶
type DotGeneralConfig struct {
// UseTF32 specifies whether to use tf32 (a truncated float32 that NVidia CUDA PJRT is able to use)
// when doing float32 dot general.
UseTF32 bool
}
DotGeneralConfig represents the configuration to use for DotGeneral. StableHLO has lots of options (see github.com/gomlx/stablehlo.DotGeneral), and here is what we expose for now.
type Executable ¶
type Executable struct {
// contains filtered or unexported fields
}
Executable implements backends.Executable for XLA/PJRT github.com/gomlx/gopjrt
func (*Executable) CheckValid ¶
func (e *Executable) CheckValid() error
CheckValid returns an error if the backend or the executable are not ok -- e.g.: if they have been finalized or the builder has already been compiled.
func (*Executable) Execute ¶
Execute the executable on the default device (0). The number and shapes of the inputs must match those returned by Inputs.
func (*Executable) Finalize ¶
func (e *Executable) Finalize()
Finalize immediately frees resources associated to the executable.
func (*Executable) Inputs ¶
func (e *Executable) Inputs() (names []string, inputShapes []shapes.Shape)
Inputs returns the parameters' names and shapes, in order created by the Builder.Parameter calls.
func (*Executable) Outputs ¶
func (e *Executable) Outputs() (outputShapes []shapes.Shape)
Outputs returns the computation's output shapes, in the order given to the Builder.Compile call.