Documentation
¶
Overview ¶
Package shapeinference calculates the shape resulting from operations and validates its inputs.
This can be useful for new optypes.to test and help plan for buffer space for temporary or output buffers.
It defines a BinaryOp function for shape inference for the majority of binary functions, using the standard broadcasting rules.
The majority of the unary functions don't change the shape, except those that explicitly say that in their name, like Reshape, etc.
For the remainder operations, each one gets its own shape inference function.
Index ¶
- Variables
- func AdjustAxisToRank(axis, rank int) (int, error)
- func AllGather(operand shapes.Shape, replicaGroups [][]int, allGatherDim int) (output shapes.Shape, err error)
- func AllReduce(operands []shapes.Shape, reductionInputs, reductionOutputs []shapes.Shape, ...) (outputs []shapes.Shape, err error)
- func AllToAll(operand shapes.Shape, replicaGroups [][]int, ...) (output shapes.Shape, err error)
- func ArgMinMax(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error)
- func BinaryOp(opType optypes.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error)
- func BitcastConvert(operand shapes.Shape, targetDType dtypes.DType) (outputShape shapes.Shape, err error)
- func BroadcastInDim(operand, targetShape shapes.Shape, axesMapping []int) error
- func Clamp(min, operand, max shapes.Shape) (output shapes.Shape, err error)
- func CollectiveBroadcast(operand shapes.Shape, replicaGroups [][]int) (output shapes.Shape, err error)
- func CollectivePermute(operand shapes.Shape, sourceTargetPairs [][2]int) (output shapes.Shape, err error)
- func Compare(lhsShape, rhsShape shapes.Shape, direction types.ComparisonDirection, ...) (output shapes.Shape, err error)
- func Complex(real, imag shapes.Shape) (output shapes.Shape, err error)
- func Concatenate(inputs []shapes.Shape, axis int) (output shapes.Shape, err error)
- func Convolve(input, kernel shapes.Shape, strides []int, paddings [][2]int, ...) (shapes.Shape, error)
- func DotGeneral(lhs shapes.Shape, lhsContractingAxes, lhsBatchAxes []int, rhs shapes.Shape, ...) (output shapes.Shape, err error)
- func FFT(x shapes.Shape, fftType types.FFTType, fftLength []int) (output shapes.Shape, err error)
- func Gather(operand, startIndices shapes.Shape, indexVectorAxis int, ...) (output shapes.Shape, err error)
- func IsFinite(operand shapes.Shape) (output shapes.Shape, err error)
- func Pad(x, fill shapes.Shape, paddingStart, paddingEnd, paddingInterior []int) (outputShape shapes.Shape, err error)
- func RealOrImag(complexOperand shapes.Shape) (output shapes.Shape, err error)
- func Reduce(inputs, initialValues, reductionInputs, reductionOutputs []shapes.Shape, ...) (outputs []shapes.Shape, err error)
- func ReduceWindow(inputs, initialValues []shapes.Shape, ...) (outputs []shapes.Shape, err error)
- func Scatter(inputs []shapes.Shape, scatterIndices shapes.Shape, updates []shapes.Shape, ...) (outputs []shapes.Shape, err error)
- func Select(pred, onTrue, onFalse shapes.Shape) (output shapes.Shape, err error)
- func Slice(operand shapes.Shape, starts, limits, strides []int) (output shapes.Shape, err error)
- func Transpose(operand shapes.Shape, permutation []int) (output shapes.Shape, err error)
- func UnaryOp(opType optypes.OpType, operand shapes.Shape) (output shapes.Shape, err error)
Constants ¶
This section is empty.
Variables ¶
var ( // BooleanOrBitwiseOperations take booleans as input, aka. logical operations. BooleanOrBitwiseOperations = utils.SetWith( optypes.And, optypes.Or, optypes.Xor, optypes.Not, ) // BitwiseOperations operates only on integer (binary) numbers and won't work on floats or complex numbers. BitwiseOperations = utils.SetWith( optypes.Popcnt, optypes.ShiftLeft, optypes.ShiftRightArithmetic, optypes.ShiftRightLogical, optypes.CountLeadingZeros, ) // NumberOperations can take any type of number as input: integers, floats, or complex numbers. NumberOperations = utils.SetWith( optypes.Add, optypes.Subtract, optypes.Multiply, optypes.Divide, optypes.Power, optypes.Remainder, optypes.Abs, optypes.Sign, optypes.Compare, ) SignedNumberOperations = utils.SetWith( optypes.Negate, ) // FloatOperations operates only on float (and not on complex numbers). FloatOperations = utils.SetWith( optypes.Erf, optypes.Logistic, optypes.Cosine, optypes.Sine, optypes.Tanh, ) // FloatOrComplexOperations operates only on float or complex numbers and won't work on integer or boolean values. FloatOrComplexOperations = utils.SetWith( optypes.Exponential, optypes.ExponentialMinusOne, optypes.Log, optypes.LogPlusOne, optypes.Ceil, optypes.Floor, optypes.RoundNearestEven, optypes.Rsqrt, optypes.Sqrt, optypes.IsFinite, ) // ComplexOperations operates only on complex numbers. ComplexOperations = utils.SetWith( optypes.Imag, optypes.Real, ) // StandardBinaryOperations include all operations that have two operands usually named lhs (left-hand-side) and // rhs (right-hand-side) and are usually commutative (invariant to order). StandardBinaryOperations = utils.SetWith( optypes.Add, optypes.Atan2, optypes.Subtract, optypes.Multiply, optypes.Divide, optypes.Power, optypes.Remainder, optypes.And, optypes.Or, optypes.Xor, optypes.Maximum, optypes.Minimum, optypes.ShiftLeft, optypes.ShiftRightArithmetic, optypes.ShiftRightLogical, ) // ComparisonOperations include all operations that take two inputs and returns booleans with the results of // a comparison. // For StableHLO they are converged in only one optypes.Compare, that takes as an attribute the comparison type. ComparisonOperations = utils.SetWith(optypes.Compare) // StandardUnaryOperations include all operations that have a single operand as input, and the return shape is the // same as the input (so no reductions). StandardUnaryOperations = utils.SetWith( optypes.Not, optypes.Popcnt, optypes.Cbrt, optypes.CountLeadingZeros, optypes.Erf, optypes.Exponential, optypes.ExponentialMinusOne, optypes.Log, optypes.LogPlusOne, optypes.Logistic, optypes.Ceil, optypes.Floor, optypes.RoundNearestEven, optypes.RoundNearestAfz, optypes.Rsqrt, optypes.Sqrt, optypes.Cosine, optypes.Sine, optypes.Tan, optypes.Tanh, optypes.Abs, optypes.Negate, optypes.Sign, ) )
Functions ¶
func AdjustAxisToRank ¶
AdjustAxisToRank returns a positive axis, adjusting negative numbers to the correct rank.
func AllGather ¶ added in v0.1.0
func AllGather(operand shapes.Shape, replicaGroups [][]int, allGatherDim int) (output shapes.Shape, err error)
AllGather returns the output shape for an all_gather operation.
func AllReduce ¶ added in v0.1.0
func AllReduce(operands []shapes.Shape, reductionInputs, reductionOutputs []shapes.Shape, replicaGroups [][]int) ( outputs []shapes.Shape, err error)
AllReduce returns the output shapes for a collective_all_reduce operation. The output shapes are identical to the operand shapes. It also validates the computation function shapes.
func AllToAll ¶ added in v0.1.0
func AllToAll(operand shapes.Shape, replicaGroups [][]int, splitDimension, concatDimension, splitCount int) (output shapes.Shape, err error)
AllToAll returns the output shape for an all_to_all operation.
func ArgMinMax ¶
func ArgMinMax(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error)
ArgMinMax calculates the output shape for an ArgMinMax operation. It will be the shape of the operand minus the "reduce" axis.
func BinaryOp ¶
func BinaryOp(opType optypes.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error)
BinaryOp returns the expected output shape for ops in the StandardBinaryOperations set -- those include all operations that have two operands usually named lhs (left-hand-side) and rhs (right-hand-side), and they are usually commutative (invariant to order).
It returns an error if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching dtypes, or LogicalAnd not having booleans (dtype.Bool) as input.
func BitcastConvert ¶
func BroadcastInDim ¶
BroadcastInDim verifies that the arguments are valid. The output shape is already known, so nothing is returned.
The axesMapping is changed in place, replacing negative axes with their positive equivalent.
func CollectiveBroadcast ¶ added in v0.1.0
func CollectiveBroadcast(operand shapes.Shape, replicaGroups [][]int) (output shapes.Shape, err error)
CollectiveBroadcast returns the output shape for a collective_broadcast operation. The output shape is identical to the operand shape.
func CollectivePermute ¶ added in v0.1.0
func CollectivePermute(operand shapes.Shape, sourceTargetPairs [][2]int) (output shapes.Shape, err error)
CollectivePermute returns the output shape for a collective_permute operation.
func Compare ¶
func Compare(lhsShape, rhsShape shapes.Shape, direction types.ComparisonDirection, compareType types.ComparisonType) (output shapes.Shape, err error)
Compare returns the broadcast shape with dtype set to Bool, for comparison operations (Equal, LessThan, GreaterOrEqual, etc.)
func Concatenate ¶
Concatenate calculates the output shape of a Concatenate operation. It takes a slice of input shapes and the dimension along which to concatenate.
func Convolve ¶
func Convolve(input, kernel shapes.Shape, strides []int, paddings [][2]int, inputDilations, kernelDilations []int, inputBatchAxis, inputChannelsAxis int, inputSpatialAxes []int, kernelInputChannelsAxis, kernelOutputChannelsAxis int, kernelSpatialAxes []int, outputBatchAxis, outputChannelsAxis int, outputSpatialAxes []int, channelGroupCount, batchGroupCount int) (shapes.Shape, error)
Convolve returns the expected output shape for the Convolve operation.
func DotGeneral ¶
func DotGeneral( lhs shapes.Shape, lhsContractingAxes, lhsBatchAxes []int, rhs shapes.Shape, rhsContractingAxes, rhsBatchAxes []int, outputDType dtypes.DType) (output shapes.Shape, err error)
DotGeneral returns the shape resulting from the corresponding operations.
It also has a side effect on the axes' specifications: it converts negative axes to their corresponding positive axes, and it sorts the axes in ascending order.
func Gather ¶
func Gather(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (output shapes.Shape, err error)
Gather returns the output shape of a Gather operation.
func RealOrImag ¶
RealOrImag returns the shape resulting from the corresponding operations.
func Reduce ¶
func Reduce(inputs, initialValues, reductionInputs, reductionOutputs []shapes.Shape, axes []int) (outputs []shapes.Shape, err error)
Reduce returns the operation's output shapes and checks all shapes and dtypes are valid. The axes are also normalized to positive in-place.
func ReduceWindow ¶
func ReduceWindow(inputs, initialValues []shapes.Shape, reductionInputs, reductionOutputs []shapes.Shape, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (outputs []shapes.Shape, err error)
ReduceWindow returns the expected output shape for the operation.
Notice it doesn't take as input the reductionType parameter, since it doesn't affect the output shape.
func Scatter ¶
func Scatter(inputs []shapes.Shape, scatterIndices shapes.Shape, updates []shapes.Shape, updateWindowAxes, insertedWindowAxes []int, inputBatchingAxes, scatterIndicesBatchingAxes []int, indexedInputAxes []int, indexVectorAxis int, updateComputationInputs, updateComputationOutputs []shapes.Shape) (outputs []shapes.Shape, err error)
Scatter checks that the parameters are consistent. The output shapes returned are the unchanged inputs -- the scattered updates are applied to the inputs, but their shapes are unchanged.
The Scatter operations indicesAreSorted and uniqueIndices don't play a role in this.
func Select ¶
Select returns the shape resulting from the Select operation.
The pred must be boolean and can be a scalar or have the same shape as isTrue and isFalse. isTrue and isFalse must have the same shape and dtypes.
func Slice ¶
Slice calculates the output shape for a Slice operation. It checks that starts, limits, and strides have the correct length (matching operand rank), and that the slice parameters are valid for the operand's dimensions. Strides must be positive.
Types ¶
This section is empty.