shapeinference

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 4, 2025 License: Apache-2.0 Imports: 7 Imported by: 0

README

Package shapeinference is a copy of github.com/gomlx/gomlx/backends/shapeinferenece with a few modifications.

It's kept separate to avoid the dependency on GoMLX -- at the cost of maintaining both in sync.

Documentation

Overview

Package shapeinference calculates the shape resulting from operations and validates its inputs.

This can be useful for new optypes.to test and help plan for buffer space for temporary or output buffers.

It defines a BinaryOp function for shape inference for the majority of binary functions, using the standard broadcasting rules.

The majority of the unary functions don't change the shape, except those that explicitly say that in their name, like Reshape, etc.

For the remainder operations, each one gets its own shape inference function.

Index

Constants

This section is empty.

Variables

View Source
var (
	// BooleanOrBitwiseOperations take booleans as input, aka. logical operations.
	BooleanOrBitwiseOperations = utils.SetWith(
		optypes.And,
		optypes.Or,
		optypes.Xor,
		optypes.Not,
	)

	// BitwiseOperations operates only on integer (binary) numbers and won't work on floats or complex numbers.
	BitwiseOperations = utils.SetWith(
		optypes.Popcnt,
		optypes.ShiftLeft,
		optypes.ShiftRightArithmetic,
		optypes.ShiftRightLogical,
		optypes.CountLeadingZeros,
	)

	// NumberOperations can take any type of number as input: integers, floats, or complex numbers.
	NumberOperations = utils.SetWith(
		optypes.Add,
		optypes.Subtract,
		optypes.Multiply,
		optypes.Divide,
		optypes.Power,
		optypes.Remainder,

		optypes.Abs,
		optypes.Sign,

		optypes.Compare,
	)

	SignedNumberOperations = utils.SetWith(
		optypes.Negate,
	)

	// FloatOperations operates only on float (and not on complex numbers).
	FloatOperations = utils.SetWith(
		optypes.Erf,
		optypes.Logistic,
		optypes.Cosine,
		optypes.Sine,
		optypes.Tanh,
	)

	// FloatOrComplexOperations operates only on float or complex numbers and won't work on integer or boolean values.
	FloatOrComplexOperations = utils.SetWith(
		optypes.Exponential,
		optypes.ExponentialMinusOne,
		optypes.Log,
		optypes.LogPlusOne,
		optypes.Ceil,
		optypes.Floor,
		optypes.RoundNearestEven,
		optypes.Rsqrt,
		optypes.Sqrt,
		optypes.IsFinite,
	)

	// ComplexOperations operates only on complex numbers.
	ComplexOperations = utils.SetWith(
		optypes.Imag,
		optypes.Real,
	)

	// StandardBinaryOperations include all operations that have two operands usually named lhs (left-hand-side) and
	// rhs (right-hand-side) and are usually commutative (invariant to order).
	StandardBinaryOperations = utils.SetWith(
		optypes.Add,
		optypes.Atan2,
		optypes.Subtract,
		optypes.Multiply,
		optypes.Divide,
		optypes.Power,
		optypes.Remainder,
		optypes.And,
		optypes.Or,
		optypes.Xor,
		optypes.Maximum,
		optypes.Minimum,
		optypes.ShiftLeft,
		optypes.ShiftRightArithmetic,
		optypes.ShiftRightLogical,
	)

	// ComparisonOperations include all operations that take two inputs and returns booleans with the results of
	// a comparison.
	// For StableHLO they are converged in only one optypes.Compare, that takes as an attribute the comparison type.
	ComparisonOperations = utils.SetWith(optypes.Compare)

	// StandardUnaryOperations include all operations that have a single operand as input, and the return shape is the
	// same as the input (so no reductions).
	StandardUnaryOperations = utils.SetWith(
		optypes.Not,
		optypes.Popcnt,
		optypes.Cbrt,
		optypes.CountLeadingZeros,
		optypes.Erf,
		optypes.Exponential,
		optypes.ExponentialMinusOne,
		optypes.Log,
		optypes.LogPlusOne,
		optypes.Logistic,
		optypes.Ceil,
		optypes.Floor,
		optypes.RoundNearestEven,
		optypes.RoundNearestAfz,
		optypes.Rsqrt,
		optypes.Sqrt,
		optypes.Cosine,
		optypes.Sine,
		optypes.Tan,
		optypes.Tanh,
		optypes.Abs,
		optypes.Negate,
		optypes.Sign,
	)
)

Functions

func AdjustAxisToRank

func AdjustAxisToRank(axis, rank int) (int, error)

AdjustAxisToRank returns a positive axis, adjusting negative numbers to the correct rank.

func AllGather added in v0.1.0

func AllGather(operand shapes.Shape, replicaGroups [][]int, allGatherDim int) (output shapes.Shape, err error)

AllGather returns the output shape for an all_gather operation.

func AllReduce added in v0.1.0

func AllReduce(operands []shapes.Shape, reductionInputs, reductionOutputs []shapes.Shape, replicaGroups [][]int) (
	outputs []shapes.Shape, err error)

AllReduce returns the output shapes for a collective_all_reduce operation. The output shapes are identical to the operand shapes. It also validates the computation function shapes.

func AllToAll added in v0.1.0

func AllToAll(operand shapes.Shape, replicaGroups [][]int, splitDimension, concatDimension, splitCount int) (output shapes.Shape, err error)

AllToAll returns the output shape for an all_to_all operation.

func ArgMinMax

func ArgMinMax(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error)

ArgMinMax calculates the output shape for an ArgMinMax operation. It will be the shape of the operand minus the "reduce" axis.

func BinaryOp

func BinaryOp(opType optypes.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error)

BinaryOp returns the expected output shape for ops in the StandardBinaryOperations set -- those include all operations that have two operands usually named lhs (left-hand-side) and rhs (right-hand-side), and they are usually commutative (invariant to order).

It returns an error if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching dtypes, or LogicalAnd not having booleans (dtype.Bool) as input.

func BitcastConvert

func BitcastConvert(operand shapes.Shape, targetDType dtypes.DType) (outputShape shapes.Shape, err error)

func BroadcastInDim

func BroadcastInDim(operand, targetShape shapes.Shape, axesMapping []int) error

BroadcastInDim verifies that the arguments are valid. The output shape is already known, so nothing is returned.

The axesMapping is changed in place, replacing negative axes with their positive equivalent.

func Clamp

func Clamp(min, operand, max shapes.Shape) (output shapes.Shape, err error)

Clamp returns the shape resulting from the corresponding operation.

func CollectiveBroadcast added in v0.1.0

func CollectiveBroadcast(operand shapes.Shape, replicaGroups [][]int) (output shapes.Shape, err error)

CollectiveBroadcast returns the output shape for a collective_broadcast operation. The output shape is identical to the operand shape.

func CollectivePermute added in v0.1.0

func CollectivePermute(operand shapes.Shape, sourceTargetPairs [][2]int) (output shapes.Shape, err error)

CollectivePermute returns the output shape for a collective_permute operation.

func Compare

func Compare(lhsShape, rhsShape shapes.Shape, direction types.ComparisonDirection, compareType types.ComparisonType) (output shapes.Shape, err error)

Compare returns the broadcast shape with dtype set to Bool, for comparison operations (Equal, LessThan, GreaterOrEqual, etc.)

func Complex

func Complex(real, imag shapes.Shape) (output shapes.Shape, err error)

Complex returns the shape resulting from the Complex operation.

func Concatenate

func Concatenate(inputs []shapes.Shape, axis int) (output shapes.Shape, err error)

Concatenate calculates the output shape of a Concatenate operation. It takes a slice of input shapes and the dimension along which to concatenate.

func Convolve

func Convolve(input, kernel shapes.Shape,
	strides []int, paddings [][2]int, inputDilations, kernelDilations []int,
	inputBatchAxis, inputChannelsAxis int, inputSpatialAxes []int,
	kernelInputChannelsAxis, kernelOutputChannelsAxis int, kernelSpatialAxes []int,
	outputBatchAxis, outputChannelsAxis int, outputSpatialAxes []int,
	channelGroupCount, batchGroupCount int) (shapes.Shape, error)

Convolve returns the expected output shape for the Convolve operation.

func DotGeneral

func DotGeneral(
	lhs shapes.Shape, lhsContractingAxes, lhsBatchAxes []int,
	rhs shapes.Shape, rhsContractingAxes, rhsBatchAxes []int,
	outputDType dtypes.DType) (output shapes.Shape, err error)

DotGeneral returns the shape resulting from the corresponding operations.

It also has a side effect on the axes' specifications: it converts negative axes to their corresponding positive axes, and it sorts the axes in ascending order.

func FFT

func FFT(x shapes.Shape, fftType types.FFTType, fftLength []int) (output shapes.Shape, err error)

func Gather

func Gather(operand, startIndices shapes.Shape, indexVectorAxis int,
	offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes,
	startIndicesBatchingAxes, startIndexMap,
	sliceSizes []int, indicesAreSorted bool) (output shapes.Shape, err error)

Gather returns the output shape of a Gather operation.

func IsFinite

func IsFinite(operand shapes.Shape) (output shapes.Shape, err error)

func Pad

func Pad(x, fill shapes.Shape, paddingStart, paddingEnd, paddingInterior []int) (outputShape shapes.Shape, err error)

func RealOrImag

func RealOrImag(complexOperand shapes.Shape) (output shapes.Shape, err error)

RealOrImag returns the shape resulting from the corresponding operations.

func Reduce

func Reduce(inputs, initialValues, reductionInputs, reductionOutputs []shapes.Shape, axes []int) (outputs []shapes.Shape, err error)

Reduce returns the operation's output shapes and checks all shapes and dtypes are valid. The axes are also normalized to positive in-place.

func ReduceWindow

func ReduceWindow(inputs, initialValues []shapes.Shape, reductionInputs, reductionOutputs []shapes.Shape,
	windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (outputs []shapes.Shape, err error)

ReduceWindow returns the expected output shape for the operation.

Notice it doesn't take as input the reductionType parameter, since it doesn't affect the output shape.

func Scatter

func Scatter(inputs []shapes.Shape, scatterIndices shapes.Shape, updates []shapes.Shape,
	updateWindowAxes, insertedWindowAxes []int,
	inputBatchingAxes, scatterIndicesBatchingAxes []int,
	indexedInputAxes []int, indexVectorAxis int,
	updateComputationInputs, updateComputationOutputs []shapes.Shape) (outputs []shapes.Shape, err error)

Scatter checks that the parameters are consistent. The output shapes returned are the unchanged inputs -- the scattered updates are applied to the inputs, but their shapes are unchanged.

The Scatter operations indicesAreSorted and uniqueIndices don't play a role in this.

func Select

func Select(pred, onTrue, onFalse shapes.Shape) (output shapes.Shape, err error)

Select returns the shape resulting from the Select operation.

The pred must be boolean and can be a scalar or have the same shape as isTrue and isFalse. isTrue and isFalse must have the same shape and dtypes.

func Slice

func Slice(operand shapes.Shape, starts, limits, strides []int) (output shapes.Shape, err error)

Slice calculates the output shape for a Slice operation. It checks that starts, limits, and strides have the correct length (matching operand rank), and that the slice parameters are valid for the operand's dimensions. Strides must be positive.

func Transpose

func Transpose(operand shapes.Shape, permutation []int) (output shapes.Shape, err error)

Transpose all axes of the operand. There must be one value in permutations for each axis in the operand. The output will have: output.Shape.Dimension[ii] = operand.Shape.Dimension[permutations[i]].

func UnaryOp

func UnaryOp(opType optypes.OpType, operand shapes.Shape) (output shapes.Shape, err error)

UnaryOp checks the validity of the data type for StandardUnaryOperations and returns either an error or the output shape, which is the same as the operand.

Types

This section is empty.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL