shapes

package
v0.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 23, 2023 License: Apache-2.0 Imports: 6 Imported by: 0

Documentation

Overview

Package shapes defines Shape and DType and associated tools.

Shape represents the shape (rank, dimensions and DType) of either a Tensor or the expected shape of a node in a computation Graph. DType indicates the type of the unit element of a Tensor (or its representation as a node in a computation Graph).

Shape and DType are used both by the concrete tensor values (see tensor package) and when working on the computation graph (see graph package).

**Asserts**: When coding ML models, one delicate part is keeping tabs on the shape of the nodes of the graphs -- unfortunately there is no compile-time checking of values, so validation only happens in runtime. To facilitate, and also to serve as code documentation, this package provides two variations of _assert_ funtionality. Examples:

  1. `AssertRank` and `AssertDims` checks that the rank and dimensions of the given object (that has a `Shape` method) match, otherwise it panics. The `-1` means the dimension is unchecked (it can be anything).

```

func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
   _ = spec  // Not needed here, we know the dataset.
   shapes.AssertRank(inputs, 2)
   batchSize := inputs.Shape().Dimensions[0]
   logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1)
   shapes.AssertDims(logits, batchSize, -1)
   return []*Node{logits}
}

```

**Glossary**:

  • **Rank**: number of axes (dimensions) of a Tensor.
  • **Axis**: is the index of a dimension on a multi-dimensional Tensor. Sometimes used interchangeably with Dimension, but here we try to refer to a dimension index as "axis" (plural axes), and its size as its dimension.
  • **Dimension**: the size of a multi-dimensions Tensor in one of its axes. See example below:
  • **DType**: the data type of the unit element in a tensor.
  • **Scalar**: is a shape where there are no axes (or dimensions), only a single value of the associated DType.

Example:

The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor
would have shape `(int32)[2 3]`. We say it has **rank 2** (so 2 axes), **axis 0** has
**dimension 2**, and **axis 1** has **dimension 3**. This shape could be created with
`shapes.Make(int32, 2, 3)`.

Index

Constants

View Source
const (
	I64 = Int64
	F32 = Float32
	F64 = Float64
)
View Source
const PRED = Bool

PRED type is an alias to Bool, used in `tensorflow/compiler/xla/xla_data.proto`.

View Source
const UncheckedAxis = int(-1)

UncheckedAxis can be used in CheckDims or AssertDims functions for an axis whose dimension doesn't matter.

Variables

This section is empty.

Functions

func AssertDims

func AssertDims(shaped HasShape, dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertRank

func AssertRank(shaped HasShape, rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func AssertScalar

func AssertScalar(shaped HasShape)

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func CastAsDType

func CastAsDType(value any, dtype DType) any

CastAsDType casts a numeric value to the corresponding for the DType. If the value is a slice it will convert to a newly allocated slice of the given DType.

func CheckDims

func CheckDims(shaped HasShape, dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or any of the dimensions.

func CheckRank

func CheckRank(shaped HasShape, rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func CheckScalar

func CheckScalar(shaped HasShape) error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func LowestValueForDType

func LowestValueForDType(dtype DType) any

func TypeForDType

func TypeForDType(dtype DType) reflect.Type

Types

type DType

type DType int32

DType indicates the type of the unit element of a Tensor (or its representation in a computation graph). It enumerates the known data types. So far only Bool, Int32, Int64, Float32 and Float64 work.

The values of DType must match "tensorflow/compiler/xla/xla_data.pb.h", hence it needs to be an int32. TODO: do a little generate script to generate these automatically.

See example in package shapes documentation.

const (
	InvalidDType DType = iota
	Bool               // Bool, but also known as PRED in `xla_data.proto`.
	Int8               // S8
	Int16              // S16
	Int32              // S32
	Int64              // S64, in Go represented as int
	UInt8              // U8
	UInt16             // U16
	UInt32             // U32
	UInt64             // U64
	Float16            // F16
	Float32            // F32
	Float64            // F64

	BFloat16   DType = 16 // BF16
	Complex64  DType = 15 // C64
	Complex128 DType = 18 // C128

	Tuple      DType = 13
	OpaqueType DType = 14
	Token      DType = 17
)

DType constants must match `tensorflow/compiler/xla/xla_data.proto`.

func DTypeForType

func DTypeForType(t reflect.Type) DType

func DTypeGeneric

func DTypeGeneric[T Supported]() DType

func (DType) IsFloat

func (dtype DType) IsFloat() bool

IsFloat returns whether dtype is a supported float -- float types not yet supported will return false.

func (DType) IsInt

func (dtype DType) IsInt() bool

IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false.

func (DType) IsSupported

func (dtype DType) IsSupported() bool

func (DType) String

func (i DType) String() string

type GoFloat

type GoFloat interface {
	float32 | float64
}

GoFloat represent a continuous Go numeric type, supported by GoMLX.

type HasShape

type HasShape interface {
	Shape() Shape
}

HasShape is an interface for objects that have an associated Shape. `tensor.Tensor` (concrete tensor) and `graph.Node` (tensor representations in a computation graph), `context.Variable` and Shape itself implement the interface.

type MultiDimensionSlice

type MultiDimensionSlice interface {
	bool | int | float32 | float64 |
		[]bool | []int | []float32 | []float64 |
		[][]bool | [][]int | [][]float32 | [][]float64 |
		[][][]bool | [][][]int | [][][]float32 | [][][]float64 |
		[][][][]bool | [][][][]int | [][][][]float32 | [][][][]float64 |
		[][][][][]bool | [][][][][]int | [][][][][]float32 | [][][][][]float64 |
		[][][][][][]bool | [][][][][][]int | [][][][][][]float32 | [][][][][][]float64 |
		[][][][][][][]bool | [][][][][][][]int | [][][][][][][]float32 | [][][][][][][]float64 |
		[][][][][][][][]bool | [][][][][][][][]int | [][][][][][][][]float32 | [][][][][][][][]float64 |
		[][][][][][][][][]bool | [][][][][][][][][]int | [][][][][][][][][]float32 | [][][][][][][][][]float64
}

MultiDimensionSlice lists the Go types a Tensor can be converted to/from. There are no recursions in generics constraints definitions, so we enumerate up to 7 levels of slices. Feel free to add more if needed, the implementation will work with any arbitrary number.

type Number

type Number interface {
	float32 | float64 | int
}

Number represents the Go numeric types that are supported by graph package. Used as a Generics constraint. Notice that "int" becomes int64 in the implementation. Since it needs a 1:1 mapping, it doesn't support the native (Go) int64 type.

type Shape

type Shape struct {
	DType       DType
	Dimensions  []int
	TupleShapes []Shape // Shapes of the tuple, if this is a tuple.
}

Shape represents the shape of either a Tensor or the expected shape of the value from a computation node.

Use Make to create a new shape. See example in package shapes documentation.

func ConcatenateDimensions

func ConcatenateDimensions(s1, s2 Shape) (shape Shape)

ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other. It doesn't work for Tuples.

func Make

func Make(dtype DType, dimensions ...int) Shape

Make returns a Shape structure filled with the values given.

func MakeTuple

func MakeTuple(elements []Shape) Shape

MakeTuple returns a shape representing a tuple of elements with the given shapes.

func Scalar

func Scalar[T Number]() Shape

Scalar returns a scalar Shape for the given type.

func (Shape) AssertDims

func (s Shape) AssertDims(dimensions ...int)

AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertRank

func (s Shape) AssertRank(rank int)

AssertRank checks that the shape has the given rank.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) AssertScalar

func (s Shape) AssertScalar()

AssertScalar checks that the shape is a scalar.

It panics if it doesn't match.

See usage example in package shapes documentation.

func (Shape) CheckDims

func (s Shape) CheckDims(dimensions ...int) error

CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.

It returns an error if the rank is different or any of the dimensions.

func (Shape) CheckRank

func (s Shape) CheckRank(rank int) error

CheckRank checks that the shape has the given rank.

It returns an error if the rank is different.

func (Shape) CheckScalar

func (s Shape) CheckScalar() error

CheckScalar checks that the shape is a scalar.

It returns an error if shape is not a scalar.

func (Shape) Copy

func (s Shape) Copy() (s2 Shape)

Copy makes a deep copy of the shapes.

func (Shape) Eq

func (s Shape) Eq(s2 Shape) bool

Eq compares two shapes for equality: dtype and dimensions are compared.

func (Shape) IsScalar

func (s Shape) IsScalar() bool

IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).

func (Shape) IsTuple

func (s Shape) IsTuple() bool

IsTuple returns whether the shape represents a tuple.

func (Shape) Ok

func (s Shape) Ok() bool

Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.

func (Shape) Rank

func (s Shape) Rank() int

Rank of the shape, that is, the number of dimensions.

func (Shape) Shape

func (s Shape) Shape() Shape

Shape returns a shallow copy of itself. It implements the HasShape interface.

func (Shape) Size

func (s Shape) Size() (size int)

Size returns the number of elements of DType are needed for this shape. It's the product of all dimensions.

func (Shape) String

func (s Shape) String() string

String implements stringer, pretty-prints the shape.

func (Shape) TupleSize

func (s Shape) TupleSize() int

TupleSize returns the number of elements in the tuple, if it is a tuple.

type Supported

type Supported interface {
	bool | float32 | float64 | int
}

Supported represents the Go types that are supported by the graph package. Used as a Generics constraint. See also Number.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL