stablehlo

package
v0.23.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Sep 21, 2025 License: Apache-2.0 Imports: 22 Imported by: 0

Documentation

Overview

Package stablehlo implements a GoMLX backend using StableHLO (see github.com/gomlx/stablehlo) as a language to talk to PJRT, the C++ engine for XLA (see github.com/gomlx/gopjrt/pjrt).

The backend is registered as "stablehlo", "shlo" or "hlo" (all aliases to the same backend).

Index

Constants

View Source
const BackendName = "stablehlo"

BackendName is the name of the backend.

The stablehlo backend also accepts the "hlo" and "pjrt" aliases.

Variables

View Source
var Capabilities = backends.Capabilities{
	Operations: map[backends.OpType]bool{

		backends.OpTypeParameter: true,
		backends.OpTypeConstant:  true,

		backends.OpTypeAbs:        true,
		backends.OpTypeBitCount:   true,
		backends.OpTypeBitwiseNot: true,
		backends.OpTypeCeil:       true,
		backends.OpTypeClz:        true,
		backends.OpTypeCos:        true,
		backends.OpTypeErf:        true,
		backends.OpTypeExp:        true,
		backends.OpTypeExpm1:      true,
		backends.OpTypeFloor:      true,
		backends.OpTypeIsFinite:   true,
		backends.OpTypeIsNaN:      true,
		backends.OpTypeLog1p:      true,
		backends.OpTypeLog:        true,
		backends.OpTypeLogicalNot: true,
		backends.OpTypeLogistic:   true,
		backends.OpTypeNeg:        true,
		backends.OpTypeRound:      true,
		backends.OpTypeRsqrt:      true,
		backends.OpTypeSign:       true,
		backends.OpTypeSin:        true,
		backends.OpTypeSqrt:       true,
		backends.OpTypeTanh:       true,

		backends.OpTypeAdd:        true,
		backends.OpTypeBitwiseAnd: true,
		backends.OpTypeBitwiseOr:  true,
		backends.OpTypeBitwiseXor: true,
		backends.OpTypeDiv:        true,
		backends.OpTypeLogicalAnd: true,
		backends.OpTypeLogicalOr:  true,
		backends.OpTypeLogicalXor: true,
		backends.OpTypeMax:        true,
		backends.OpTypeMin:        true,
		backends.OpTypeMul:        true,
		backends.OpTypePow:        true,
		backends.OpTypeRem:        true,
		backends.OpTypeSub:        true,

		backends.OpTypeEqual:                    true,
		backends.OpTypeEqualTotalOrder:          true,
		backends.OpTypeGreaterOrEqual:           true,
		backends.OpTypeGreaterOrEqualTotalOrder: true,
		backends.OpTypeGreaterThan:              true,
		backends.OpTypeGreaterThanTotalOrder:    true,
		backends.OpTypeLessOrEqual:              true,
		backends.OpTypeLessOrEqualTotalOrder:    true,
		backends.OpTypeLessThan:                 true,
		backends.OpTypeLessThanTotalOrder:       true,
		backends.OpTypeNotEqual:                 true,
		backends.OpTypeNotEqualTotalOrder:       true,

		backends.OpTypeComplex: true,
		backends.OpTypeConj:    true,
		backends.OpTypeImag:    true,
		backends.OpTypeReal:    true,

		backends.OpTypeArgMinMax:             true,
		backends.OpTypeBatchNormForInference: true,
		backends.OpTypeBatchNormForTraining:  true,
		backends.OpTypeBatchNormGradient:     true,
		backends.OpTypeBitcast:               true,
		backends.OpTypeBroadcast:             true,
		backends.OpTypeBroadcastInDim:        true,
		backends.OpTypeClamp:                 true,
		backends.OpTypeConcatenate:           true,
		backends.OpTypeConvertDType:          true,
		backends.OpTypeConvGeneral:           true,
		backends.OpTypeDynamicSlice:          true,
		backends.OpTypeDynamicUpdateSlice:    true,
		backends.OpTypeDot:                   true,
		backends.OpTypeDotGeneral:            true,
		backends.OpTypeFFT:                   true,
		backends.OpTypeGather:                true,
		backends.OpTypeIdentity:              true,
		backends.OpTypeIota:                  true,
		backends.OpTypePad:                   true,
		backends.OpTypeReduceBitwiseAnd:      true,
		backends.OpTypeReduceBitwiseOr:       true,
		backends.OpTypeReduceBitwiseXor:      true,
		backends.OpTypeReduceLogicalAnd:      true,
		backends.OpTypeReduceLogicalOr:       true,
		backends.OpTypeReduceLogicalXor:      true,
		backends.OpTypeReduceMax:             true,
		backends.OpTypeReduceMin:             true,
		backends.OpTypeReduceProduct:         true,
		backends.OpTypeReduceSum:             true,
		backends.OpTypeReduceWindow:          true,
		backends.OpTypeReshape:               true,
		backends.OpTypeReverse:               true,
		backends.OpTypeRngBitGenerator:       true,
		backends.OpTypeScatterSum:            true,
		backends.OpTypeScatterMax:            true,
		backends.OpTypeScatterMin:            true,
		backends.OpTypeSelectAndScatterMax:   true,
		backends.OpTypeSelectAndScatterMin:   true,
		backends.OpTypeShiftLeft:             true,
		backends.OpTypeShiftRightArithmetic:  true,
		backends.OpTypeShiftRightLogical:     true,
		backends.OpTypeSlice:                 true,
		backends.OpTypeTranspose:             true,
		backends.OpTypeWhere:                 true,
	},

	DTypes: map[dtypes.DType]bool{
		dtypes.Bool:       true,
		dtypes.Int8:       true,
		dtypes.Int16:      true,
		dtypes.Int32:      true,
		dtypes.Int64:      true,
		dtypes.Uint8:      true,
		dtypes.Uint16:     true,
		dtypes.Uint32:     true,
		dtypes.Uint64:     true,
		dtypes.Float32:    true,
		dtypes.Float64:    true,
		dtypes.BFloat16:   true,
		dtypes.Complex64:  true,
		dtypes.Complex128: true,
	},
}

Capabilities of the SimpleGo backends: the set of supported operations and data types.

View Source
var (
	// DefaultPlugins is the list of plugins to use in preference order, if not otherwise specified.
	DefaultPlugins = []string{"cuda", "cpu"}
)

Functions

func GetAvailablePlugins

func GetAvailablePlugins() []string

GetAvailablePlugins lists the available platforms -- it caches and reuses the result in future calls.

Plugins are searched in the PJRT_PLUGIN_LIBRARY_PATH directory -- or directories if it is a ":" separated list. If it is not set, it will search in "/usr/local/lib/gomlx/pjrt" and the standard libraries directories of the system (in linux in LD_LIBRARY_PATH and /etc/ld.so.conf file) in that order.

If there are plugins with the same name but different versions in different directories, it respects the order of the directories given by PJRT_PLUGIN_LIBRARY_PATH or by the system.

See details in pjrt.AvailablePlugins.

func New

func New(config string) (backends.Backend, error)

New returns a new Backend using the config as a configuration. The config string should be the name of the PJRT plugin to use.

func ShapeFromStableHLO

func ShapeFromStableHLO(shape stablehloshapes.Shape) shapes.Shape

ShapeFromStableHLO converts a StableHLO shape to a GomlX shape.

func ShapeToStableHLO

func ShapeToStableHLO(shape shapes.Shape) stablehloshapes.Shape

ShapeToStableHLO converts a GomlX shape to a StableHLO shape.

Types

type Backend

type Backend struct {
	DotGeneralConfig
	// contains filtered or unexported fields
}

Backend implements the XLA/PJRT backends.Backend for GoMLX.

func NewWithOptions

func NewWithOptions(config string, options pjrt.NamedValuesMap) (*Backend, error)

NewWithOptions creates a StableHLO backend with the given client options. It allows more control, not available with the default New constructor.

func (*Backend) BufferData

func (backend *Backend) BufferData(buffer backends.Buffer) (flat any, err error)

BufferData implements backends.Backend interface.

For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.

func (*Backend) BufferDeviceNum

func (backend *Backend) BufferDeviceNum(buffer backends.Buffer) (backends.DeviceNum, error)

BufferDeviceNum returns the deviceNum for the buffer.

func (*Backend) BufferFinalize

func (backend *Backend) BufferFinalize(buffer backends.Buffer) error

BufferFinalize implements backends.DataInterface.

func (*Backend) BufferFromFlatData

func (backend *Backend) BufferFromFlatData(deviceNum backends.DeviceNum, flat any, shape shapes.Shape) (backends.Buffer, error)

BufferFromFlatData transfers data from Go given as a flat slice (of the type corresponding to the shape DType) to the deviceNum, and returns the corresponding Buffer.

func (*Backend) BufferShape

func (backend *Backend) BufferShape(buffer backends.Buffer) (shapes.Shape, error)

BufferShape returns the shape for the buffer.

func (*Backend) BufferToFlatData

func (backend *Backend) BufferToFlatData(buffer backends.Buffer, flat any) error

BufferToFlatData transfers the flat values of buffer to the Go flat array. The slice flat must have the exact number of elements required to store the Buffer shape.

See also FlatDataToBuffer, BufferShape, and shapes.Shape.Size.

func (*Backend) Builder

func (backend *Backend) Builder(name string) backends.Builder

Builder creates a new builder used to define a new computation.

func (*Backend) Capabilities

func (backend *Backend) Capabilities() backends.Capabilities

Capabilities returns information about what is supported by this backend.

func (*Backend) CheckValid

func (backend *Backend) CheckValid() error

CheckValid returns an error if the backend is not valid: if it's nil or has already been finalized.

func (*Backend) Description

func (backend *Backend) Description() string

Description is a longer description of the Backend that can be used to pretty-print.

func (*Backend) Finalize

func (backend *Backend) Finalize()

Finalize releases all the associated resources immediately, and makes the backend invalid.

func (*Backend) HasSharedBuffers

func (backend *Backend) HasSharedBuffers() bool

HasSharedBuffers returns whether this PJRT plugin supports "shared buffers". In PJRT that means supporting pjrt.Client.CreateViewOfDeviceBuffer.

func (*Backend) IsFinalized

func (backend *Backend) IsFinalized() bool

IsFinalized returns true if the backend is in an invalid state.

func (*Backend) Name

func (backend *Backend) Name() string

Name returns the short name of the backend. E.g.: "stablehlo" for the StableHLO/PJRT plugin.

func (*Backend) NewSharedBuffer

func (backend *Backend) NewSharedBuffer(deviceNum backends.DeviceNum, shape shapes.Shape) (buffer backends.Buffer, flat any, err error)

NewSharedBuffer implements backends.Backend interface.

For XLA this means allocating the aligned memory and calling pjrt.Client.CreateViewOfDeviceBuffer to create a buffer that shares the memory.

func (*Backend) NumDevices

func (backend *Backend) NumDevices() backends.DeviceNum

NumDevices return the number of devices available for this Backend.

func (*Backend) String

func (backend *Backend) String() string

String returns Name().

type Builder

type Builder struct {
	notimplemented.Builder
	// contains filtered or unexported fields
}

Builder keeps track of the computation graph being defined.

func (*Builder) Abs

func (b *Builder) Abs(operand backends.Op) (backends.Op, error)

Abs returns the Op that represents the output of the corresponding operation.

It is special-cased here because StableHLO doesn't define the Abs() of complex numbers.

func (*Builder) Add

func (b *Builder) Add(lhs, rhs backends.Op) (backends.Op, error)

Add returns the element-wise sum of the two values. Standard broadcasting rules apply (see documentation).

func (*Builder) ArgMinMax

func (b *Builder) ArgMinMax(x backends.Op, axis int, outputDType dtypes.DType, isMin bool) (backends.Op, error)

ArgMinMax calculates the "argmin" or "argmax" across an axis of the given input array x.

outputDType defines the output of the argmin/argmax, it doesn't need to be the same as the input. It's a form of reduction on the given axis, and that axis goes away. So the rank of the result is one less than the rank of x.

Examples:

ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=1, isMin=true) -> {1, 0}  // (it chooses the 0 and the -3)
ArgMinMax(x={{2, 0, 7}, {-3, 4, 2}}, axis=0, isMin=false) -> {0, 1, 0} // (it choose the 2, 4 and 7)

func (*Builder) BatchNormForInference

func (b *Builder) BatchNormForInference(input, scale, offset, mean, variance backends.Op, epsilon float32, featureAxis int) (backends.Op, error)

BatchNormForInference implements backends.Builder interface.

func (*Builder) BatchNormForTraining

func (b *Builder) BatchNormForTraining(input, scale, offset backends.Op, epsilon float32, featureAxis int) (output, batchMean, batchVar backends.Op, err error)

BatchNormForTraining implements backends.Builder interface.

func (*Builder) BatchNormGradient

func (b *Builder) BatchNormGradient(gradOutput, input, scale, mean, variance backends.Op, epsilon float32, featureAxis int) (gradInput, gradScale, gradOffset backends.Op, err error)

BatchNormGradient implements backends.Builder interface.

func (*Builder) BitCount

func (b *Builder) BitCount(operand backends.Op) (backends.Op, error)

BitCount returns the number of bits that are set to one. Also known as Population Count ("Popcnt") or Hamming Weight.

func (*Builder) Bitcast

func (b *Builder) Bitcast(x backends.Op, targetDType dtypes.DType) (backends.Op, error)

Bitcast implements backends.Builder interface.

func (*Builder) BitwiseAnd

func (b *Builder) BitwiseAnd(lhs, rhs backends.Op) (backends.Op, error)

BitwiseAnd returns the element-wise bitwise AND operation.

func (*Builder) BitwiseNot

func (b *Builder) BitwiseNot(operand backends.Op) (backends.Op, error)

BitwiseNot returns the element-wise bitwise AND operation.

func (*Builder) BitwiseOr

func (b *Builder) BitwiseOr(lhs, rhs backends.Op) (backends.Op, error)

BitwiseOr returns the element-wise bitwise OR operation.

func (*Builder) BitwiseXor

func (b *Builder) BitwiseXor(lhs, rhs backends.Op) (backends.Op, error)

BitwiseXor returns the element-wise bitwise XOR operator.

func (*Builder) Broadcast

func (b *Builder) Broadcast(x backends.Op, prefixDims ...int) (backends.Op, error)

Broadcast implements the backends.Builder interface.

func (*Builder) BroadcastInDim

func (b *Builder) BroadcastInDim(x backends.Op, outputShape shapes.Shape, broadcastAxes []int) (backends.Op, error)

BroadcastInDim broadcasts x to an output with the given shape. broadcastAxes has an output axes value for each x axes (len(broadcastAxes) == x.Shape.Rank()). The i-th axis of x is mapped to the broadcastAxes[i]-th dimension of the output. broadcastAxes must be also increasing: this operation cannot be used to transpose axes, it will only broadcast and introduce new axes in-between. This also requires that the i-th input axis is either 1 or is the same as the output dimension it's broadcasting into. For example, say operand `x = (s32)[2]{1, 2}`; outputShape = `(s32)[2,2]`:

  • Specifying []int{1} as broadcastAxes will generate output {{1, 2}, {1, 2}}
  • On the other hand, specifying []int{0} as broadcastAxes will generate output {{1 , 1}, {2 , 2}}

func (*Builder) Ceil

func (b *Builder) Ceil(operand backends.Op) (backends.Op, error)

Ceil returns the Op that represents the output of the corresponding operation.

func (*Builder) CheckValid

func (b *Builder) CheckValid() error

CheckValid returns an error if the backend or the builder are not ok.

E.g.: they have been finalized or the builder has already been compiled.

func (*Builder) Clamp

func (b *Builder) Clamp(min, a backends.Op, max backends.Op) (backends.Op, error)

Clamp returns the element-wise clamping operation.

The values max and min can either be a scalar or have the same shape as x.

func (*Builder) Clz

func (b *Builder) Clz(operand backends.Op) (backends.Op, error)

Clz returns element-wise the "count leading zeros" bits of input node x -- for integer values.

func (*Builder) Compile

func (b *Builder) Compile(outputs ...backends.Op) (backends.Executable, error)

func (*Builder) Complex

func (b *Builder) Complex(lhs, rhs backends.Op) (backends.Op, error)

Complex returns the complex number taking x0 as the real part and x1 as the imaginary part. The real (x0) and imaginary (x1) must have the same dtype, and they must be either `dtypes.Float32` or `dtypes.Float64`. The output will be either `dtypes.Complex64` or `dtypes.Complex128`, depending on x0 and x1 dtypes. The shapes of `real` or `imaginary` must be the same, or one must be a scalar, in which case the value is broadcast to every other value.

func (*Builder) Concatenate

func (b *Builder) Concatenate(axis int, operands ...backends.Op) (backends.Op, error)

Concatenate operands on the given axis.

All axes that are not being concatenated must match dimensions, except on the axes being concatenated. It doesn't work with scalars -- use ExpandAxes. If there is only one operand, it is returned and this is a no-op.

func (*Builder) Conj

func (b *Builder) Conj(operand backends.Op) (backends.Op, error)

Conj returns the conjugate of a complex number. E.g: Conj(1+3i) = 1-3i

func (*Builder) Constant

func (b *Builder) Constant(flat any, dimensions ...int) (backends.Op, error)

Constant creates a constant in the graph with the given flat values and the shape defined by the dimensions.

The flat value must be a slice of a basic type supported -- that can be converted to a DType.

The value is copied into the graph. It's recommended that for very large tensors, even if constants, that they are passed as side inputNodes (or variables, see context package) instead.

func (*Builder) ConvGeneral

func (b *Builder) ConvGeneral(input, kernel backends.Op,
	axes backends.ConvolveAxesConfig, strides []int, paddings [][2]int,
	inputDilations, kernelDilations []int, channelGroupCount, batchGroupCount int) (backends.Op, error)

ConvGeneral implements the backends.Builder interface.

func (*Builder) ConvertDType

func (b *Builder) ConvertDType(x backends.Op, dtype dtypes.DType) (backends.Op, error)

ConvertDType implements backends.Builder interface.

func (*Builder) Cos

func (b *Builder) Cos(operand backends.Op) (backends.Op, error)

Cos returns the Op that represents the output of the corresponding operation.

func (*Builder) Div

func (b *Builder) Div(lhs, rhs backends.Op) (backends.Op, error)

Div returns the element-wise division of the two values. Standard broadcasting rules apply (see documentation).

func (*Builder) Dot

func (b *Builder) Dot(lhs, rhs backends.Op) (backends.Op, error)

Dot returns the "dot product" operation. The exact semantics of this operation depend on the ranks of the operands: | Input | Output | Semantics | | vector [n] dot vector [n] | scalar | vector dot product | | matrix [m x k] dot vector [k] | vector [m] matrix-vector multiplication | | matrix [m x k] dot matrix [k x n] | matrix [m x n] | matrix-matrix multiplication | The operation performs sum of products over the second dimension of x0 (or the first if it has rank 1) and the first dimension of x1. These are the "contracted" dimensions. The contracted dimensions of x0 and x1 must be of the same size. In practice, it can be used to perform dot products between vectors, vector/matrix multiplications, or matrix/matrix multiplications.

func (*Builder) DotGeneral

func (b *Builder) DotGeneral(lhs backends.Op, lhsContractingAxes, lhsBatchAxes []int, rhs backends.Op, rhsContractingAxes []int, rhsBatchAxes []int) (backends.Op, error)

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

func (*Builder) DynamicSlice

func (b *Builder) DynamicSlice(operand backends.Op, startIndices []backends.Op, sliceDims []int) (backends.Op, error)

DynamicSlice extracts a slice from the operand at the startIndices position and the given sliceSizes.

- operand: tensor from where to take the slice. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.

The startIndices are adjusted as follows:

adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - sliceSizes[i])

See description in https://openxla.org/xla/operation_semantics#dynamicslice

func (*Builder) DynamicUpdateSlice

func (b *Builder) DynamicUpdateSlice(operand, update backends.Op, startIndices []backends.Op) (backends.Op, error)

DynamicUpdateSlice updates the operand with the values given in update, at the position given by startIndices.

- operand: original value that to be updated. - update: values to "paste" on top of operand, at position startIndices. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.

It returns a value with the same shape as the operand, with the values updated.

The startIndices are adjusted as follows:

adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - update.Dimensions[i])

func (*Builder) Equal

func (b *Builder) Equal(lhs, rhs backends.Op) (backends.Op, error)

Equal returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) EqualTotalOrder

func (b *Builder) EqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

EqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) Erf

func (b *Builder) Erf(operand backends.Op) (backends.Op, error)

Erf returns the "error function", defined as erf(x) = 2/Pi * \int_{0}^{x}{e^{-t^2}dt}.

func (*Builder) Exp

func (b *Builder) Exp(operand backends.Op) (backends.Op, error)

Exp returns the Op that represents the output of the corresponding operation.

func (*Builder) Expm1

func (b *Builder) Expm1(operand backends.Op) (backends.Op, error)

Expm1 returns the Op that represents the output of the corresponding operation.

func (*Builder) FFT

func (b *Builder) FFT(x backends.Op, fftType backends.FFTType, fftLength []int) (backends.Op, error)

FFT implements the Fast Fourier Transform operation. fftType specifies the type of FFT operation to perform. fftLength specifies the length of the transform for each axis.

func (*Builder) Floor

func (b *Builder) Floor(operand backends.Op) (backends.Op, error)

Floor returns the Op that represents the output of the corresponding operation.

func (*Builder) Gather

func (b *Builder) Gather(operand, startIndices backends.Op, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (backends.Op, error)

Gather is a powerful but cumbersome Gather operation. See details in the backend.

Notice GoMLX backend Gather operation doesn't support batching axes, which StableHLO does. For compatibility, we simply leave them empty.

func (*Builder) GreaterOrEqual

func (b *Builder) GreaterOrEqual(lhs, rhs backends.Op) (backends.Op, error)

GreaterOrEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) GreaterOrEqualTotalOrder

func (b *Builder) GreaterOrEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

GreaterOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) GreaterThan

func (b *Builder) GreaterThan(lhs, rhs backends.Op) (backends.Op, error)

GreaterThan returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) GreaterThanTotalOrder

func (b *Builder) GreaterThanTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

GreaterThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) Identity

func (b *Builder) Identity(x backends.Op) (backends.Op, error)

Identity returns an Op whose output is the same as its input. It's a no-op that can serve as a place-holder.

func (*Builder) Imag

func (b *Builder) Imag(operand backends.Op) (backends.Op, error)

Imag returns the imaginary part of a complex number. It returns 0 if the x is a float number.

func (*Builder) Iota

func (b *Builder) Iota(shape shapes.Shape, iotaAxis int) (backends.Op, error)

Iota implements backends.Builder interface.

func (*Builder) IsFinite

func (b *Builder) IsFinite(operand backends.Op) (backends.Op, error)

IsFinite tests whether each element of operand is finite, i.e., if it is not positive nor negative infinity, and it is not NaN. It returns the same shape as the input, but with boolean values where each element is true if and only if the corresponding input element is finite.

func (*Builder) IsNaN

func (b *Builder) IsNaN(x backends.Op) (backends.Op, error)

IsNaN implements backends.Builder interface.

func (*Builder) LessOrEqual

func (b *Builder) LessOrEqual(lhs, rhs backends.Op) (backends.Op, error)

LessOrEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) LessOrEqualTotalOrder

func (b *Builder) LessOrEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

LessOrEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) LessThan

func (b *Builder) LessThan(lhs, rhs backends.Op) (backends.Op, error)

LessThan returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) LessThanTotalOrder

func (b *Builder) LessThanTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

LessThanTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) Log

func (b *Builder) Log(operand backends.Op) (backends.Op, error)

Log returns the Op that represents the output of the corresponding operation.

func (*Builder) Log1p

func (b *Builder) Log1p(operand backends.Op) (backends.Op, error)

Log1p returns the expression log(x+1).

func (*Builder) LogicalAnd

func (b *Builder) LogicalAnd(lhs, rhs backends.Op) (backends.Op, error)

LogicalAnd returns the element-wise logical AND operation.

func (*Builder) LogicalNot

func (b *Builder) LogicalNot(operand backends.Op) (backends.Op, error)

LogicalNot returns the Op that represents the output of the corresponding operation.

func (*Builder) LogicalOr

func (b *Builder) LogicalOr(lhs, rhs backends.Op) (backends.Op, error)

LogicalOr returns the element-wise logical OR operation.

func (*Builder) LogicalXor

func (b *Builder) LogicalXor(lhs, rhs backends.Op) (backends.Op, error)

LogicalXor returns the element-wise logical XOR operator.

func (*Builder) Logistic

func (b *Builder) Logistic(operand backends.Op) (backends.Op, error)

Logistic returns the element-wise expression 1/(1+exp(-x)). Also known as the Sigmoid function.

func (*Builder) Max

func (b *Builder) Max(lhs, rhs backends.Op) (backends.Op, error)

Max returns the element-wise highest value among the two.

func (*Builder) Min

func (b *Builder) Min(lhs, rhs backends.Op) (backends.Op, error)

Min returns the element-wise smallest value among the two.

func (*Builder) Mul

func (b *Builder) Mul(lhs, rhs backends.Op) (backends.Op, error)

Mul returns the element-wise multiplication of the two values. Standard broadcasting rules apply (see documentation).

func (*Builder) Neg

func (b *Builder) Neg(operand backends.Op) (backends.Op, error)

Neg returns the Op that represents the output of the corresponding operation.

func (*Builder) NotEqual

func (b *Builder) NotEqual(lhs, rhs backends.Op) (backends.Op, error)

NotEqual returns the element-wise operation. Standard broadcasting rules apply (see documentation).

func (*Builder) NotEqualTotalOrder

func (b *Builder) NotEqualTotalOrder(lhs, rhs backends.Op) (backends.Op, error)

NotEqualTotalOrder returns the element-wise operation. Standard broadcasting rules apply (see documentation).

The "TotalOrder" version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.

func (*Builder) OpShape

func (b *Builder) OpShape(op backends.Op) (shapes.Shape, error)

OpShape returns the shape of a computation Op.

func (*Builder) Pad

func (b *Builder) Pad(x, fillValue backends.Op, axesConfig ...backends.PadAxis) (backends.Op, error)

Pad injects padding on the start, end, or interior (in between each element) of the given operand. There must be at most `operand.Rank()` axesConfig values. Missing PadAxis are assumed to be zeros, that is, no padding for those axes.

func (*Builder) Parameter

func (b *Builder) Parameter(name string, shape shapes.Shape) (backends.Op, error)

Parameter creates an input parameter for the computation.

During the computation's execution this value will need to be fed, in the same order it is created.

func (*Builder) Pow

func (b *Builder) Pow(lhs, rhs backends.Op) (backends.Op, error)

Pow returns the Op that represents the output of the corresponding operation.

func (*Builder) Real

func (b *Builder) Real(operand backends.Op) (backends.Op, error)

Real return the real part of a complex number. It returns x if the x is a float number.

func (*Builder) ReduceBitwiseAnd

func (b *Builder) ReduceBitwiseAnd(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseAnd implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceBitwiseOr

func (b *Builder) ReduceBitwiseOr(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseOr implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceBitwiseXor

func (b *Builder) ReduceBitwiseXor(x backends.Op, axes ...int) (backends.Op, error)

ReduceBitwiseXor implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceLogicalAnd

func (b *Builder) ReduceLogicalAnd(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalAnd implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceLogicalOr

func (b *Builder) ReduceLogicalOr(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalOr implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceLogicalXor

func (b *Builder) ReduceLogicalXor(x backends.Op, axes ...int) (backends.Op, error)

ReduceLogicalXor implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceMax

func (b *Builder) ReduceMax(x backends.Op, axes ...int) (backends.Op, error)

ReduceMax implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceMin

func (b *Builder) ReduceMin(x backends.Op, axes ...int) (backends.Op, error)

ReduceMin implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceProduct

func (b *Builder) ReduceProduct(x backends.Op, axes ...int) (backends.Op, error)

ReduceProduct implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceSum

func (b *Builder) ReduceSum(x backends.Op, axes ...int) (backends.Op, error)

ReduceSum implements the corresponding method of the backends.Builder interface.

func (*Builder) ReduceWindow

func (b *Builder) ReduceWindow(x backends.Op, reductionType backends.ReduceOpType, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (backends.Op, error)

ReduceWindow runs a reduction function of the type given by reductionType, it can be either ReduceMaxNode, ReduceSumNode or ReduceMultiplyNode.

The parameter windowDimensions must be set and have a value for each axis. If strides is nil, it's assumed to be the same as windowDimensions -- that is, the strides jump a window at a time. If baseDilations, windowDilations are nil, they are assumed to be 1 (no dilation). If paddings is nil, they are assumed to be 0.

func (*Builder) Rem

func (b *Builder) Rem(lhs, rhs backends.Op) (backends.Op, error)

Rem returns the remainder operation, also known as modulo (or Mod for short). Notice despite the name XLA implements Mod not IEEE754 Remainder operation.

func (*Builder) Reshape

func (b *Builder) Reshape(x backends.Op, dimensions ...int) (backends.Op, error)

Reshape implements backends.Builder interface.

func (*Builder) Reverse

func (b *Builder) Reverse(x backends.Op, axes ...int) (backends.Op, error)

Reverse implements the backends.Builder interface.

func (*Builder) RngBitGenerator

func (b *Builder) RngBitGenerator(state backends.Op, shape shapes.Shape) (newState backends.Op, values backends.Op, err error)

RngBitGenerator generates the given shape filled with random bits.

It takes as input a state (usually [3]uint64) and returns the updated state and the generated values (with random bits).

Currently, the backend only supports the Philox algorithm. See https://dl.acm.org/doi/10.1145/2063384.2063405

func (*Builder) Round

func (b *Builder) Round(operand backends.Op) (backends.Op, error)

Round returns the Op that represents the output of the corresponding operation.

func (*Builder) Rsqrt

func (b *Builder) Rsqrt(operand backends.Op) (backends.Op, error)

Rsqrt returns the element-wise reciprocal of square root operation 1/sqrt(x).

func (*Builder) ScatterMax

func (b *Builder) ScatterMax(operand, scatterIndices, updates backends.Op, indexVectorAxis int,
	updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterMax scatter values from updates pointed by scatterIndices to operand, by taking the Max.

func (*Builder) ScatterMin

func (b *Builder) ScatterMin(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterMin scatter values from updates pointed by scatterIndices to operand, by taking the Min.

func (*Builder) ScatterSum

func (b *Builder) ScatterSum(operand, scatterIndices, updates backends.Op, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int, indicesAreSorted, uniqueIndices bool) (backends.Op, error)

ScatterSum values from updates pointed by scatterIndices to operand.

func (*Builder) SelectAndScatterMax

func (b *Builder) SelectAndScatterMax(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)

SelectAndScatterMax runs windows (similar to ReduceWindow) over the operand and selects the lowest values to update the output (like ScatterSum)

It selects the values in the window such that it works as reverse for a PoolMax operation.

Note: "Max" refers to the selection. After selected, the values are added into the output position.

See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) SelectAndScatterMin

func (b *Builder) SelectAndScatterMin(operand, source backends.Op, windowDimensions, windowStrides []int, paddings [][2]int) (backends.Op, error)

SelectAndScatterMin runs windows (similar to ReduceWindow) over the operand and selects the lowest values to update the output (like ScatterSum)

It selects the values in the window such that it works as reverse for a PoolMax operation.

Note: "Min" refers to the selection. After selected, values are added into the output position.

See details in https://openxla.org/xla/operation_semantics#selectandscatter

func (*Builder) ShiftLeft

func (b *Builder) ShiftLeft(lhs, rhs backends.Op) (backends.Op, error)

ShiftLeft n bits. It implicitly preserves the sign bit if there is no overflow. So ShiftLeft(-1, 1) = -2.

func (*Builder) ShiftRightArithmetic

func (b *Builder) ShiftRightArithmetic(lhs, rhs backends.Op) (backends.Op, error)

ShiftRightArithmetic shifts right by n bits, preserving the sign bit. So ShiftRight(-2, 1) = -1.

func (*Builder) ShiftRightLogical

func (b *Builder) ShiftRightLogical(lhs, rhs backends.Op) (backends.Op, error)

ShiftRightLogical shifts right by n bits, destroying the sign bit.

func (*Builder) Sign

func (b *Builder) Sign(operand backends.Op) (backends.Op, error)

Sign returns element-wise +1, +/-0 or -1 depending on the sign of x. It returns NaN if the input is NaN.

func (*Builder) Sin

func (b *Builder) Sin(operand backends.Op) (backends.Op, error)

Sin returns the Op that represents the output of the corresponding operation.

func (*Builder) Slice

func (b *Builder) Slice(x backends.Op, starts, limits, strides []int) (backends.Op, error)

Slice implements backends.Builder interface.

func (*Builder) Sqrt

func (b *Builder) Sqrt(operand backends.Op) (backends.Op, error)

Sqrt returns the Op that represents the output of the corresponding operation.

func (*Builder) Sub

func (b *Builder) Sub(lhs, rhs backends.Op) (backends.Op, error)

Sub returns the element-wise subtraction of the two values. Standard broadcasting rules apply (see documentation).

func (*Builder) Tanh

func (b *Builder) Tanh(operand backends.Op) (backends.Op, error)

Tanh returns the Op that represents the output of the corresponding operation.

func (*Builder) Transpose

func (b *Builder) Transpose(x backends.Op, permutation ...int) (backends.Op, error)

Transpose implements backends.Builder interface. It transposes input tensor x according to the given permutation axes.

func (*Builder) Where

func (b *Builder) Where(condition, onTrue, onFalse backends.Op) (backends.Op, error)

Where implements backends.Builder interface.

type DotGeneralConfig

type DotGeneralConfig struct {
	// UseTF32 specifies whether to use tf32 (a truncated float32 that NVidia CUDA PJRT is able to use)
	// when doing float32 dot general.
	UseTF32 bool
}

DotGeneralConfig represents the configuration to use for DotGeneral. StableHLO has lots of options (see github.com/gomlx/stablehlo.DotGeneral), and here is what we expose for now.

type Executable

type Executable struct {
	// contains filtered or unexported fields
}

Executable implements backends.Executable for XLA/PJRT github.com/gomlx/gopjrt

func (*Executable) CheckValid

func (e *Executable) CheckValid() error

CheckValid returns an error if the backend or the executable are not ok -- e.g.: if they have been finalized or the builder has already been compiled.

func (*Executable) Execute

func (e *Executable) Execute(inputs []backends.Buffer, donate []bool) ([]backends.Buffer, error)

Execute the executable on the default device (0). The number and shapes of the inputs must match those returned by Inputs.

func (*Executable) Finalize

func (e *Executable) Finalize()

Finalize immediately frees resources associated to the executable.

func (*Executable) Inputs

func (e *Executable) Inputs() (names []string, inputShapes []shapes.Shape)

Inputs returns the parameters' names and shapes, in order created by the Builder.Parameter calls.

func (*Executable) Outputs

func (e *Executable) Outputs() (outputShapes []shapes.Shape)

Outputs returns the computation's output shapes, in the order given to the Builder.Compile call.

type Node

type Node struct {
	// contains filtered or unexported fields
}

Node represents the output of an operation and implements a "backends.Op" interface.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL