shapes

package
v0.0.0-...-475f5b6 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2026 License: Apache-2.0 Imports: 12 Imported by: 0

Documentation

Overview

Package shapes define Shape and DType and associated tools.

Shape represents the shape (rank, dimensions for each axis, and DType) of either a Tensor or the expected shape of a node in a computation graph. DType indicates the data type for a Tensor's unit element.

Optionally, the shape can also carry axes names.

It also supports "dynamic shapes", where one or more axis has an undefined (DynamicDim == -1) dimension. Any dynamic dimension must be named (it must have a non-empty AxisName), so their values can be inferred when needed (see AxisBindings).

## Immutable Semantics

The Shape object should be immutable semantic after creation: function that need to mutate shapes should clone first (see Shape.Clone), mutate, and return the updated (and henceforward immutable) shape. It's ok to simply copy shapes (shallow copy) if they are not meant to be mutated.

## Glossary

  • Rank: number of axes (dimensions) of a Tensor.
  • Axis: is the index of a dimension on a multidimensional array (tensor). Sometimes used interchangeably with Dimension, but here we try to refer to a dimension index as "axis" (plural axes), and its size as its dimension. An axis may also have an associated name.
  • Dimension: the size of a multi-dimension Tensor in one of its axes. See the example below.
  • DynamicDim (-1): special dimension value that indicates the axis is dynamic (unknown at graph building and compilation time). Axes with dynamic dimensions must be named.
  • DType: the data type of the unit element in a tensor. Enumeration defined in github.com/gomlx/compute/dtypes
  • Scalar: is a shape where there are no axes (or dimensions), only a single value of the associated DType.

Example: The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor would have shape `(int32)[2, 3]`. We say it has rank 2 (so 2 axes), axis 0 has dimension 2, and axis 1 has dimension 3. This shape could be created with `shapes.Make(int32, 2, 3)`.

## Creating a new shape:

  • Make(dtype, dimensions...int): create a concrete (no dynamic axes) un-named shape.
  • MakeDynamic(dtype, dimensions []int, axesNames []string): create a shape with (optional) dynamic axes and axes names. Dynamic axes must have an associated name, while concrete axes are usually unnamed ("").
  • MakeTuple(shapes...): internal use, create a shape that represents a tuple of shapes. Internal use, and subject to change.

Iterators and Strides

Shapes support several iteration facilities. See Shape.Iter, Shape.IterOn, Shape.IterOnAxes and Shape.Strides.

Index

Constants

View Source
const DynamicDim = -1

DynamicDim is the sentinel value used in Shape.Dimensions to indicate an axis whose size is unknown at graph build time and will be resolved at execution time via AxisBindings.

Note: this is the same numeric value as UncheckedAxis (-1), but used in a different context. UncheckedAxis is used in assertion arguments (CheckDims, AssertDims) to mean "don't check this axis". DynamicDim is used in Shape.Dimensions to mean "this axis has an unknown size".

View Source
const UncheckedAxis = int(-1)

UncheckedAxis can be used in CheckDims or AssertDims functions for an axis whose dimension doesn't matter.

Variables

This section is empty.

Functions

func Assert

func Assert(shaped HasShape, dtype dtypes.DType, dimensions ...int)

Assert checks that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

func AssertDims

func AssertDims(shaped HasShape, dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertRank

func AssertRank(shaped HasShape, rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertScalar

func AssertScalar(shaped HasShape)

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func CastAsDType

func CastAsDType(value any, dtype DType) any

CastAsDType casts a numeric value to the corresponding for the DType. If the value is a slice it will convert to a newly allocated slice of the given DType.

It doesn't work for complex numbers.

func CheckDims

func CheckDims(shaped HasShape, dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or any of the dimensions.

func CheckRank

func CheckRank(shaped HasShape, rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func CheckScalar

func CheckScalar(shaped HasShape) error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func ConvertTo

func ConvertTo[T NumberNotComplex](value any) T

ConvertTo converts any scalar (typically returned by `tensor.Local.Value()`) of the supported dtypes to `T`. Returns 0 if value is not a scalar or not a supported number (e.g: bool). It doesn't work for if T (the output type) is a complex number. If value is a complex number, it converts by taking the real part of the number and discarding the imaginary part.

func UnifyAxisName

func UnifyAxisName(name1, name2 string) (string, error)

UnifyAxisName resolves the output axis name when combining two axes from different shapes.

Rules:

  • "" + "" = "" (both unnamed → unnamed)
  • "name" + "" = "name" (one named → keep the name)
  • "" + "name" = "name" (one named → keep the name)
  • "name" + "name" = "name" (same name → keep it)
  • "a" + "b" = error (different names → conflict)

func UnifyAxisNames

func UnifyAxisNames(s1, s2 Shape) ([]string, error)

UnifyAxisNames unifies axis names from two shapes of the same rank. Returns the unified axis names, or error on name conflicts. Returns nil if neither shape has axis names.

func UnsafeSliceForDType deprecated

func UnsafeSliceForDType(dtype DType, unsafePtr unsafe.Pointer, length int) any

UnsafeSliceForDType creates a slice of the corresponding dtype and casts it to any. It uses unsafe.Slice. Set `length` to the number of `DType` elements (not the number of bytes).

Deprecated: use dtypes.UnsafeAnySliceFromBytes instead.

Types

type AxisBindings

type AxisBindings map[string]int

AxisBindings maps named axis names to their concrete dimension values. Used at execution time to resolve dynamic shapes.

func (AxisBindings) Extract

func (b AxisBindings) Extract(template, concrete Shape) error

Extract axis bindings by comparing a template shape (with named dynamic axes) against a concrete shape with all dimensions known (presumably given during execution, when the concrete inputs are given).

Returns an error if the shapes are incompatible: different ranks, different static dimensions, or inconsistent bindings where the same axis name maps to different concrete values.

func (AxisBindings) Key

func (b AxisBindings) Key() string

Key returns a deterministic string key suitable for use as a map key. Axis names are sorted alphabetically and formatted as "name=value,name=value".

type HasShape

type HasShape interface {
	Shape() Shape
}

HasShape is an interface for objects that have an associated Shape. `tensor.Tensor` (concrete tensor) and `graph.Node` (tensor representations in a computation graph), `context.Variable` and Shape itself implement the interface.

type Shape

type Shape struct {
	// DType is the data type of the unit element in a tensor.
	DType dtypes.DType

	// Dimensions is the size of each axis. Its length determines the rank.
	// A value of DynamicDim (-1) indicates a dynamic axis whose size is unknown at graph build time.
	Dimensions []int

	// AxisNames holds optional names for each axis. nil means no axis names (the default).
	// When non-nil, len(AxisNames) must equal len(Dimensions).
	// An empty string "" means the axis is unnamed. A non-empty string names the axis.
	AxisNames []string `json:"axis_names,omitempty"`

	// TupleShapes is used if this Shape represents a tuple of elements.
	// Internal use only.
	TupleShapes []Shape `json:"tuple,omitempty"` // Shapes of the tuple, if this is a tuple.
}

Shape represents the shape of either a Tensor or the expected shape of the value from a computation node.

Use Make to create a new shape. See examples in the package documentation.

func ConcatenateDimensions

func ConcatenateDimensions(s1, s2 Shape) (shape Shape)

ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other. It doesn't work for Tuples.

func FromAnyValue

func FromAnyValue(v any) (shape Shape, err error)

FromAnyValue attempts to convert a Go "any" value to its expected shape. Accepted values are plain-old-data (POD) types (ints, floats, complex), slices (or multiple level of slices) of POD.

It returns the expected shape.

Example:

shape := shapes.FromAnyValue([][]float64{{0, 0}}) // Returns shape (Float64)[1 2]

func GobDeserialize

func GobDeserialize(decoder *gob.Decoder) (s Shape, err error)

GobDeserialize a Shape. Returns new Shape or an error. Handles both the old format (without a format version) and the v1 format (with AxisNames).

func Invalid

func Invalid() Shape

Invalid returns an invalid shape.

Invalid().IsOk() == false.

func Make

func Make(dtype dtypes.DType, dimensions ...int) Shape

Make returns a Shape structure filled with the values given.

See MakeDynamic for shapes with dynamic and/or named axes. See MakeTuple for tuple shapes.

func MakeDynamic

func MakeDynamic(dtype dtypes.DType, dimensions []int, axisNames []string) Shape

MakeDynamic creates a Shape with named axes and optional dynamic dimensions.

dimensions can contain DynamicDim (-1) for axes whose size is unknown at graph build time. Dynamic axes must have a non-empty name in axisNames. Static axes (>= 0) may have names too.

axisNames must have the same length as dimensions. Use "" for unnamed axes.

Axis names starting with "=" or "#" are reserved for internal (backend implementation) use and shouldn't e used by end users.

Example:

shapes.MakeDynamic(dtypes.F32, []int{-1, 512}, []string{"batch", ""})
// Creates shape (Float32)[batch=-1 512]

func MakeTuple

func MakeTuple(elements []Shape) Shape

MakeTuple returns a shape representing a tuple of elements with the given shapes.

func Scalar

func Scalar[T dtypes.Number]() Shape

Scalar returns a scalar Shape for the given type.

func (Shape) Assert

func (s Shape) Assert(dtype dtypes.DType, dimensions ...int)

Assert checks that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

func (Shape) AssertDims

func (s Shape) AssertDims(dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertRank

func (s Shape) AssertRank(rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertScalar

func (s Shape) AssertScalar()

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AxisName

func (s Shape) AxisName(axis int) string

AxisName returns the name of the given axis, or "" if unnamed or if AxisNames is nil. axis supports negative indexing.

func (Shape) ByteSize

func (s Shape) ByteSize() int64

ByteSize returns the number of bytes used to store an array of the given shape.

func (Shape) Check

func (s Shape) Check(dtype dtypes.DType, dimensions ...int) error

Check that the shape has the given dtype, dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the dtype or rank is different or if any of the dimensions don't match.

func (Shape) CheckDims

func (s Shape) CheckDims(dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or if any of the dimensions don't match.

func (Shape) CheckRank

func (s Shape) CheckRank(rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func (Shape) CheckScalar

func (s Shape) CheckScalar() error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func (Shape) Clone

func (s Shape) Clone() (s2 Shape)

Clone returns a new deep copy of the shape.

func (Shape) Dim

func (s Shape) Dim(axis int) int

Dim returns the dimension of the given axis. axis can take negative numbers, in which case it counts as starting from the end -- so axis=-1 refers to the last axis. Like with a slice indexing, it panics for an out-of-bound axis.

func (Shape) Equal

func (s Shape) Equal(s2 Shape) bool

Equal compares two shapes for equality: dtype and dimensions are compared.

func (Shape) EqualDimensions

func (s Shape) EqualDimensions(s2 Shape) bool

EqualDimensions compares two shapes for equality of dimensions.

DType and axis names are ignored.

func (Shape) GobSerialize

func (s Shape) GobSerialize(encoder *gob.Encoder) (err error)

GobSerialize shape in binary format.

Format v1 (current): DType, Dimensions, -1 (version marker), hasAxisNames (bool), [AxisNames if hasAxisNames], numTuples, [sub-shapes...].

Old format: DType, Dimensions, numTuples, [sub-shapes...]. GobDeserialize handles both formats for backward compatibility (old data → new code). Note: new-format data cannot be read by old code (the -1 marker would be interpreted as numTuples, causing a panic). This is forward-compatible only.

func (Shape) IsAxisDynamic

func (s Shape) IsAxisDynamic(axis int) bool

IsAxisDynamic returns true if the given axis has a dynamic dimension (DynamicDim). axis supports negative indexing (e.g., -1 for the last axis).

func (Shape) IsDynamic

func (s Shape) IsDynamic() bool

IsDynamic returns true if any dimension is dynamic (DynamicDim).

func (Shape) IsScalar

func (s Shape) IsScalar() bool

IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).

func (Shape) IsTuple

func (s Shape) IsTuple() bool

IsTuple returns whether the shape represents a tuple.

func (Shape) IsZeroSize

func (s Shape) IsZeroSize() bool

IsZeroSize returns whether any of the dimensions is zero, in which case it's an empty shape, with no data attached to it.

Notice scalars are not zero in size -- they have size one, but rank zero.

func (Shape) Iter

func (s Shape) Iter() iter.Seq2[int, []int]

Iter iterates sequentially over all possible indices of the given shape.

It yields the flat index (counter) and a slice of indices for each axis.

Panics if any dimension is dynamic (DynamicDim).

To avoid allocating the slice of indices, the yielded indices is owned by the Iter() method: don't change it inside the loop.

func (Shape) IterOn

func (s Shape) IterOn(indices []int) iter.Seq2[int, []int]

IterOn iterates over all possible indices of the given shape.

It yields the flat index (counter) and a slice of indices for each axis.

Panics if any dimension is dynamic (DynamicDim).

The iteration updates the indices on the given indices slice. During the iteration the caller shouldn't modify the slice of indices, otherwise it will lead to undefined behavior.

It expects len(indices) == s.Rank(). It will panic otherwise.

func (Shape) IterOnAxes

func (s Shape) IterOnAxes(axesToIterate, strides, indices []int) iter.Seq2[int, []int]

IterOnAxes iterates over all possible indices of the given shape's axesToIterate.

It yields the flat index and the update indices for all axes of the shape (not only the one in axes). The indices not pointed by axesToIterate are not touched.

Panics if any dimension is dynamic (DynamicDim).

Args:

  • axesToIterate: axes of the shape to iterate over. They must be 0 <= axis < rank. Axes not included here are not touched in the indices.
  • strides: for the shape, as returned by Shape.Strides(). If nil, it will use the value returned by Shape.Strides. If you are iterating over a shape many times, pre-calculating the strides saves some time. If provided, it expects len(strides) == s.Rank(). It will panic otherwise.
  • indices: slice that will be yielded during the iteration, it must have length equal to the shape's rank. If it is nil, one will be allocated for the iteration. The indices not in axesToIterate are left untouched, but they are used to calculate the flatIdx that is also yielded. If provided, it expects len(indices) == s.Rank(). It will panic otherwise.

During the iteration the caller shouldn't modify the slice of indices, otherwise it will lead to undefined behavior.

Example:

// Create a shape with dimensions [2, 3, 4]
shape := Make(dtypes.F32, 2, 3, 4)

// Iterate over the first and last axes (0 and 2)
axesToIterate := []int{0, 2}
indices := make([]int, shape.Rank())
indices[1] = 1  // Fix middle axis to 1

// Each iteration will update indices[0] and indices[2], keeping indices[1]=1
for flatIdx, indices := range shape.IterOnAxes(axesToIterate, nil, indices) {
    fmt.Printf("flatIdx=%d, indices=%v\n", flatIdx, indices)
}

func (Shape) Memory deprecated

func (s Shape) Memory() uintptr

Memory is an old alias to ByteSize, kept for backward compatibility.

Deprecated: use ByteSize() instead.

func (Shape) Ok

func (s Shape) Ok() bool

Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.

func (Shape) Rank

func (s Shape) Rank() int

Rank of the shape, that is, the number of dimensions.

func (Shape) Resolve

func (s Shape) Resolve(bindings AxisBindings) (Shape, error)

Resolve returns a new Shape with all dynamic dimensions replaced by their bound values -- except if the original shape was not dynamic, then it is returned as is (not a copy).

The resolved shape retains its AxisNames for provenance/debugging.

Returns an error a named dynamic axis has no corresponding binding, or if the binding is non-positive.

func (Shape) Shape

func (s Shape) Shape() Shape

Shape returns a shallow copy of itself. It implements the HasShape interface.

func (Shape) Size

func (s Shape) Size() (size int)

Size returns the number of elements (not bytes) for this shape. It's the product of all dimensions.

It panics if s.IsDynamic().

For the number of bytes used to store this shape, see Shape.ByteSize.

func (Shape) Strides

func (s Shape) Strides() (strides []int)

Strides returns the strides for each axis of the shape, assuming a "row-major" layout in memory, the one used everywhere in GoMLX.

Notice the strides are **not in bytes**, but in indices.

It panics if any dimension is dynamic (DynamicDim), see Shape.IsDynamic.

func (Shape) String

func (s Shape) String() string

String implements stringer, pretty-prints the shape.

func (Shape) TupleSize

func (s Shape) TupleSize() int

TupleSize returns the number of elements in the tuple, if it is a tuple.

func (Shape) WithAxisNames

func (s Shape) WithAxisNames(names ...string) Shape

WithAxisNames returns a copy of the shape with the given axis names set. The number of names must equal the rank.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL