xla

package
v0.13.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 7, 2024 License: Apache-2.0 Imports: 17 Imported by: 1

Documentation

Overview

Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX.

Simply import it with import _ "github.com/gomlx/gomlx/backends/xla" to make it available in your program. It will register itself as an available backend during initialization.

Index

Constants

View Source
const BackendName = "xla"

Variables

View Source
var (
	// DefaultPlugins is the list of plugins to use in preference order, if not otherwise specified.
	DefaultPlugins = []string{"cuda", "cpu"}
)

Functions

func GetAvailablePlugins

func GetAvailablePlugins() []string

GetAvailablePlugins lists the available platforms -- it caches and reuses the result in future calls.

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file) in that order.

If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.

See details in pjrt.AvailablePlugins.

func New

func New(pluginName string) backends.Backend

New returns a new Backend using the config as a configuration. The config string should be the name of the PJRT plugin to use.

Types

type Backend

type Backend struct {
	// contains filtered or unexported fields
}

Backend implements the XLA/PJRT backends.Backend for GoMLX.

func NewWithOptions

func NewWithOptions(pluginName string, options pjrt.NamedValuesMap) *Backend

NewWithOptions creates a XlaBackend with the given client options. It allows more control, not available with the default New constructor.

func (*Backend) AssertValid

func (backend *Backend) AssertValid()

AssertValid will panic if the backend is not valid: if it's nil or has already been finalized.

func (*Backend) BufferDeviceNum

func (backend *Backend) BufferDeviceNum(buffer backends.Buffer) backends.DeviceNum

BufferDeviceNum returns the deviceNum for the buffer.

func (*Backend) BufferFinalize

func (backend *Backend) BufferFinalize(buffer backends.Buffer)

BufferFinalize allows client to inform backend that buffer is no longer needed and associated resources can be freed immediately.

func (*Backend) BufferFromFlatData

func (backend *Backend) BufferFromFlatData(deviceNum backends.DeviceNum, flat any, shape shapes.Shape) backends.Buffer

BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType) to the deviceNum, and returns the corresponding Buffer.

func (*Backend) BufferShape

func (backend *Backend) BufferShape(buffer backends.Buffer) shapes.Shape

BufferShape returns the shape for the buffer.

func (*Backend) BufferToFlatData

func (backend *Backend) BufferToFlatData(buffer backends.Buffer, flat any)

BufferToFlatData transfers the flat values of buffer to the Go flat array. The slice flat must have the exact number of elements required to store the Buffer shape. See BufferShape, and shapes.Shape.Size.

func (*Backend) Builder

func (backend *Backend) Builder(name string) backends.Builder

Builder creates a new builder used to define a new computation.

func (*Backend) Description

func (backend *Backend) Description() string

Description is a longer description of the Backend that can be used to pretty-print.

func (*Backend) Finalize

func (backend *Backend) Finalize()

Finalize releases all the associated resources immediately, and makes the backend invalid.

func (*Backend) Name

func (backend *Backend) Name() string

Name returns the short name of the backend. E.g.: "xla" for the Xla/PJRT plugin.

func (*Backend) NumDevices

func (backend *Backend) NumDevices() backends.DeviceNum

NumDevices return the number of devices available for this Backend.

func (*Backend) SupressLogging

func (backend *Backend) SupressLogging(supressLogging bool) *Backend

SupressLogging during compilation of a graph.

type Builder

type Builder struct {
	// contains filtered or unexported fields
}

Builder implements the backends.Builder interface using github.com/gomlx/gopjrt/xlabuilder

func (*Builder) Abs

func (b *Builder) Abs(x backends.Op) backends.Op

Abs returns the Op that represents the output of the corresponding operation.

func (*Builder) Add

func (b *Builder) Add(x0, x1 backends.Op) backends.Op

Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) And

func (b *Builder) And(x0, x1 backends.Op) backends.Op

And returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) ArgMinMax

func (b *Builder) ArgMinMax(x backends.Op, axis int, outputDType dtypes.DType, isMin bool) backends.Op

ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x. outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input. It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than the rank of x. Examples:

ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0}  // (it chooses the 0 and the -3)
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)

func (*Builder) AssertValid

func (b *Builder) AssertValid()

AssertValid panics if the backend or the builder are not ok -- e.g.: if they have been finalized or the builder has already been compiled.

func (*Builder) BatchNormForInference

func (b *Builder) BatchNormForInference(operand, scale, offset, mean, variance backends.Op, epsilon float32, axis int) backends.Op

BatchNormForInference implements Batch Norm for inference. See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BatchNormForTraining

func (b *Builder) BatchNormForTraining(operand, scale, offset backends.Op, epsilon float32, axis int) (normalized, batchMean, batchVariance backends.Op)

BatchNormForTraining implements Batch Norm for training. See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.

It returns the normalized tensor, the batchMean and the batchVariance.

Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BatchNormGradient

func (b *Builder) BatchNormGradient(operand, scale, mean, variance, gradOutput backends.Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset backends.Op)

BatchNormGradient calculates the BatchNorm gradient. See details in https://openxla.org/xla/operation_semantics#batchnormgrad

The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the batch normalization.

It returns as a tuple with the 3 elements.

Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BitCount added in v0.13.0

func (b *Builder) BitCount(x backends.Op) backends.Op

BitCount returns the number of bits that are set to one.

func (*Builder) Broadcast

func (b *Builder) Broadcast(x backends.Op, prefixDims ...int) backends.Op

Broadcast prefixes dimensions to an array by duplicating the data in the array. See BroadcastInDim for a broadcast in between the axes. The new dimensions dims are inserted on the left, i.e., if prefixDims has values `{a0, ..., aN}` and the operand shape has dimensions {b0, ..., bM} then the shape of the output has dimensions {a0, ..., aN, b0, ..., bM}. The new dimensions id into copies of the operand, i.e.

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

func (*Builder) BroadcastInDim

func (b *Builder) BroadcastInDim(x backends.Op, outputShape shapes.Shape, broadcastAxes []int) backends.Op

BroadcastInDim broadcasts x to an output with the given shape. broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()). The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output. broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only broadcast and introduce new axes in-between. This also requires that the i-th input axis is either 1 or is the same as the output dimension it's broadcasting into. For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:

  • Specifying []int{1} as broadcastAxes will generate output {{1, 2}, {1, 2}}
  • On the other hand, specifying []int{0} as broadcastAxes will generate output {{1 , 1}, {2 , 2}}

func (*Builder) Ceil

func (b *Builder) Ceil(x backends.Op) backends.Op

Ceil returns the Op that represents the output of the corresponding operation.

func (*Builder) Clz

func (b *Builder) Clz(x backends.Op) backends.Op

Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.

func (*Builder) Compile

func (b *Builder) Compile(outputs ...backends.Op) backends.Executable

func (*Builder) Complex

func (b *Builder) Complex(x0, x1 backends.Op) backends.Op

Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Concatenate

func (b *Builder) Concatenate(axis int, operands ...backends.Op) backends.Op

Concatenate results on the given axis. All axes that are not being concatenated must match dimensions. It doesn't work with scalars -- use ExpandDims. If there is only one operand, it is returned and this is a no-op.

func (*Builder) Conj

func (b *Builder) Conj(x backends.Op) backends.Op

Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i

func (*Builder) Constant

func (b *Builder) Constant(flat any, dims ...int) backends.Op

Constant creates a constant in the graph with the given flat values, and the shape defined by dims.

flat must be a slice of a basic type supported -- that can be converted to a DType.

The value is copied into the graph. It's recommended that for very large tensors, even if constants, that they are passed as side inputNodes (or variables, see context package) instead.

func (*Builder) ConvGeneralDilated

func (b *Builder) ConvGeneralDilated(operand, filter backends.Op, axes backends.ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilation, filterDilation []int, filterGroupCount, batchGroupCount int) backends.Op

ConvGeneralDilated is a generic Convolution operation offered by XLA. featureAxisAfter defines whether the features (aka. channels or depth) axis comes after the spatial dimension. Example: a 2D input can be one of the two:

  • featureAxisAfter=false: input=[batch_size, features, height, width], filter=[output_features, input_features, height, width]
  • featureAxisAfter=true: input=[batch_size, height, width, features], filter=[output_features, height, width, input_features]

Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution. There operand and filter are called lhs and rhs. (XLA documentation is unfortunately poor, much is guess-work). Also useful, https://arxiv.org/pdf/1603.07285v1.pdf.

func (*Builder) ConvertDType

func (b *Builder) ConvertDType(x backends.Op, dtype dtypes.DType) backends.Op

ConvertDType of x to dtype.

func (*Builder) Cos

func (b *Builder) Cos(x backends.Op) backends.Op

Cos returns the Op that represents the output of the corresponding operation.

func (*Builder) Div

func (b *Builder) Div(x0, x1 backends.Op) backends.Op

Div returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Dot

func (b *Builder) Dot(x0, x1 backends.Op) backends.Op

Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands: | Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication | The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or matrix/matrix multiplications. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) DotGeneral

func (b *Builder) DotGeneral(lhs backends.Op, lhsContractingAxes, lhsBatchAxes []int, rhs backends.Op, rhsContractingAxes, rhsBatchAxes []int) backends.Op

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

func (*Builder) DynamicSlice added in v0.11.1

func (b *Builder) DynamicSlice(operand backends.Op, startIndices []backends.Op, sliceDims []int) backends.Op

DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each axis is passed in sliceDims, which specify the slice intervals for each axis: [start, start + size). The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicslice

func (*Builder) DynamicUpdateSlice added in v0.11.1

func (b *Builder) DynamicUpdateSlice(operand, update backends.Op, startIndices []backends.Op) backends.Op

DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten at startIndices. The shape of update determines the shape of the sub-array of the result which is updated. The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice

func (*Builder) Equal

func (b *Builder) Equal(x0, x1 backends.Op) backends.Op

Equal performs element-wise equality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) EqualTotalOrder

func (b *Builder) EqualTotalOrder(x0, x1 backends.Op) backends.Op

EqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Erf added in v0.12.0

func (b *Builder) Erf(x backends.Op) backends.Op

Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.

func (*Builder) Exp

func (b *Builder) Exp(x backends.Op) backends.Op

Exp returns the Op that represents the output of the corresponding operation.

func (*Builder) Expm1

func (b *Builder) Expm1(x backends.Op) backends.Op

Expm1 returns the Op that represents the output of the corresponding operation.

func (*Builder) FFT

func (b *Builder) FFT(operand backends.Op, fftType backends.FFTType, fftLength []int) backends.Op

FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions. See documentation in https://www.tensorflow.org/xla/operation_semantics. Underlying, CPU FFT is backed by Eigen's TensorFFT and GPU FFT uses cuFFT.

func (*Builder) Floor

func (b *Builder) Floor(x backends.Op) backends.Op

Floor returns the Op that represents the output of the corresponding operation.

func (*Builder) Gather

func (b *Builder) Gather(operand, startIndices backends.Op, indexVectorAxis int, offsetAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) backends.Op

Gather is a powerful but cumbersome Gather operation offered by XLA. Full details in https://www.tensorflow.org/xla/operation_semantics#gather. (Warning: it's poorly described, with many undefined terms) Arguments:

  • startIndices: are the indices we want to gather. There will be one axis with which enumerates the indices in the operand array, typically the last one. All other axes are "batch dimensions" and they will have equivalent axes in the output.
  • indexVectorAxis: typically the last axis of startIndices, so startIndices.Shape.Rank()-1. Usually, one has the dimension of the indexVectorAxis equal to the full rank of the operand. That is: startIndices.Shape.Dimensions[indexVectorAxis] = operand.Shape.Rank() Lets call "one index vector" a value of startIndices formed by a slice across indexVectorAxis.
  • startIndexMap: for each "index vector" from startIndices, this maps each element of the vector goes to which axes of the operand. Typically, this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each "index vector" fully defines an element on the operand. If one is gathering slices of the operand (as opposed to individual values), one can skip some of those axes from startIndexMap, and the index for those axis is considered 0, and set sliceSizes to take the slice one wants (typically the full slice).
  • sliceSizes: the "index vector" described above points to the data in the operand to be gathered. Then sliceSizes indicates how much data to gather. One value per axis of the operand must be set. For gathering individual values, set these all to 1.
  • collapsedSliceAxes: the slice gathered for each "index vector" (with sizes sliceSizes), often has dimension one for most (or all, in case of gathering individual items) axes. collapsedSliceAxes allows one to collapse those axes, so they don't show up in the output. Usually, collapse all axes that are size one. These are axes within the rank of operand (from 0 to operand.Shape.Rank()-1).
  • offsetAxes: for those gathered slices not collapsed (with collapsedSliceAxes), this maps them to a position in the output array. Typically, these will be consecutive numbers starting with indexVectorAxis. So, the output will have the same prefix shape (the "batch dimensions") as the startIndices array, and the suffix shape will be the gathered slices mapped to these `offsetAxes`. There must be one value per axis not collapsed with collapsedSliceAxes -- the value itself is an axis in the output shape.

func (*Builder) GreaterOrEqual

func (b *Builder) GreaterOrEqual(x0, x1 backends.Op) backends.Op

GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterOrEqualTotalOrder

func (b *Builder) GreaterOrEqualTotalOrder(x0, x1 backends.Op) backends.Op

GreaterOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterThan

func (b *Builder) GreaterThan(x0, x1 backends.Op) backends.Op

GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterThanTotalOrder

func (b *Builder) GreaterThanTotalOrder(x0, x1 backends.Op) backends.Op

GreaterThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Identity

func (b *Builder) Identity(x backends.Op) backends.Op

Identity returns an Op whose output is the same as its input. It's a no-op that can serve as a place-holder.

func (*Builder) Imag

func (b *Builder) Imag(x backends.Op) backends.Op

Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.

func (*Builder) Iota

func (b *Builder) Iota(shape shapes.Shape, iotaAxis int) backends.Op

Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].

func (*Builder) IsFinite added in v0.13.0

func (b *Builder) IsFinite(x backends.Op) backends.Op

IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN. It returns an array of boolean values with the same shape as the input, where each element is true if and only if the corresponding input element is finite.

func (*Builder) LessOrEqual

func (b *Builder) LessOrEqual(x0, x1 backends.Op) backends.Op

LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessOrEqualTotalOrder

func (b *Builder) LessOrEqualTotalOrder(x0, x1 backends.Op) backends.Op

LessOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessThan

func (b *Builder) LessThan(x0, x1 backends.Op) backends.Op

LessThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessThanTotalOrder

func (b *Builder) LessThanTotalOrder(x0, x1 backends.Op) backends.Op

LessThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Log

func (b *Builder) Log(x backends.Op) backends.Op

Log returns the Op that represents the output of the corresponding operation.

func (*Builder) Log1p

func (b *Builder) Log1p(x backends.Op) backends.Op

Log1p returns the expression log(x+1).

func (*Builder) LogicalNot

func (b *Builder) LogicalNot(x backends.Op) backends.Op

LogicalNot returns the Op that represents the output of the corresponding operation.

func (*Builder) Logistic

func (b *Builder) Logistic(x backends.Op) backends.Op

Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.

func (*Builder) Max

func (b *Builder) Max(x0, x1 backends.Op) backends.Op

Max returns the element-wise highest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Min

func (b *Builder) Min(x0, x1 backends.Op) backends.Op

Min returns the element-wise smallest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Mul

func (b *Builder) Mul(x0, x1 backends.Op) backends.Op

Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Name

func (b *Builder) Name() string

Name of the computation being built.

func (*Builder) Neg

func (b *Builder) Neg(x backends.Op) backends.Op

Neg returns the Op that represents the output of the corresponding operation.

func (*Builder) NotEqual

func (b *Builder) NotEqual(x0, x1 backends.Op) backends.Op

NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) NotEqualTotalOrder

func (b *Builder) NotEqualTotalOrder(x0, x1 backends.Op) backends.Op

NotEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) OpShape

func (b *Builder) OpShape(op backends.Op) shapes.Shape

OpShape returns the shape of a computation Op.

func (*Builder) Or

func (b *Builder) Or(x0, x1 backends.Op) backends.Op

Or returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Pad

func (b *Builder) Pad(x, fillValue backends.Op, axesConfig ...backends.PadAxis) backends.Op

Pad injects padding on the start, end or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.

func (*Builder) Parameter

func (b *Builder) Parameter(name string, shape shapes.Shape) backends.Op

Parameter creates an input parameter for the computation. During execution of the computation this value will need to be fed, in the same order it is created.

func (*Builder) Pow

func (b *Builder) Pow(x0, x1 backends.Op) backends.Op

Pow returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Real

func (b *Builder) Real(x backends.Op) backends.Op

Real return the real part of a complex number. It returns x if the x is a float number.

func (*Builder) ReduceAnd added in v0.11.1

func (b *Builder) ReduceAnd(x backends.Op, axes ...int) backends.Op

ReduceAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-and of the reduced axes. It only works for booleans. If no axes are given, it reduces the full array.

func (*Builder) ReduceMax

func (b *Builder) ReduceMax(x backends.Op, axes ...int) backends.Op

ReduceMax is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the max value. If no axes are given, it reduces the full array.

func (*Builder) ReduceMin

func (b *Builder) ReduceMin(x backends.Op, axes ...int) backends.Op

ReduceMin is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the min value. If no axes are given, it reduces the full array.

func (*Builder) ReduceOr added in v0.11.1

func (b *Builder) ReduceOr(x backends.Op, axes ...int) backends.Op

ReduceOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the logical-or of the reduced axes. It only works for booleans. If no axes are given, it reduces the full array.

func (*Builder) ReduceProduct

func (b *Builder) ReduceProduct(x backends.Op, axes ...int) backends.Op

ReduceProduct is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the product of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceSum

func (b *Builder) ReduceSum(x backends.Op, axes ...int) backends.Op

ReduceSum is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the sum of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceWindow

func (b *Builder) ReduceWindow(x backends.Op, reductionType backends.ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) backends.Op

func (*Builder) Rem

func (b *Builder) Rem(x0, x1 backends.Op) backends.Op

Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Reshape

func (b *Builder) Reshape(x backends.Op, dimensions ...int) backends.Op

Reshape reshapes x to the new dimensions. Total size cannot change, it's just a "reinterpretation" of the same flat data. The dtype remains the same, see ConvertDType to actually convert the values.

func (*Builder) Reverse

func (b *Builder) Reverse(x backends.Op, axes ...int) backends.Op

Reverse returns x with the values for the given dimensions reversed, that is, the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`. The shape remains the same.

func (*Builder) RngBitGenerator

func (b *Builder) RngBitGenerator(state backends.Op, shape shapes.Shape) (newState, values backends.Op)

RngBitGenerator generates the given shape filled with random bits. It takes as input the current random number generator (RNG) state, see RngState or RngStateFromSeed. The algorithm is hard-coded to use Philox algorithm for now.

It returns the new state of the RNG and the generated values (with random bits) with the given shape.

func (*Builder) Round

func (b *Builder) Round(x backends.Op) backends.Op

Round returns the Op that represents the output of the corresponding operation.

func (*Builder) Rsqrt

func (b *Builder) Rsqrt(x backends.Op) backends.Op

Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).

func (*Builder) ScatterAdd

func (b *Builder) ScatterAdd(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) backends.Op

ScatterAdd values from updates pointed by scatterIndices to operand.

func (*Builder) ScatterMax

func (b *Builder) ScatterMax(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) backends.Op

ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.

func (*Builder) ScatterMin

func (b *Builder) ScatterMin(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) backends.Op

ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.

func (*Builder) SelectAndScatterMax

func (b *Builder) SelectAndScatterMax(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) backends.Op

SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterMax. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) SelectAndScatterMin

func (b *Builder) SelectAndScatterMin(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) backends.Op

SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterMin. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) SelectAndScatterSum

func (b *Builder) SelectAndScatterSum(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) backends.Op

SelectAndScatterSum runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterSum. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) Sign

func (b *Builder) Sign(x backends.Op) backends.Op

Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.

func (*Builder) Sin

func (b *Builder) Sin(x backends.Op) backends.Op

Sin returns the Op that represents the output of the corresponding operation.

func (*Builder) Slice

func (b *Builder) Slice(x backends.Op, starts, limits, strides []int) backends.Op

Slice extracts a sub-array from the input array. The sub-array is of the same rank as the input and contains the values inside a bounding box within the input array where the dimensions and indices of the bounding box are given as arguments to the slice operation. The strides set the input stride of the slice in each axis and must be >= 1. It is optional, and if missing it is assumed to be 1 for every dimension. Examples:

Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}

func (*Builder) Sqrt

func (b *Builder) Sqrt(x backends.Op) backends.Op

Sqrt returns the Op that represents the output of the corresponding operation.

func (*Builder) Sub

func (b *Builder) Sub(x0, x1 backends.Op) backends.Op

Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Tanh

func (b *Builder) Tanh(x backends.Op) backends.Op

Tanh returns the Op that represents the output of the corresponding operation.

func (*Builder) Transpose

func (b *Builder) Transpose(x backends.Op, permutations ...int) backends.Op

Transpose axes of x. There should be one value in permutations for each axis in x. The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].

func (*Builder) Where

func (b *Builder) Where(condition, onTrue, onFalse backends.Op) backends.Op

Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).

func (*Builder) Xor

func (b *Builder) Xor(x0, x1 backends.Op) backends.Op

Xor returns the element-wise logic "and" operator. The op is created on the same XlaBuilder as used for x0 and x1.

type Executable

type Executable struct {
	// contains filtered or unexported fields
}

Executable implements backends.Executable for XLA/PJRT github.com/gomlx/gopjrt

func (*Executable) AssertValid

func (e *Executable) AssertValid()

AssertValid panics if the backend or the executable are not ok -- e.g.: if they have been finalized or the builder has already been compiled.

func (*Executable) Execute

func (e *Executable) Execute(inputs []backends.Buffer, donate []bool) []backends.Buffer

Execute the executable on the default device (0). The number and shapes of the inputs must match those returned by Inputs.

func (*Executable) Finalize

func (e *Executable) Finalize()

Finalize immediately frees resources associated to the executable.

func (*Executable) Inputs

func (e *Executable) Inputs() (names []string, inputShapes []shapes.Shape)

Inputs returns the list of parameters names and shapes, in order created by the Builder.Parameter calls.

func (*Executable) Outputs

func (e *Executable) Outputs() (outputShapes []shapes.Shape)

Outputs returns the list of the shapes of the outputs of the computation, in order given to the Builder.Compile call.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL