xla

package
v0.22.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 22, 2025 License: Apache-2.0 Imports: 17 Imported by: 1

Documentation

Overview

Package xla implements the XLA/PJRT (https://openxla.org/) based backend for GoMLX.

To make it available in your program, import it with:

import _ "github.com/gomlx/gomlx/backends/xla"

It will register itself as an available backend during initialization.

By default, XLA/PJRT backend loads requested plugins after the program starts and specifies the desired plugin name (default to "cpu") using `dlopen`. Now there are cases that one may simply want to pre-link a plugin with the program. There are two options here (at most one can be selected):

  • Pre-link the CPU PJRT plugin statically: this will generate a bigger binary (+ ~200Mb, so slower to build), but allows one to build a static binary that can be deployed without extra dependencies (except the standard C and C++ libraries, usually available in most machines). To enable, build using the tag `pjrt_cpu_static` (e.g.: `go build --tags pjrt_cpu_static ...`), or import `github.com/gomlx/gomlx/backends/xla/cpu/static`. Both methods have the same effect.
  • Pre-link the CPU PJRT plugin dynamically: build with the build tag `pjrt_cpu_dynamic` (e.g.: `go test --tags pjrt_cpu_dynamic ...`), or import `github.com/gomlx/gomlx/backends/xla/cpu/dynamic`. Not much difference from linking the PJRT plugin after the program starts, as default.

Darwin (MacOS): currently dynamic linking XLA/PJRT is not working, so it links the CPU PJRT plugin by default, no need to manually link `github.com/gomlx/gomlx/backends/xla/cpu/static`.

Shared Buffers Support:

XLA/PJRT for CPU allows the "device buffer" (where device=CPU) to be addressed directly, which saves the copy from "host/local tensor" to the "on-device tensor" when executing a computation. This is enabled by default if the plugin is called "cpu". To force advertising support for this for other PJRTs provide the "shared_buffers" option, e.g.: GOMLX_BACKEND="xla:my_pjrt,shared_buffers". Or to force disabling the support, provide the "noshared_buffers" option.

Index

Constants

View Source
const BackendName = "xla"

Variables

View Source
var CPUCapabilities = backends.Capabilities{
	Operations: map[backends.OpType]bool{
		backends.OpTypeParameter:             true,
		backends.OpTypeConstant:              true,
		backends.OpTypeIdentity:              true,
		backends.OpTypeReduceWindow:          true,
		backends.OpTypeRngBitGenerator:       true,
		backends.OpTypeBatchNormForInference: true,
		backends.OpTypeBatchNormForTraining:  true,
		backends.OpTypeBatchNormGradient:     true,
		backends.OpTypeBitCount:              true,

		backends.OpTypeAbs:                      true,
		backends.OpTypeAdd:                      true,
		backends.OpTypeArgMinMax:                true,
		backends.OpTypeBitcast:                  true,
		backends.OpTypeBitwiseAnd:               true,
		backends.OpTypeBitwiseNot:               true,
		backends.OpTypeBitwiseOr:                true,
		backends.OpTypeBitwiseXor:               true,
		backends.OpTypeBroadcast:                true,
		backends.OpTypeBroadcastInDim:           true,
		backends.OpTypeCeil:                     true,
		backends.OpTypeClz:                      true,
		backends.OpTypeComplex:                  true,
		backends.OpTypeConcatenate:              true,
		backends.OpTypeConj:                     true,
		backends.OpTypeConvGeneral:              true,
		backends.OpTypeConvertDType:             true,
		backends.OpTypeCos:                      true,
		backends.OpTypeDiv:                      true,
		backends.OpTypeDot:                      true,
		backends.OpTypeDotGeneral:               true,
		backends.OpTypeDynamicSlice:             true,
		backends.OpTypeDynamicUpdateSlice:       true,
		backends.OpTypeEqual:                    true,
		backends.OpTypeEqualTotalOrder:          true,
		backends.OpTypeErf:                      true,
		backends.OpTypeExp:                      true,
		backends.OpTypeExpm1:                    true,
		backends.OpTypeFFT:                      true,
		backends.OpTypeFloor:                    true,
		backends.OpTypeGather:                   true,
		backends.OpTypeGreaterOrEqual:           true,
		backends.OpTypeGreaterOrEqualTotalOrder: true,
		backends.OpTypeGreaterThan:              true,
		backends.OpTypeGreaterThanTotalOrder:    true,
		backends.OpTypeImag:                     true,
		backends.OpTypeIota:                     true,
		backends.OpTypeIsFinite:                 true,
		backends.OpTypeLessOrEqual:              true,
		backends.OpTypeLessOrEqualTotalOrder:    true,
		backends.OpTypeLessThan:                 true,
		backends.OpTypeLessThanTotalOrder:       true,
		backends.OpTypeLog:                      true,
		backends.OpTypeLog1p:                    true,
		backends.OpTypeLogicalAnd:               true,
		backends.OpTypeLogicalNot:               true,
		backends.OpTypeLogicalOr:                true,
		backends.OpTypeLogicalXor:               true,
		backends.OpTypeLogistic:                 true,
		backends.OpTypeMax:                      true,
		backends.OpTypeMin:                      true,
		backends.OpTypeMul:                      true,
		backends.OpTypeNeg:                      true,
		backends.OpTypeNotEqual:                 true,
		backends.OpTypeNotEqualTotalOrder:       true,
		backends.OpTypePad:                      true,
		backends.OpTypePow:                      true,
		backends.OpTypeReal:                     true,
		backends.OpTypeReduceBitwiseAnd:         true,
		backends.OpTypeReduceBitwiseOr:          true,
		backends.OpTypeReduceBitwiseXor:         true,
		backends.OpTypeReduceLogicalAnd:         true,
		backends.OpTypeReduceLogicalOr:          true,
		backends.OpTypeReduceLogicalXor:         true,
		backends.OpTypeReduceMax:                true,
		backends.OpTypeReduceMin:                true,
		backends.OpTypeReduceProduct:            true,
		backends.OpTypeReduceSum:                true,
		backends.OpTypeRem:                      true,
		backends.OpTypeReshape:                  true,
		backends.OpTypeReverse:                  true,
		backends.OpTypeRound:                    true,
		backends.OpTypeRsqrt:                    true,
		backends.OpTypeScatterMax:               true,
		backends.OpTypeScatterMin:               true,
		backends.OpTypeScatterSum:               true,
		backends.OpTypeSelectAndScatterMax:      true,
		backends.OpTypeSelectAndScatterMin:      true,
		backends.OpTypeSelectAndScatterSum:      true,
		backends.OpTypeShiftLeft:                true,
		backends.OpTypeShiftRightArithmetic:     true,
		backends.OpTypeShiftRightLogical:        true,
		backends.OpTypeSign:                     true,
		backends.OpTypeSin:                      true,
		backends.OpTypeSlice:                    true,
		backends.OpTypeSqrt:                     true,
		backends.OpTypeSub:                      true,
		backends.OpTypeTanh:                     true,
		backends.OpTypeTranspose:                true,
		backends.OpTypeWhere:                    true,
	},

	DTypes: map[dtypes.DType]bool{
		dtypes.Bool:       true,
		dtypes.Int8:       true,
		dtypes.Int16:      true,
		dtypes.Int32:      true,
		dtypes.Int64:      true,
		dtypes.Uint8:      true,
		dtypes.Uint16:     true,
		dtypes.Uint32:     true,
		dtypes.Uint64:     true,
		dtypes.Float16:    true,
		dtypes.Float32:    true,
		dtypes.Float64:    true,
		dtypes.BFloat16:   true,
		dtypes.Complex64:  true,
		dtypes.Complex128: true,
	},
}

CPUCapabilities supported by XLA CPU backends.

This is the base value, and can be copied and specialized by specific PJRT that may not support everything.

View Source
var (
	// DefaultPlugins is the list of plugins to use in preference order, if not otherwise specified.
	DefaultPlugins = []string{"cuda", "cpu"}
)

Functions

func GetAvailablePlugins

func GetAvailablePlugins() []string

GetAvailablePlugins lists the available platforms -- it caches and reuses the result in future calls.

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories, if it is a ":" separated list. If it is not set it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file) in that order.

If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.

See details in pjrt.AvailablePlugins.

func New

func New(config string) (backends.Backend, error)

New returns a new Backend using the config as a configuration. The config string should be the name of the PJRT plugin to use.

Types

type Backend

type Backend struct {
	// contains filtered or unexported fields
}

Backend implements the XLA/PJRT backends.Backend for GoMLX.

func NewWithOptions

func NewWithOptions(config string, options pjrt.NamedValuesMap) (*Backend, error)

NewWithOptions creates a XlaBackend with the given client options. It allows more control, not available with the default New constructor.

func (*Backend) BufferData added in v0.16.0

func (backend *Backend) BufferData(buffer backends.Buffer) (flat any, err error)

BufferData implements backends.Backend interface.

For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.

func (*Backend) BufferDeviceNum

func (backend *Backend) BufferDeviceNum(buffer backends.Buffer) (backends.DeviceNum, error)

BufferDeviceNum returns the deviceNum for the buffer.

func (*Backend) BufferFinalize

func (backend *Backend) BufferFinalize(buffer backends.Buffer) error

BufferFinalize implements backends.DataInterface.

func (*Backend) BufferFromFlatData

func (backend *Backend) BufferFromFlatData(deviceNum backends.DeviceNum, flat any, shape shapes.Shape) (backends.Buffer, error)

BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType) to the deviceNum, and returns the corresponding Buffer.

func (*Backend) BufferShape

func (backend *Backend) BufferShape(buffer backends.Buffer) (shapes.Shape, error)

BufferShape returns the shape for the buffer.

func (*Backend) BufferToFlatData

func (backend *Backend) BufferToFlatData(buffer backends.Buffer, flat any) error

BufferToFlatData transfers the flat values of buffer to the Go flat array. The slice flat must have the exact number of elements required to store the Buffer shape.

See also FlatDataToBuffer, BufferShape, and shapes.Shape.Size.

func (*Backend) Builder

func (backend *Backend) Builder(name string) backends.Builder

Builder creates a new builder used to define a new computation.

func (*Backend) Capabilities added in v0.19.0

func (backend *Backend) Capabilities() backends.Capabilities

Capabilities returns information about what is supported by this backend.

func (*Backend) CheckValid added in v0.19.3

func (backend *Backend) CheckValid() error

CheckValid returns an error if the backend is not valid: if it's nil or has already been finalized.

func (*Backend) Description

func (backend *Backend) Description() string

Description is a longer description of the Backend that can be used to pretty-print.

func (*Backend) Finalize

func (backend *Backend) Finalize()

Finalize releases all the associated resources immediately, and makes the backend invalid.

func (*Backend) HasSharedBuffers added in v0.16.0

func (backend *Backend) HasSharedBuffers() bool

HasSharedBuffers returns whether this PJRT plugin supports "shared buffers". In PJRT that means supporting pjrt.Client.CreateViewOfDeviceBuffer.

func (*Backend) IsFinalized added in v0.20.0

func (backend *Backend) IsFinalized() bool

IsFinalized returns true if the backend is in an invalid state.

func (*Backend) Name

func (backend *Backend) Name() string

Name returns the short name of the backend. E.g.: "xla" for the Xla/PJRT plugin.

func (*Backend) NewSharedBuffer added in v0.16.0

func (backend *Backend) NewSharedBuffer(deviceNum backends.DeviceNum, shape shapes.Shape) (buffer backends.Buffer, flat any, err error)

NewSharedBuffer implements backends.Backend interface.

For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.

func (*Backend) NumDevices

func (backend *Backend) NumDevices() backends.DeviceNum

NumDevices return the number of devices available for this Backend.

func (*Backend) String added in v0.19.3

func (backend *Backend) String() string

String returns Name().

type Builder

type Builder struct {
	// contains filtered or unexported fields
}

Builder implements the backends.Builder interface using github.com/gomlx/gopjrt/xlabuilder

func (*Builder) Abs

func (b *Builder) Abs(x backends.Op) (backends.Op, error)

Abs returns the Op that represents the output of the corresponding operation.

func (*Builder) Add

func (b *Builder) Add(x0, x1 backends.Op) (backends.Op, error)

Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) ArgMinMax

func (b *Builder) ArgMinMax(x backends.Op, axis int, outputDType dtypes.DType, isMin bool) (backends.Op, error)

ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x. outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input. It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than the rank of x. Examples:

ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0}  // (it chooses the 0 and the -3)
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)

func (*Builder) BatchNormForInference

func (b *Builder) BatchNormForInference(operand, scale, offset, mean, variance backends.Op, epsilon float32, axis int) (backends.Op, error)

BatchNormForInference implements Batch Norm for inference. See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BatchNormForTraining

func (b *Builder) BatchNormForTraining(operand, scale, offset backends.Op, epsilon float32, axis int) (normalized, batchMean, batchVariance backends.Op, err error)

BatchNormForTraining implements Batch Norm for training. See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.

It returns the normalized tensor, the batchMean and the batchVariance.

Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BatchNormGradient

func (b *Builder) BatchNormGradient(operand, scale, mean, variance, gradOutput backends.Op, epsilon float32, axis int) (gradOperand, gradScale, gradOffset backends.Op, err error)

BatchNormGradient calculates the BatchNorm gradient. See details in https://openxla.org/xla/operation_semantics#batchnormgrad

The gradOutput is the adjoint gradient, that is, the gradient with respect to the output of the batch normalization.

It returns as a tuple with the 3 elements.

Based on paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func (*Builder) BitCount added in v0.13.0

func (b *Builder) BitCount(x backends.Op) (backends.Op, error)

BitCount returns the number of bits that are set to one.

func (*Builder) Bitcast added in v0.17.1

func (b *Builder) Bitcast(x backends.Op, targetDType dtypes.DType) (backends.Op, error)

Bitcast performs an elementwise bit-cast operation from a dtype to another dtype. The bitcast doesn't "convert" anything, it just reinterprets the bits from x.DType() to the targetDType. If x.DType() and targetDType use the same number of bytes (targetDType.Size() = x.DType().Size()), the dimensions are not changed, simply the dtype is changed. If targetDType.Size() > x.DType().Size(), it requires that x last axis to have a dimension of targetDType.Size() / x.DType().Size(), and the returned shape will trim the last axis. If targetDType.Size() < x.DType().Size(), the returned shape will have an extra axis in the end, with dimension of x.DType().Size() / targetDType.Size(). E.g: Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xdead, 0xbeef}}

func (*Builder) BitwiseAnd added in v0.17.0

func (b *Builder) BitwiseAnd(x0, x1 backends.Op) (backends.Op, error)

BitwiseAnd returns the element-wise bitwise AND operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) BitwiseNot added in v0.17.0

func (b *Builder) BitwiseNot(x backends.Op) (backends.Op, error)

BitwiseNot returns the element-wise bitwise AND operation.

func (*Builder) BitwiseOr added in v0.17.0

func (b *Builder) BitwiseOr(x0, x1 backends.Op) (backends.Op, error)

BitwiseOr returns the element-wise bitwise OR operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) BitwiseXor added in v0.17.0

func (b *Builder) BitwiseXor(x0, x1 backends.Op) (backends.Op, error)

BitwiseXor returns the element-wise bitwise XOR operator. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Broadcast

func (b *Builder) Broadcast(x backends.Op, prefixDims ...int) (backends.Op, error)

Broadcast prefixes dimensions to an array by duplicating the data in the array. See BroadcastInDim for a broadcast in between the axes. The new dimensions dims are inserted on the left, i.e., if prefixDims has values `{a0, ..., aN}` and the operand shape has dimensions {b0, ..., bM} then the shape of the output has dimensions {a0, ..., aN, b0, ..., bM}. The new dimensions id into copies of the operand, i.e.

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

func (*Builder) BroadcastInDim

func (b *Builder) BroadcastInDim(x backends.Op, outputShape shapes.Shape, broadcastAxes []int) (backends.Op, error)

BroadcastInDim broadcasts x to an output with the given shape. broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()). The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output. broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only broadcast and introduce new axes in-between. This also requires that the i-th input axis is either 1 or is the same as the output dimension it's broadcasting into. For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:

  • Specifying []int{1} as broadcastAxes will generate output {{1, 2}, {1, 2}}
  • On the other hand, specifying []int{0} as broadcastAxes will generate output {{1 , 1}, {2 , 2}}

func (*Builder) Ceil

func (b *Builder) Ceil(x backends.Op) (backends.Op, error)

Ceil returns the Op that represents the output of the corresponding operation.

func (*Builder) CheckValid added in v0.19.3

func (b *Builder) CheckValid() error

CheckValid panics if the backend or the builder are not ok -- e.g.: if they have been finalized or the builder has already been compiled.

func (*Builder) Clz

func (b *Builder) Clz(x backends.Op) (backends.Op, error)

Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.

func (*Builder) Compile

func (b *Builder) Compile(outputs ...backends.Op) (backends.Executable, error)

func (*Builder) Complex

func (b *Builder) Complex(x0, x1 backends.Op) (backends.Op, error)

Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Concatenate

func (b *Builder) Concatenate(axis int, operands ...backends.Op) (backends.Op, error)

Concatenate results on the given axis. All axes that are not being concatenated must match dimensions. It doesn't work with scalars -- use ExpandDims. If there is only one operand, it is returned and this is a no-op.

func (*Builder) Conj

func (b *Builder) Conj(x backends.Op) (backends.Op, error)

Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i

func (*Builder) Constant

func (b *Builder) Constant(flat any, dims ...int) (backends.Op, error)

Constant creates a constant in the graph with the given flat values, and the shape defined by dims.

flat must be a slice of a basic type supported -- that can be converted to a DType.

The value is copied into the graph. It's recommended that for very large tensors, even if constants, that they are passed as side inputNodes (or variables, see context package) instead.

func (*Builder) ConvGeneral added in v0.22.0

func (b *Builder) ConvGeneral(input, kernel backends.Op, axes backends.ConvolveAxesConfig, strides []int, paddings [][2]int, inputDilations, kernelDilations []int, channelGroupCount, batchGroupCount int) (backends.Op, error)

ConvGeneral is a generic Convolution operation with support for: - Arbitrary number of spatial axes. - Arbitrary transposition of axes. - Strides and padding. - Dilations of the input. - Dilations of the kernel, aka. atrous convolution. - Channels grouping (on the input channels). - Batch grouping. Some details in https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution. There operand and filter are called lhs and rhs. (XLA documentation is unfortunately poor, much is guess-work). Also useful, https://arxiv.org/pdf/1603.07285v1.pdf. Note:

  • Another common term for "channels" is "features".
  • "Kernel" is also commonly called "weights" or "filters".

func (*Builder) ConvertDType

func (b *Builder) ConvertDType(x backends.Op, dtype dtypes.DType) (backends.Op, error)

ConvertDType of x to dtype.

func (*Builder) Cos

func (b *Builder) Cos(x backends.Op) (backends.Op, error)

Cos returns the Op that represents the output of the corresponding operation.

func (*Builder) Div

func (b *Builder) Div(x0, x1 backends.Op) (backends.Op, error)

Div returns the element-wise division of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Dot

func (b *Builder) Dot(x0, x1 backends.Op) (backends.Op, error)

Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands: | Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication | The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications or matrix/matrix multiplications. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) DotGeneral

func (b *Builder) DotGeneral(lhs backends.Op, lhsContractingAxes, lhsBatchAxes []int, rhs backends.Op, rhsContractingAxes, rhsBatchAxes []int) (backends.Op, error)

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

func (*Builder) DynamicSlice added in v0.11.1

func (b *Builder) DynamicSlice(operand backends.Op, startIndices []backends.Op, sliceDims []int) (backends.Op, error)

DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each axis is passed in sliceDims, which specify the slice intervals for each axis: [start, start + size). The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicslice

func (*Builder) DynamicUpdateSlice added in v0.11.1

func (b *Builder) DynamicUpdateSlice(operand, update backends.Op, startIndices []backends.Op) (backends.Op, error)

DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten at startIndices. The shape of update determines the shape of the sub-array of the result which is updated. The shape of startIndices must be rank == 1, with dimension size equal to the rank of operand. See description in https://openxla.org/xla/operation_semantics#dynamicupdateslice

func (*Builder) Equal

func (b *Builder) Equal(x0, x1 backends.Op) (backends.Op, error)

Equal performs element-wise equality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) EqualTotalOrder

func (b *Builder) EqualTotalOrder(x0, x1 backends.Op) (backends.Op, error)

EqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Erf added in v0.12.0

func (b *Builder) Erf(x backends.Op) (backends.Op, error)

Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.

func (*Builder) Exp

func (b *Builder) Exp(x backends.Op) (backends.Op, error)

Exp returns the Op that represents the output of the corresponding operation.

func (*Builder) Expm1

func (b *Builder) Expm1(x backends.Op) (backends.Op, error)

Expm1 returns the Op that represents the output of the corresponding operation.

func (*Builder) FFT

func (b *Builder) FFT(operand backends.Op, fftType backends.FFTType, fftLength []int) (backends.Op, error)

FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions. See documentation in https://www.tensorflow.org/xla/operation_semantics. Underlying, CPU FFT is backed by Eigen's TensorFFT and GPU FFT uses cuFFT.

func (*Builder) Floor

func (b *Builder) Floor(x backends.Op) (backends.Op, error)

Floor returns the Op that represents the output of the corresponding operation.

func (*Builder) Gather

func (b *Builder) Gather(operand, startIndices backends.Op, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (backends.Op, error)

Gather is a powerful but cumbersome Gather operation offered by XLA. Full details in https://www.tensorflow.org/xla/operation_semantics#gather. (Warning: it's circular and cumbersome) The output of Gather has the same DType of the operand, from where we are pulling the data. It's shape will be composed of 2 parts:

  • Batch axes: they come from the axes of startIndices, except the "indexVectorAxis" (usually the last) that is used as the indices into the operand. (*)
  • "Offset axes": these are axes that come from the operand, the sizes given by sliceSizes. Notice that if sliceSizes for an axis is 1, and that axis feature in the collapsedSliceAxes list, this axis gets omitted in the output.

So in general output.Rank() = startIndices.Rank() - 1 + len(offsetAxes). (*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes). Arguments:

  • operand: the values from where we are gathering. The output DType will follow the operand one.
  • startIndices: are the indices we want to gather. There will be one axis pointed by indexVector axis which enumerates the indices of the slice to be gathered in the operand array (their values are mapped to the axis in the operand according to startIndexMap). All other axes are "batch dimensions" and they will have equivalent axes (same dimensions) in the output.
  • indexVectorAxis: which of the axis in startIndices is collected and used as the start index for slices to be gathered in the operand. It is typically the last axis of startIndices, so startIndices.Shape.Rank()-1. There is a special case where indexVectorAxis == startIndices.Rank() in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
  • offsetOutputAxes: axes in the _output_ (not on the operand) that will hold the "offset slices", slices that are not collapsed. It points in which position (axis) in the output these slices should show up. Any axis in sliceSizes that is > 1 must feature here. Notice all axes in the operand will either become an "offset axis" in the output, if their slice size > 1, of optionally collapsed (or "squeezed") in the output, if their slice size == 1. We map the axes in the output (given in offsetAxes) to the axes in the operand (the axes not present in collapsedSliceAxes) sequentially. One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetAxes).
  • collapsedSliceAxes: for sliceSizes that are 1 in the operand, one may not want to include them in the output. The _operand_ axes included here are marked to be collapsed (removed) in the output. Notice, the corresponding value in sliceSizes must be 1. One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetOutputAxes).
  • startIndexMap: this maps which value in startIndices is used for which axis index in the slice to be gathered. Notice len(startIndexMap) must match the startIndices.Shape().Dimensions[indexVectorAxis]. E.g: if startIndices.shape=(2, 3), indexVectorAxis=1, and operand.rank=4 and startIndexMap=[]int{0, 1, 2}, this mean each row of the startIndices will point to the first 3 axis (0,1 and 2) in operand. In many cases this is [0, 1, 2, ..., operand.Shape.Rank()-1], that is, each "index vector" fully defines an element on the operand. In some this is only a prefix of the operand's rank. For those axis in the operand not explicitly set (so if len(startIndexMap) < operand.Rank()), the corresponding axis start index is considered to be 0, and one sets the sliceSizes to take the slice one wants (typically the full slice).
  • sliceSizes: once the start index from where to gather is resolved, this defines how much data in each axis to gather. The "offset" output axes (see above) will have dimensions equal to this number for not axes that are not collapsed.
  • indicesAreSorted: can be set to true if its guaranteed that startIndices are sorted (in ascending order, after scattering its values according to start_index_map) by the user. This allows for some optimizations in some platforms.

func (*Builder) GreaterOrEqual

func (b *Builder) GreaterOrEqual(x0, x1 backends.Op) (backends.Op, error)

GreaterOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterOrEqualTotalOrder

func (b *Builder) GreaterOrEqualTotalOrder(x0, x1 backends.Op) (backends.Op, error)

GreaterOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterThan

func (b *Builder) GreaterThan(x0, x1 backends.Op) (backends.Op, error)

GreaterThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) GreaterThanTotalOrder

func (b *Builder) GreaterThanTotalOrder(x0, x1 backends.Op) (backends.Op, error)

GreaterThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Identity

func (b *Builder) Identity(x backends.Op) (backends.Op, error)

Identity returns an Op whose output is the same as its input. It's a no-op that can serve as a place-holder.

func (*Builder) Imag

func (b *Builder) Imag(x backends.Op) (backends.Op, error)

Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.

func (*Builder) Iota

func (b *Builder) Iota(shape shapes.Shape, iotaAxis int) (backends.Op, error)

Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].

func (*Builder) IsFinite added in v0.13.0

func (b *Builder) IsFinite(x backends.Op) (backends.Op, error)

IsFinite tests whether each element of operand is finite, i.e., is not positive or negative infinity, and is not NaN. It returns an array of boolean values with the same shape as the input, where each element is true if and only if the corresponding input element is finite.

func (*Builder) LessOrEqual

func (b *Builder) LessOrEqual(x0, x1 backends.Op) (backends.Op, error)

LessOrEqual performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessOrEqualTotalOrder

func (b *Builder) LessOrEqualTotalOrder(x0, x1 backends.Op) (backends.Op, error)

LessOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessThan

func (b *Builder) LessThan(x0, x1 backends.Op) (backends.Op, error)

LessThan performs element-wise comparison, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LessThanTotalOrder

func (b *Builder) LessThanTotalOrder(x0, x1 backends.Op) (backends.Op, error)

LessThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Log

func (b *Builder) Log(x backends.Op) (backends.Op, error)

Log returns the Op that represents the output of the corresponding operation.

func (*Builder) Log1p

func (b *Builder) Log1p(x backends.Op) (backends.Op, error)

Log1p returns the expression log(x+1).

func (*Builder) LogicalAnd added in v0.17.0

func (b *Builder) LogicalAnd(x0, x1 backends.Op) (backends.Op, error)

LogicalAnd returns the element-wise logical AND operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LogicalNot

func (b *Builder) LogicalNot(x backends.Op) (backends.Op, error)

LogicalNot returns the Op that represents the output of the corresponding operation.

func (*Builder) LogicalOr added in v0.17.0

func (b *Builder) LogicalOr(x0, x1 backends.Op) (backends.Op, error)

LogicalOr returns the element-wise logical OR operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) LogicalXor added in v0.17.0

func (b *Builder) LogicalXor(x0, x1 backends.Op) (backends.Op, error)

LogicalXor returns the element-wise logical XOR operator. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Logistic

func (b *Builder) Logistic(x backends.Op) (backends.Op, error)

Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.

func (*Builder) Max

func (b *Builder) Max(x0, x1 backends.Op) (backends.Op, error)

Max returns the element-wise highest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Min

func (b *Builder) Min(x0, x1 backends.Op) (backends.Op, error)

Min returns the element-wise smallest value among the two. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Mul

func (b *Builder) Mul(x0, x1 backends.Op) (backends.Op, error)

Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Name

func (b *Builder) Name() string

Name of the computation being built.

func (*Builder) Neg

func (b *Builder) Neg(x backends.Op) (backends.Op, error)

Neg returns the Op that represents the output of the corresponding operation.

func (*Builder) NotEqual

func (b *Builder) NotEqual(x0, x1 backends.Op) (backends.Op, error)

NotEqual performs element-wise inequality check, returns boolean results with the same dimensions as input. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) NotEqualTotalOrder

func (b *Builder) NotEqualTotalOrder(x0, x1 backends.Op) (backends.Op, error)

NotEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation). The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) OpShape

func (b *Builder) OpShape(op backends.Op) (shapes.Shape, error)

OpShape returns the shape of a computation Op.

func (*Builder) Pad

func (b *Builder) Pad(x, fillValue backends.Op, axesConfig ...backends.PadAxis) (backends.Op, error)

Pad injects padding on the start, end or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.

func (*Builder) Parameter

func (b *Builder) Parameter(name string, shape shapes.Shape) (backends.Op, error)

Parameter creates an input parameter for the computation. During execution of the computation this value will need to be fed, in the same order it is created.

func (*Builder) Pow

func (b *Builder) Pow(x0, x1 backends.Op) (backends.Op, error)

Pow returns the Op that represents the output of the corresponding operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Real

func (b *Builder) Real(x backends.Op) (backends.Op, error)

Real return the real part of a complex number. It returns x if the x is a float number.

func (*Builder) ReduceBitwiseAnd added in v0.17.0

func (b *Builder) ReduceBitwiseAnd(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical And of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceBitwiseOr added in v0.17.0

func (b *Builder) ReduceBitwiseOr(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Or of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceBitwiseXor added in v0.17.0

func (b *Builder) ReduceBitwiseXor(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseXor is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Xor of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceLogicalAnd added in v0.17.0

func (b *Builder) ReduceLogicalAnd(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalAnd is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical And of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceLogicalOr added in v0.17.0

func (b *Builder) ReduceLogicalOr(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalOr is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Or of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceLogicalXor added in v0.17.0

func (b *Builder) ReduceLogicalXor(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalXor is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the bitwise/logical Xor of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceMax

func (b *Builder) ReduceMax(x backends.Op, axes ...int) (backends.Op, error)

ReduceMax is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the max value. If no axes are given, it reduces the full array.

func (*Builder) ReduceMin

func (b *Builder) ReduceMin(x backends.Op, axes ...int) (backends.Op, error)

ReduceMin is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the min value. If no axes are given, it reduces the full array.

func (*Builder) ReduceProduct

func (b *Builder) ReduceProduct(x backends.Op, axes ...int) (backends.Op, error)

ReduceProduct is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the product of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceSum

func (b *Builder) ReduceSum(x backends.Op, axes ...int) (backends.Op, error)

ReduceSum is a shortcut for Reduce with the proper computation and initial value to reduce x on the given axes, by taking the sum of the reduced axes. If no axes are given, it reduces the full array.

func (*Builder) ReduceWindow

func (b *Builder) ReduceWindow(x backends.Op, reductionType backends.ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (backends.Op, error)

func (*Builder) Rem

func (b *Builder) Rem(x0, x1 backends.Op) (backends.Op, error)

Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Reshape

func (b *Builder) Reshape(x backends.Op, dimensions ...int) (backends.Op, error)

Reshape reshapes x to the new dimensions. Total size cannot change, it's just a "reinterpretation" of the same flat data. The dtype remains the same, see ConvertDType to actually convert the values.

func (*Builder) Reverse

func (b *Builder) Reverse(x backends.Op, axes ...int) (backends.Op, error)

Reverse returns x with the values for the given dimensions reversed, that is, the value indexed at `i` will be swapped with the value at indexed `(dimension_size - 1 - i)`. The shape remains the same.

func (*Builder) RngBitGenerator

func (b *Builder) RngBitGenerator(state backends.Op, shape shapes.Shape) (newState, values backends.Op, err error)

RngBitGenerator generates the given shape filled with random bits. It takes as input the current random number generator (RNG) state, see RngState or RngStateFromSeed. The algorithm is hard-coded to use Philox algorithm for now.

It returns the new state of the RNG and the generated values (with random bits) with the given shape.

func (*Builder) Round

func (b *Builder) Round(x backends.Op) (backends.Op, error)

Round returns the Op that represents the output of the corresponding operation.

func (*Builder) Rsqrt

func (b *Builder) Rsqrt(x backends.Op) (backends.Op, error)

Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).

func (*Builder) ScatterMax

func (b *Builder) ScatterMax(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.

func (*Builder) ScatterMin

func (b *Builder) ScatterMin(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.

func (*Builder) ScatterSum added in v0.18.0

func (b *Builder) ScatterSum(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterSum values from updates pointed by scatterIndices to operand.

func (*Builder) SelectAndScatterMax

func (b *Builder) SelectAndScatterMax(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)

SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterMax. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) SelectAndScatterMin

func (b *Builder) SelectAndScatterMin(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)

SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterMin. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) SelectAndScatterSum

func (b *Builder) SelectAndScatterSum(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)

SelectAndScatterSum runs windows (similar to ReduceWindow) over the operand, selects values to updates the output (like ScatterAdd) It selects the values in the window such that it works as reverse for ScatterSum. See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) ShiftLeft added in v0.17.0

func (b *Builder) ShiftLeft(x0, x1 backends.Op) (backends.Op, error)

ShiftLeft n bits. It implicitly preserves the sign bit, if there is no overflow. So ShiftLeft(-1, 1) = -2. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) ShiftRightArithmetic added in v0.17.0

func (b *Builder) ShiftRightArithmetic(x0, x1 backends.Op) (backends.Op, error)

ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) ShiftRightLogical added in v0.17.0

func (b *Builder) ShiftRightLogical(x0, x1 backends.Op) (backends.Op, error)

ShiftRightLogical shifts right by n bits, destroying the sign bit. The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Sign

func (b *Builder) Sign(x backends.Op) (backends.Op, error)

Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.

func (*Builder) Sin

func (b *Builder) Sin(x backends.Op) (backends.Op, error)

Sin returns the Op that represents the output of the corresponding operation.

func (*Builder) Slice

func (b *Builder) Slice(x backends.Op, starts, limits, strides []int) (backends.Op, error)

Slice extracts a sub-array from the input array. The sub-array is of the same rank as the input and contains the values inside a bounding box within the input array where the dimensions and indices of the bounding box are given as arguments to the slice operation. The strides set the input stride of the slice in each axis and must be >= 1. It is optional, and if missing it is assumed to be 1 for every dimension. Examples:

Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}

func (*Builder) Sqrt

func (b *Builder) Sqrt(x backends.Op) (backends.Op, error)

Sqrt returns the Op that represents the output of the corresponding operation.

func (*Builder) Sub

func (b *Builder) Sub(x0, x1 backends.Op) (backends.Op, error)

Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation). The op is created on the same XlaBuilder as used for x0 and x1.

func (*Builder) Tanh

func (b *Builder) Tanh(x backends.Op) (backends.Op, error)

Tanh returns the Op that represents the output of the corresponding operation.

func (*Builder) Transpose

func (b *Builder) Transpose(x backends.Op, permutations ...int) (backends.Op, error)

Transpose axes of x. There should be one value in permutations for each axis in x. The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].

func (*Builder) Where

func (b *Builder) Where(condition, onTrue, onFalse backends.Op) (backends.Op, error)

Where takes element-wise values from onTrue or onFalse depending on the value of condition (expected to be boolean).

type Executable

type Executable struct {
	// contains filtered or unexported fields
}

Executable implements backends.Executable for XLA/PJRT github.com/gomlx/gopjrt

func (*Executable) CheckValid added in v0.19.3

func (e *Executable) CheckValid() error

CheckValid returns an error if the backend or the executable are not ok -- e.g.: if they have been finalized or the builder has already been compiled.

func (*Executable) Execute

func (e *Executable) Execute(inputs []backends.Buffer, donate []bool) ([]backends.Buffer, error)

Execute the executable on the default device (0). The number and shapes of the inputs must match those returned by Inputs.

func (*Executable) Finalize

func (e *Executable) Finalize()

Finalize immediately frees resources associated to the executable.

func (*Executable) Inputs

func (e *Executable) Inputs() (names []string, inputShapes []shapes.Shape)

Inputs returns the parameters' names and shapes, in order created by the Builder.Parameter calls.

func (*Executable) Outputs

func (e *Executable) Outputs() (outputShapes []shapes.Shape)

Outputs returns the computation's output shapes, in the order given to the Builder.Compile call.

Directories

Path Synopsis
cpu
dynamic
Package dynamic links the XLA/PJRT CPU plugin dynamically (as in ".so" libraries) with your binary.
Package dynamic links the XLA/PJRT CPU plugin dynamically (as in ".so" libraries) with your binary.
static
Package static links the XLA/PJRT CPU plugin statically with your binary.
Package static links the XLA/PJRT CPU plugin statically with your binary.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL