backends

package
v0.13.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 7, 2024 License: Apache-2.0 Imports: 8 Imported by: 7

Documentation

Overview

Package backends defines the interface to a computation building and execution system needs to implement to be used by GoMLX.

It is based on OpenXLA's API for now, since it's the only implementation.

A backend that doesn't implement every operation, can simply return a "Not implemented" error for any op, and it would still work for computations that don't require those operations.

To simplify error handling, all functions are expected to throw (panic) with a stack trace in case of errors. See package github.com/gomlx/exceptions.

Index

Constants

View Source
const GOMLX_BACKEND = "GOMLX_BACKEND"

GOMLX_BACKEND is the environment variable with the default backend configuration to use.

The format of config is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend specific (e.g.: for xla backend, it is the pjrt plugin name).

Variables

View Source
var DefaultConfig string

DefaultConfig is the name of the default backend configuration to use if specified.

See NewWithConfig for the format of the configuration string.

Functions

func NotImplemented

func NotImplemented()

NotImplemented panics with a not implemented error, for backends that don't implement all ops. It allows users of the backend to capture the exception and handle it differently.

func Register

func Register(name string, constructor Constructor)

Register backend with the given name, and a default constructor that takes as input a configuration string that is passed along to the backend constructor.

To be safe, call Register during initialization of a package.

Types

type Backend

type Backend interface {
	// Name returns the short name of the backend. E.g.: "xla" for the Xla/PJRT plugin.
	Name() string

	// Description is a longer description of the Backend that can be used to pretty-print.
	Description() string

	// NumDevices return the number of devices available for this Backend.
	NumDevices() DeviceNum

	// Builder creates a new builder used to define a new named computation.
	Builder(name string) Builder

	// DataInterface is the sub-interface that defines the API to transfer Buffer to/from accelerators for the backend.
	DataInterface

	// Finalize releases all the associated resources immediately, and makes the backend invalid.
	Finalize()
}

Backend is the API that needs to be implemented by a GoMLX backend.

func New

func New() Backend

New returns a new default Backend.

The default is:

1. The environment GOMLX_BACKEND is used as a configuration if defined. 2. Next the variable DefaultConfig is used as a configuration if defined. 3. The first registered backend is used with an empty configuration.

It panics if not backend was registered.

func NewWithConfig

func NewWithConfig(config string) Backend

NewWithConfig takes a configurations string formated as

The format of config is "<backend_name>:<backend_configuration>". The "<backend_name>" is the name of a registered backend (e.g.: "xla") and "<backend_configuration>" is backend specific (e.g.: for xla backend, it is the pjrt plugin name).

type Buffer

type Buffer any

Buffer represents actual data (a tensor) stored in the accelerator that is actually going to execute the graph. It's used as input/output of computation execution. A Buffer is always associated to a DeviceNum, even if there is only one.

It is opaque from GoMLX perspective, but one of the backend methods take this value as input, and needs

type Builder

type Builder interface {
	// Compile the computation built. This immediately invalidates the Builder and returns an Executable that
	// can now be used to run the computation.
	//
	// It is given the list of outputs.
	Compile(outputs ...Op) Executable

	// Name of the computation being built.
	Name() string

	// OpShape returns the shape of a computation Op.
	OpShape(op Op) shapes.Shape

	// Parameter creates an input parameter for the computation.
	// During execution of the computation this value will need to be fed, in the same order it is created.
	Parameter(name string, shape shapes.Shape) Op

	// Constant creates a constant in the graph with the given flat values, and the shape defined by dims.
	//
	// flat must be a slice of a basic type supported -- that can be converted to a DType.
	//
	// The value is copied into the graph. It's recommended that for very large tensors,
	// even if constants, that they are passed as side inputNodes (or variables, see context package) instead.
	Constant(flat any, dims ...int) Op

	// Identity returns an Op whose output is the same as its input.
	// It's a no-op that can serve as a place-holder.
	Identity(x Op) Op

	// ReduceWindow runs a reduction function of the type given by reductionType,
	// it can be either ReduceMaxNode, ReduceSumNode or ReduceMultiplyNode.
	//
	// The parameter windowDimensions must be set and have a value for each axis.
	// If strides is nil, it's assumed to be the same as windowDimensions -- that is, the strides jump a window at a time.
	// If baseDilations, windowDilations are nil, they are assumed to be 1 (no dilation).
	// If paddings are nil they are assumed to be 0.
	ReduceWindow(x Op, reductionType ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) Op

	// RngBitGenerator generates the given shape filled with random bits.
	// It takes as input the current random number generator (RNG) state, see RngState or RngStateFromSeed.
	// The algorithm is hard-coded to use Philox algorithm for now.
	//
	// It returns the new state of the RNG and the generated values (with random bits) with the given shape.
	RngBitGenerator(state Op, shape shapes.Shape) (newState, values Op)

	// BatchNormForInference implements Batch Norm for inference. See details in
	// https://www.tensorflow.org/xla/operation_semantics#batchnorminference.
	//
	// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
	// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
	BatchNormForInference(operand, scale, offset, mean, variance Op, epsilon float32, axis int) Op

	// BatchNormForTraining implements Batch Norm for training. See details in
	// https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
	//
	// It returns the normalized tensor, the batchMean and the batchVariance.
	//
	// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
	// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
	BatchNormForTraining(operand, scale, offset Op, epsilon float32, axis int) (normalized, batchMean, batchVariance Op)

	// BatchNormGradient calculates the BatchNorm gradient. See details in
	// https://openxla.org/xla/operation_semantics#batchnormgrad
	//
	// The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the
	// batch normalization.
	//
	// It returns  as a tuple with the 3 elements.
	//
	// Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing
	// Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
	BatchNormGradient(operand, scale, mean, variance, gradOutput Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset Op)

	// BitCount returns the number of bits that are set to one.
	BitCount(operand Op) Op

	// StandardOps include automatically generated list of operations for the Builder.
	StandardOps
}

Builder is the minimal set of ops to support building an interface. is the sub-interface that defines the operations that the backend must support.

type Constructor

type Constructor func(config string) Backend

Constructor takes a config string (optionally empty) and returns a Backend.

type ConvolveAxesConfig

type ConvolveAxesConfig struct {
	InputBatch, InputChannel int
	InputSpatial             []int

	KernelInputChannel, KernelOutputChannel int
	KernelSpatial                           []int

	OutputBatch, OutputChannel int
	OutputSpatial              []int
}

ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output has batch and channel axes. Kernel has inputChannel and outputChannel axes.

See Builder.ConvGeneralDilated

func (ConvolveAxesConfig) Clone

Clone returns a deep copy of the structure.

type DataInterface

type DataInterface interface {
	// BufferFinalize allows client to inform backend that buffer is no longer needed and associated resources can be
	// freed immediately.
	BufferFinalize(buffer Buffer)

	// BufferShape returns the shape for the buffer.
	BufferShape(buffer Buffer) shapes.Shape

	// BufferDeviceNum returns the deviceNum for the buffer.
	BufferDeviceNum(buffer Buffer) DeviceNum

	// BufferToFlatData transfers the flat values of buffer to the Go flat array. The slice flat must have
	// the exact number of elements required to store the Buffer shape. See BufferShape, and shapes.Shape.Size.
	BufferToFlatData(buffer Buffer, flat any)

	// BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType)
	// to the deviceNum, and returns the corresponding Buffer.
	BufferFromFlatData(deviceNum DeviceNum, flat any, shape shapes.Shape) Buffer
}

DataInterface is the sub-interface defines the API to transfer Buffer to/from accelerators for the backend.

type DeviceNum

type DeviceNum int

DeviceNum represents which device holds a buffer, or should execute a computation. It's up to the backend to interpret it, but it should be between 0 and Backend.NumDevices.

type Executable

type Executable interface {
	// Finalize immediately frees resources associated to the executable.
	Finalize()

	// Inputs returns the list of parameters names and shapes, in order created by the Builder.Parameter calls.
	Inputs() (names []string, inputShapes []shapes.Shape)

	// Outputs returns the list of the shapes of the outputs of the computation, in order given to the Builder.Compile call.
	Outputs() (outputShapes []shapes.Shape)

	// Execute the executable on the default device (0).
	// The number and shapes of the inputs must match those returned by Inputs.
	//
	// The inputs marked in donate will become invalid after use.
	// This is useful if the input buffer is no longer needed or if updating a variable
	// so its Buffer space can be reused as an output Buffer.
	//
	// Donated buffers are no longer valid after the call.
	// If donate is nil, it is assumed to be false for all buffers, and no buffer is donated.
	Execute(inputs []Buffer, donate []bool) []Buffer
}

Executable is the API for compiled programs ready to execute.

type FFTType

type FFTType int
const (
	// FFTForward - complex in, complex out.
	FFTForward FFTType = iota

	// FFTInverse - complex in, complex out.
	FFTInverse

	// FFTForwardReal - real in, fft_length / 2 + 1 complex out
	FFTForwardReal

	// FFTInverseReal - fft_length / 2 + 1 complex in
	FFTInverseReal
)

func (FFTType) String

func (i FFTType) String() string

type Op

type Op any

Op represents the output of an operation, during the computation graph building time. It is opaque from GoMLX perspective, but one of the backend methods take this value as input, and needs to be able to implement Backend.OpShape to return its shape.

type PadAxis

type PadAxis struct {
	Start, End, Interior int
}

PadAxis defines the amount of padding preceding one axis (Start), at the end of axis (End) or in between the inputs (Interior). This is used as a parameter for the Pad operation.

type ReduceOpType

type ReduceOpType int

ReduceOpType select among the basic types of reduction supported, see XlaBuilder.ReduceComputation.

const (
	// ReduceOpUndefined is an undefined value.
	ReduceOpUndefined ReduceOpType = iota

	// ReduceOpSum reduces by summing all elements being reduced.
	ReduceOpSum

	// ReduceOpProduct reduces by multiplying all elements being reduced.
	ReduceOpProduct

	// ReduceOpMax reduces by taking the maximum value.
	ReduceOpMax

	// ReduceOpMin reduces by taking the minimum value.
	ReduceOpMin
)

func (ReduceOpType) String

func (i ReduceOpType) String() string

type StandardOps

type StandardOps interface {
	// Abs returns the Op that represents the output of the corresponding operation.
	Abs(x Op) Op

	// Add returns the element-wise sum of the two values.
	// Standard broadcasting rules apply (see documentation).
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Add(x0, x1 Op) Op

	// And returns the element-wise logic "and" operator.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	And(x0, x1 Op) Op

	// ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x.
	// outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input.
	// It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than
	// the rank of x.
	// Examples:
	// 	ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0}  // (it chooses the 0 and the -3)
	// 	ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)
	ArgMinMax(x Op, axis int, outputDType dtypes.DType, isMin bool) Op

	// Broadcast prefixes dimensions to an array by duplicating the data in the array.
	// See BroadcastInDim for a broadcast in between the axes.
	// The new dimensions dims are inserted on the left, i.e., if
	// prefixDims has values `{a0, ..., aN}` and the operand shape
	// has dimensions {b0, ..., bM} then the shape of the output has
	// dimensions {a0, ..., aN, b0, ..., bM}.
	// The new dimensions id into copies of the operand, i.e.
	// 	output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
	Broadcast(x Op, prefixDims ...int) Op

	// BroadcastInDim broadcasts x to an output with the given shape.
	// broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()).
	// The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output.
	// broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only
	// broadcast and introduce new axes in-between.
	// This also requires that the i-th input axis is either 1 or is the same as the
	// output dimension it's broadcasting into.
	// For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:
	//   - Specifying []int{1} as broadcastAxes will generate output
	//     {{1, 2},
	//     {1, 2}}
	//   - On the other hand, specifying []int{0} as broadcastAxes
	//     will generate output
	//     {{1 , 1},
	//     {2 , 2}}
	BroadcastInDim(x Op, outputShape shapes.Shape, broadcastAxes []int) Op

	// Ceil returns the Op that represents the output of the corresponding operation.
	Ceil(x Op) Op

	// Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.
	Clz(x Op) Op

	// Complex returns the complex number taking x0 as the real part and x1 as the imaginary part.
	// The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or
	// `dtypes.Float64`.
	// The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes.
	// The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case
	// the value is broadcast to every other value.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Complex(x0, x1 Op) Op

	// Concatenate results on the given axis.
	// All axes that are not being concatenated must match dimensions.
	// It doesn't work with scalars -- use ExpandDims.
	// If there is only one operand, it is returned and this is a no-op.
	Concatenate(axis int, operands ...Op) Op

	// Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i
	Conj(x Op) Op

	// ConvGeneralDilated is a generic Convolution operation offered by XLA.
	// featureAxisAfter defines whether the features (aka. channels or depth) axis comes after the
	// spatial dimension. Example: a 2D input can be one of the two:
	//   - featureAxisAfter=false: input=[batch_size, features, height, width], filter=[output_features, input_features, height, width]
	//   - featureAxisAfter=true:  input=[batch_size, height, width, features], filter=[output_features, height, width, input_features]
	// Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution.
	// There operand and filter are called lhs and rhs.
	// (XLA documentation is unfortunately poor, much is guess-work).
	// Also useful, https://arxiv.org/pdf/1603.07285v1.pdf.
	ConvGeneralDilated(operand, filter Op, axes ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilation, filterDilation []int, filterGroupCount, batchGroupCount int) Op

	// ConvertDType of x to dtype.
	ConvertDType(x Op, dtype dtypes.DType) Op

	// Cos returns the Op that represents the output of the corresponding operation.
	Cos(x Op) Op

	// Div returns the element-wise subtraction of the two values.
	// Standard broadcasting rules apply (see documentation).
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Div(x0, x1 Op) Op

	// Dot returns the "dot product" operation.
	// The exact semantics of this operation depend on the ranks of the operands:
	// | Input | Output | Semantics |
	// | vector [n] dot vector [n] | scalar | vector dot product |
	// | matrix [m x k] dot vector [k] | vector [m]	matrix-vector multiplication |
	// | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication |
	// The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and
	// the first dimension of x1.
	// These are the "contracted" dimensions.
	// The contracted dimensions of x0 and x1 must be of the same size.
	// In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or
	// matrix/matrix multiplications.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Dot(x0, x1 Op) Op

	// DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications
	// for a general vector product -- a generalized "Einsum". Each axis can be:
	//   - Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions
	//     must match in lhs and rhs.
	//   - Crossed (default), in which case the output is the combination (concatenation) of the
	//     dimensions.
	//   - Contracted (contracting axes), where the output does multiply the values and reduce sum
	//     those dimensions.
	// It follows that the resulting dimension number starts with the batch dimension, then the 'lhs'
	// non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension.
	// It provides the basic means of implementing Einsum.
	DotGeneral(lhs Op, lhsContractingAxes, lhsBatchAxes []int, rhs Op, rhsContractingAxes, rhsBatchAxes []int) Op

	// DynamicSlice extracts a sub-array from the input array at dynamic start_indices.
	// The size of the slice in each axis is passed in sliceDims, which specify the slice
	// intervals for each axis: [start, start + size).
	// The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
	// See description in https://openxla.org/xla/operation_semantics#dynamicslice
	DynamicSlice(operand Op, startIndices []Op, sliceDims []int) Op

	// DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten
	// at startIndices.
	// The shape of update determines the shape of the sub-array of the result which is updated.
	// The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand.
	// See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice
	DynamicUpdateSlice(operand, update Op, startIndices []Op) Op

	// Equal performs element-wise equality check, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Equal(x0, x1 Op) Op

	// EqualTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	EqualTotalOrder(x0, x1 Op) Op

	// Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.
	Erf(x Op) Op

	// Exp returns the Op that represents the output of the corresponding operation.
	Exp(x Op) Op

	// Expm1 returns the Op that represents the output of the corresponding operation.
	Expm1(x Op) Op

	// FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions.
	// See documentation in https://www.tensorflow.org/xla/operation_semantics.
	// Underlying, CPU FFT is backed by Eigen's TensorFFT and GPU FFT uses cuFFT.
	FFT(operand Op, fftType FFTType, fftLength []int) Op

	// Floor returns the Op that represents the output of the corresponding operation.
	Floor(x Op) Op

	// Gather is a powerful but cumbersome Gather operation offered by XLA.
	// Full details in https://www.tensorflow.org/xla/operation_semantics#gather.
	// (Warning: it's poorly described, with many undefined terms)
	// Arguments:
	//   - startIndices: are the indices we want to gather. There will be one axis with which enumerates the indices
	//     in the operand array, typically the last one. All other axes are "batch dimensions" and they will have
	//     equivalent axes in the output.
	//   - indexVectorAxis: typically the last axis of startIndices, so startIndices.Shape.Rank()-1.
	//     Usually, one has the dimension of the indexVectorAxis equal to the full rank of the operand.
	//     That is: startIndices.Shape.Dimensions[indexVectorAxis] = operand.Shape.Rank()
	//     Lets call "one index vector" a value of startIndices formed by a slice across indexVectorAxis.
	//   - startIndexMap: for each "index vector" from startIndices, this maps each element of the vector goes to
	//     which axes of the operand. Typically, this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each
	//     "index vector" fully defines an element on the operand. If one is gathering slices of the operand (as
	//     opposed to individual values), one can skip some of those axes from startIndexMap, and the index for those
	//     axis is considered 0, and set sliceSizes to take the slice one wants (typically the full slice).
	//   - sliceSizes: the "index vector" described above points to the data in the operand to be gathered. Then sliceSizes
	//     indicates how much data to gather. One value per axis of the operand must be set. For gathering individual
	//     values, set these all to 1.
	//   - collapsedSliceAxes: the slice gathered for each "index vector" (with sizes sliceSizes), often has dimension one
	//     for most (or all, in case of gathering individual items) axes. collapsedSliceAxes allows one to collapse those
	//     axes, so they don't show up in the output. Usually, collapse all axes that are size one.
	//     These are axes within the rank of operand (from 0 to operand.Shape.Rank()-1).
	//   - offsetAxes: for those gathered slices not collapsed (with collapsedSliceAxes), this maps them to a position in
	//     the output array. Typically, these will be consecutive numbers starting with indexVectorAxis. So, the output
	//     will have the same prefix shape (the "batch dimensions") as the startIndices array, and the suffix shape will
	//     be the gathered slices mapped to these `offsetAxes`. There must be one value per axis not collapsed with
	//     collapsedSliceAxes -- the value itself is an axis in the output shape.
	Gather(operand, startIndices Op, indexVectorAxis int, offsetAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) Op

	// GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	GreaterOrEqual(x0, x1 Op) Op

	// GreaterOrEqualTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	GreaterOrEqualTotalOrder(x0, x1 Op) Op

	// GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	GreaterThan(x0, x1 Op) Op

	// GreaterThanTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	GreaterThanTotalOrder(x0, x1 Op) Op

	// Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.
	Imag(x Op) Op

	// Iota creates a constant of the given shape with increasing numbers (starting from 0)
	// on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0)
	// returns [[0 0][1 1]].
	Iota(shape shapes.Shape, iotaAxis int) Op

	// IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN.
	// It returns an array of boolean values with the same shape as the input, where each element is true if and only if
	// the corresponding input element is finite.
	IsFinite(x Op) Op

	// LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	LessOrEqual(x0, x1 Op) Op

	// LessOrEqualTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	LessOrEqualTotalOrder(x0, x1 Op) Op

	// LessThan performs element-wise comparison, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	LessThan(x0, x1 Op) Op

	// LessThanTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	LessThanTotalOrder(x0, x1 Op) Op

	// Log returns the Op that represents the output of the corresponding operation.
	Log(x Op) Op

	// Log1p returns the expression log(x+1).
	Log1p(x Op) Op

	// LogicalNot returns the Op that represents the output of the corresponding operation.
	LogicalNot(x Op) Op

	// Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.
	Logistic(x Op) Op

	// Max returns the element-wise highest value among the two.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Max(x0, x1 Op) Op

	// Min returns the element-wise smallest value among the two.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Min(x0, x1 Op) Op

	// Mul returns the element-wise multiplication of the two values.
	// Standard broadcasting rules apply (see documentation).
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Mul(x0, x1 Op) Op

	// Neg returns the Op that represents the output of the corresponding operation.
	Neg(x Op) Op

	// NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	NotEqual(x0, x1 Op) Op

	// NotEqualTotalOrder returns the element-wise operation.
	// Standard broadcasting rules apply (see documentation).
	// The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	NotEqualTotalOrder(x0, x1 Op) Op

	// Or returns the element-wise logic "and" operator.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Or(x0, x1 Op) Op

	// Pad injects padding on the start, end or interior (in between each element) of the given operand.
	// There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros,
	// that is, no padding for those axes.
	Pad(x, fillValue Op, axesConfig ...PadAxis) Op

	// Pow returns the Op that represents the output of the corresponding operation.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Pow(x0, x1 Op) Op

	// Real return the real part of a complex number. It returns x if the x is a float number.
	Real(x Op) Op

	// ReduceAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-and of the reduced axes.
	// It only works for booleans.
	// If no axes are given, it reduces the full array.
	ReduceAnd(x Op, axes ...int) Op

	// ReduceMax is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the max value.
	// If no axes are given, it reduces the full array.
	ReduceMax(x Op, axes ...int) Op

	// ReduceMin is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the min value.
	// If no axes are given, it reduces the full array.
	ReduceMin(x Op, axes ...int) Op

	// ReduceOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-or of the reduced axes.
	// It only works for booleans.
	// If no axes are given, it reduces the full array.
	ReduceOr(x Op, axes ...int) Op

	// ReduceProduct is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the product of the reduced axes.
	// If no axes are given, it reduces the full array.
	ReduceProduct(x Op, axes ...int) Op

	// ReduceSum is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the sum of the reduced axes.
	// If no axes are given, it reduces the full array.
	ReduceSum(x Op, axes ...int) Op

	// Rem returns the remainder operation, also known as modulo (or Mod for short).
	// Notice despite the name XLA implements Mod not IEEE754 Remainder operation.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Rem(x0, x1 Op) Op

	// Reshape reshapes x to the new dimensions.
	// Total size cannot change, it's just a "reinterpretation" of the same flat data.
	// The dtype remains the same, see ConvertDType to actually convert the values.
	Reshape(x Op, dimensions ...int) Op

	// Reverse returns x with the values for the given dimensions reversed, that is,
	// the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`.
	// The shape remains the same.
	Reverse(x Op, axes ...int) Op

	// Round returns the Op that represents the output of the corresponding operation.
	Round(x Op) Op

	// Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).
	Rsqrt(x Op) Op

	// ScatterAdd values from updates pointed by scatterIndices to operand.
	ScatterAdd(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) Op

	// ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.
	ScatterMax(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) Op

	// ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.
	ScatterMin(operand, scatterIndices, updates Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) Op

	// SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
	// It selects the values in the window such that it works as reverse for a PoolMax operation.
	// See details in https://openxla.org/xla/operation_semantics#selectandscatter
	SelectAndScatterMax(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) Op

	// SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
	// It selects the values in the window such that it works as reverse for a PoolMin operation.
	// See details in https://openxla.org/xla/operation_semantics#selectandscatter
	SelectAndScatterMin(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) Op

	// SelectAndScatterSum runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd)
	// It selects the values in the window such that it works as reverse for a PoolSum operation.
	// See details in https://openxla.org/xla/operation_semantics#selectandscatter
	SelectAndScatterSum(operand, source Op, windowDimensions, windowStrides []int, paddings [][2]int) Op

	// Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.
	Sign(x Op) Op

	// Sin returns the Op that represents the output of the corresponding operation.
	Sin(x Op) Op

	// Slice extracts a sub-array from the input array.
	// The sub-array is of the same rank as the input and contains the values inside a bounding box within the input array
	// where the dimensions and indices of the bounding box are given as arguments to the slice operation.
	// The strides set the input stride of the slice in each axis and must be >= 1.
	// It is optional, and if missing it is assumed to be 1 for every dimension.
	// Examples:
	// 	Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
	// 	Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}
	Slice(x Op, starts, limits, strides []int) Op

	// Sqrt returns the Op that represents the output of the corresponding operation.
	Sqrt(x Op) Op

	// Sub returns the element-wise subtraction of the two values.
	// Standard broadcasting rules apply (see documentation).
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Sub(x0, x1 Op) Op

	// Tanh returns the Op that represents the output of the corresponding operation.
	Tanh(x Op) Op

	// Transpose axes of x.
	// There should be one value in permutations for each axis in x.
	// The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].
	Transpose(x Op, permutations ...int) Op

	// Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).
	Where(condition, onTrue, onFalse Op) Op

	// Xor returns the element-wise logic "and" operator.
	// The op is created on the same XlaBuilder as used for x0 and x1.
	Xor(x0, x1 Op) Op
}

Directories

Path Synopsis
Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX.
Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL