shapeinference

package
v0.22.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 22, 2025 License: Apache-2.0 Imports: 6 Imported by: 0

Documentation

Overview

Package shapeinference calculates the shape resulting from operations, and validates its inputs.

This can be useful for new backends to test and help plan for buffer space for temporary or output buffers.

It defines a BinaryOp function for shape inference for the majority of binary functions, using the standard broadcasting rules.

The majority of the unary functions don't change the shape, except those that explicitly say that in their name, like Reshape, etc.

For the remainder ops, it defines one function per OpType.

Index

Constants

This section is empty.

Variables

View Source
var (
	// BooleanOperations take booleans as input, aka. logical operations.
	BooleanOperations = types.SetWith(
		backends.OpTypeLogicalAnd,
		backends.OpTypeLogicalOr,
		backends.OpTypeLogicalXor,
		backends.OpTypeLogicalNot,
	)

	// BitwiseOperations operates only on integer (binary) numbers and won't work on floats or complex numbers.
	BitwiseOperations = types.SetWith(
		backends.OpTypeBitwiseAnd,
		backends.OpTypeBitwiseOr,
		backends.OpTypeBitwiseXor,
		backends.OpTypeBitwiseNot,
		backends.OpTypeBitCount,
		backends.OpTypeShiftLeft,
		backends.OpTypeShiftRightArithmetic,
		backends.OpTypeShiftRightLogical,
		backends.OpTypeBitCount,
		backends.OpTypeClz,
	)

	// NumberOperations can take any type of number as input: integers, floats, or complex numbers.
	NumberOperations = types.SetWith(
		backends.OpTypeAdd,
		backends.OpTypeSub,
		backends.OpTypeMul,
		backends.OpTypeDiv,
		backends.OpTypePow,
		backends.OpTypeRem,

		backends.OpTypeAbs,
		backends.OpTypeSign,

		backends.OpTypeEqual,
		backends.OpTypeNotEqual,
		backends.OpTypeGreaterOrEqual,
		backends.OpTypeGreaterThan,
		backends.OpTypeLessOrEqual,
		backends.OpTypeLessThan,

		backends.OpTypeEqualTotalOrder,
		backends.OpTypeGreaterOrEqualTotalOrder,
		backends.OpTypeGreaterThanTotalOrder,
		backends.OpTypeLessOrEqualTotalOrder,
		backends.OpTypeLessThanTotalOrder,
	)

	SignedNumberOperations = types.SetWith(
		backends.OpTypeNeg,
	)

	// FloatOperations operates only on float (and not on complex numbers).
	FloatOperations = types.SetWith(
		backends.OpTypeErf,
		backends.OpTypeLogistic,
		backends.OpTypeCos,
		backends.OpTypeSin,
		backends.OpTypeTanh,
	)

	// FloatOrComplexOperations operates only on float or complex numbers and won't work on integer or boolean values.
	FloatOrComplexOperations = types.SetWith(
		backends.OpTypeExp,
		backends.OpTypeExpm1,
		backends.OpTypeLog,
		backends.OpTypeLog1p,
		backends.OpTypeCeil,
		backends.OpTypeFloor,
		backends.OpTypeRound,
		backends.OpTypeRsqrt,
		backends.OpTypeSqrt,
		backends.OpTypeIsFinite,
	)

	// ComplexOperations operates only on complex numbers.
	ComplexOperations = types.SetWith(
		backends.OpTypeImag,
		backends.OpTypeReal,
		backends.OpTypeConj,
	)

	// StandardBinaryOperations include all operations that have two operands usually named lhs (left-hand-side) and
	// rhs (right-hand-side) and are usually commutative (invariant to order).
	StandardBinaryOperations = types.SetWith(
		backends.OpTypeAdd,
		backends.OpTypeSub,
		backends.OpTypeMul,
		backends.OpTypeDiv,
		backends.OpTypePow,
		backends.OpTypeRem,
		backends.OpTypeBitwiseAnd,
		backends.OpTypeBitwiseOr,
		backends.OpTypeBitwiseXor,
		backends.OpTypeLogicalAnd,
		backends.OpTypeLogicalOr,
		backends.OpTypeLogicalXor,
		backends.OpTypeMax,
		backends.OpTypeMin,
	)

	// ComparisonOperations include all operations that take two inputs and returns booleans with the results of
	// a comparison.
	ComparisonOperations = types.SetWith(
		backends.OpTypeEqual,
		backends.OpTypeNotEqual,
		backends.OpTypeEqualTotalOrder,
		backends.OpTypeGreaterOrEqual,
		backends.OpTypeGreaterOrEqualTotalOrder,
		backends.OpTypeGreaterThan,
		backends.OpTypeGreaterThanTotalOrder,
		backends.OpTypeLessOrEqual,
		backends.OpTypeLessOrEqualTotalOrder,
		backends.OpTypeLessThan,
		backends.OpTypeLessThanTotalOrder,
	)

	// StandardUnaryOperations include all operations that have a single operand as input, and the return shape is the
	// same as the input (so no reductions).
	StandardUnaryOperations = types.SetWith(
		backends.OpTypeLogicalNot,
		backends.OpTypeBitwiseNot,
		backends.OpTypeBitCount,
		backends.OpTypeClz,
		backends.OpTypeErf,
		backends.OpTypeExp,
		backends.OpTypeExpm1,
		backends.OpTypeLog,
		backends.OpTypeLog1p,
		backends.OpTypeLogistic,
		backends.OpTypeCeil,
		backends.OpTypeFloor,
		backends.OpTypeRound,
		backends.OpTypeRsqrt,
		backends.OpTypeSqrt,
		backends.OpTypeImag,
		backends.OpTypeReal,
		backends.OpTypeConj,
		backends.OpTypeCos,
		backends.OpTypeSin,
		backends.OpTypeTanh,
		backends.OpTypeAbs,
		backends.OpTypeNeg,
		backends.OpTypeSign,
	)
)

Functions

func ArgMinMaxOp added in v0.19.3

func ArgMinMaxOp(operand shapes.Shape, axis int, outputDType dtypes.DType) (output shapes.Shape, err error)

ArgMinMaxOp calculates the output shape for an ArgMinMax operation. It will be the shape of the operand minus the "reduce" axis.

func BinaryOp

func BinaryOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error)

BinaryOp returns the expected output shape for ops in the StandardBinaryOperations set -- those include all operations that have two operands usually named lhs (left-hand-side) and rhs (right-hand-side), and they are usually commutative (invariant to order).

It returns an error if the data type (shape.DType) is invalid for the operation -- e.g.: non-matching dtypes, or LogicalAnd not having booleans (dtype.Bool) as input.

func BroadcastInDimOp

func BroadcastInDimOp(operand, outputShape shapes.Shape, broadcastAxes []int) error

BroadcastInDimOp verifies that the arguments are valid. The output shape is already known, so nothing is returned.

func BroadcastOp

func BroadcastOp(operand shapes.Shape, prefixDims []int) (output shapes.Shape, err error)

BroadcastOp adds the prefixDims to the start of the shape.

func ComparisonOp

func ComparisonOp(opType backends.OpType, lhsShape, rhsShape shapes.Shape) (output shapes.Shape, err error)

ComparisonOp returns the broadcast shape with dtype set to Bool, for comparison operations (Equal, LessThan, GreaterOrEqual, etc.)

func ConcatenateOp

func ConcatenateOp(inputs []shapes.Shape, axis int) (output shapes.Shape, err error)

ConcatenateOp calculates the output shape of a Concatenate operation. It takes a slice of input shapes and the dimension along which to concatenate.

func ConvGeneralOp added in v0.22.0

func ConvGeneralOp(input, kernel shapes.Shape, axes backends.ConvolveAxesConfig,
	strides []int, paddings [][2]int,
	inputDilations, kernelDilations []int,
	channelGroupCount, batchGroupCount int) (shapes.Shape, error)

ConvGeneralOp returns the expected output shape for the ConvGeneral operation.

func GatherOp

func GatherOp(operand, startIndices shapes.Shape, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes,
	startIndexMap, sliceSizes []int, indicesAreSorted bool) (output shapes.Shape, err error)

GatherOp returns the output shape of a Gather operation.

func ReduceOp

func ReduceOp(operand shapes.Shape, axes []int) (output shapes.Shape, err error)

ReduceOp works for the ReduceMax, ReduceMin, ReduceSum and ReduceProduct ops.

func ReduceWindowOp added in v0.19.3

func ReduceWindowOp(operand shapes.Shape, windowDimensions, strides, baseDilations, windowDilations []int, paddings [][2]int) (shapes.Shape, error)

ReduceWindowOp returns the expected output shape for the operation.

Notice it doesn't take as input the reductionType parameter, since it doesn't affect the output shape.

func ReshapeOp

func ReshapeOp(operand shapes.Shape, dims []int) (output shapes.Shape, err error)

ReshapeOp to the given dimensions: trivial output shape, but this function also checks that the sizes are the same.

Notice the backends.Reshape doesn't support auto-scaling dimensions (set to -1), as graph.Reshape does.

func ScatterOp

func ScatterOp(operand, indices, updates shapes.Shape, indexVectorAxis int, updateWindowAxes, insertedWindowAxes, scatterAxesToOperandAxes []int) (output shapes.Shape, err error)

ScatterOp checks that the parameters are consistent. The output shape returned is the unchanged operand -- the scattered updates are applied to the operand, but its shape is unchanged.

The Scatter operations indicesAreSorted and uniqueIndices don't play a role in this.

func SliceOp added in v0.19.1

func SliceOp(operand shapes.Shape, starts, limits, strides []int) (output shapes.Shape, err error)

SliceOp calculates the output shape for a Slice operation. It checks that starts, limits, and strides have the correct length (matching operand rank), and that the slice parameters are valid for the operand's dimensions. Strides must be positive.

func TransposeOp

func TransposeOp(operand shapes.Shape, permutations []int) (output shapes.Shape, err error)

TransposeOp all axes of the operand. There must be one value in permutations for each axis in the operand. The output will have: output.Shape.Dimension[ii] = operand.Shape.Dimension[permutations[i]].

func UnaryOp

func UnaryOp(opType backends.OpType, operand shapes.Shape) (output shapes.Shape, err error)

UnaryOp checks the validity of the data type for StandardUnaryOperations and returns either an error or the output shape, which is the same as the operand.

func WhereOp

func WhereOp(condition, onTrue, onFalse shapes.Shape) (output shapes.Shape, err error)

WhereOp returns the shape resulting from the Where operation.

Shape constraints for the operation:

  1. The onTrue and onFalse must have the exact same shape, or one can be a scalar.
  2. The condition must either be a scalar or match the shape of onTrue or onFalse, except for the DType that must be Bool.

Types

This section is empty.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL