Documentation
¶
Overview ¶
Package shapes defines Shape and DType and associated tools.
Shape represents the shape (rank, dimensions and DType) of either a Tensor or the expected shape of a node in a computation Graph. DType indicates the type of the unit element of a Tensor (or its representation as a node in a computation Graph).
Shape and DType are used both by the concrete tensor values (see tensor package) and when working on the computation graph (see graph package).
**Asserts**: When coding ML models, one delicate part is keeping tabs on the shape of the nodes of the graphs -- unfortunately there is no compile-time checking of values, so validation only happens in runtime. To facilitate, and also to serve as code documentation, this package provides two variations of _assert_ funtionality. Examples:
- `AssertRank` and `AssertDims` checks that the rank and dimensions of the given object (that has a `Shape` method) match, otherwise it panics. The `-1` means the dimension is unchecked (it can be anything).
```
func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
_ = spec // Not needed here, we know the dataset.
shapes.AssertRank(inputs, 2)
batchSize := inputs.Shape().Dimensions[0]
logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1)
shapes.AssertDims(logits, batchSize, -1)
return []*Node{logits}
}
```
**Glossary**:
- **Rank**: number of axes (dimensions) of a Tensor.
- **Axis**: is the index of a dimension on a multi-dimensional Tensor. Sometimes used interchangeably with Dimension, but here we try to refer to a dimension index as "axis" (plural axes), and its size as its dimension.
- **Dimension**: the size of a multi-dimensions Tensor in one of its axes. See example below:
- **DType**: the data type of the unit element in a tensor.
- **Scalar**: is a shape where there are no axes (or dimensions), only a single value of the associated DType.
Example:
The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor
would have shape `(int32)[2 3]`. We say it has **rank 2** (so 2 axes), **axis 0** has
**dimension 2**, and **axis 1** has **dimension 3**. This shape could be created with
`shapes.Make(int32, 2, 3)`.
Index ¶
- Constants
- func AssertDims(shaped HasShape, dimensions ...int)
- func AssertRank(shaped HasShape, rank int)
- func AssertScalar(shaped HasShape)
- func CastAsDType(value any, dtype DType) any
- func CheckDims(shaped HasShape, dimensions ...int) error
- func CheckRank(shaped HasShape, rank int) error
- func CheckScalar(shaped HasShape) error
- func LowestValueForDType(dtype DType) any
- func TypeForDType(dtype DType) reflect.Type
- type DType
- type GoFloat
- type HasShape
- type MultiDimensionSlice
- type Number
- type Shape
- func (s Shape) AssertDims(dimensions ...int)
- func (s Shape) AssertRank(rank int)
- func (s Shape) AssertScalar()
- func (s Shape) CheckDims(dimensions ...int) error
- func (s Shape) CheckRank(rank int) error
- func (s Shape) CheckScalar() error
- func (s Shape) Copy() (s2 Shape)
- func (s Shape) Eq(s2 Shape) bool
- func (s Shape) IsScalar() bool
- func (s Shape) IsTuple() bool
- func (s Shape) Ok() bool
- func (s Shape) Rank() int
- func (s Shape) Shape() Shape
- func (s Shape) Size() (size int)
- func (s Shape) String() string
- func (s Shape) TupleSize() int
- type Supported
Constants ¶
const ( I64 = Int64 F32 = Float32 F64 = Float64 )
const PRED = Bool
PRED type is an alias to Bool, used in `tensorflow/compiler/xla/xla_data.proto`.
const UncheckedAxis = int(-1)
UncheckedAxis can be used in CheckDims or AssertDims functions for an axis whose dimension doesn't matter.
Variables ¶
This section is empty.
Functions ¶
func AssertDims ¶
AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It panics if it doesn't match.
See usage example in package shapes documentation.
func AssertRank ¶
AssertRank checks that the shape has the given rank.
It panics if it doesn't match.
See usage example in package shapes documentation.
func AssertScalar ¶
func AssertScalar(shaped HasShape)
AssertScalar checks that the shape is a scalar.
It panics if it doesn't match.
See usage example in package shapes documentation.
func CastAsDType ¶
CastAsDType casts a numeric value to the corresponding for the DType. If the value is a slice it will convert to a newly allocated slice of the given DType.
func CheckDims ¶
CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It returns an error if the rank is different or any of the dimensions.
func CheckRank ¶
CheckRank checks that the shape has the given rank.
It returns an error if the rank is different.
func CheckScalar ¶
CheckScalar checks that the shape is a scalar.
It returns an error if shape is not a scalar.
func LowestValueForDType ¶
func TypeForDType ¶
Types ¶
type DType ¶
type DType int32
DType indicates the type of the unit element of a Tensor (or its representation in a computation graph). It enumerates the known data types. So far only Bool, Int32, Int64, Float32 and Float64 work.
The values of DType must match "tensorflow/compiler/xla/xla_data.pb.h", hence it needs to be an int32. TODO: do a little generate script to generate these automatically.
See example in package shapes documentation.
const ( InvalidDType DType = iota Bool // Bool, but also known as PRED in `xla_data.proto`. Int8 // S8 Int16 // S16 Int32 // S32 Int64 // S64, in Go represented as int UInt8 // U8 UInt16 // U16 UInt32 // U32 UInt64 // U64 Float16 // F16 Float32 // F32 Float64 // F64 BFloat16 DType = 16 // BF16 Complex64 DType = 15 // C64 Complex128 DType = 18 // C128 Tuple DType = 13 OpaqueType DType = 14 Token DType = 17 )
DType constants must match `tensorflow/compiler/xla/xla_data.proto`.
func DTypeForType ¶
func DTypeGeneric ¶
func (DType) IsFloat ¶
IsFloat returns whether dtype is a supported float -- float types not yet supported will return false.
func (DType) IsInt ¶
IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false.
func (DType) IsSupported ¶
type HasShape ¶
type HasShape interface {
Shape() Shape
}
HasShape is an interface for objects that have an associated Shape. `tensor.Tensor` (concrete tensor) and `graph.Node` (tensor representations in a computation graph), `context.Variable` and Shape itself implement the interface.
type MultiDimensionSlice ¶
type MultiDimensionSlice interface {
bool | int | float32 | float64 |
[]bool | []int | []float32 | []float64 |
[][]bool | [][]int | [][]float32 | [][]float64 |
[][][]bool | [][][]int | [][][]float32 | [][][]float64 |
[][][][]bool | [][][][]int | [][][][]float32 | [][][][]float64 |
[][][][][]bool | [][][][][]int | [][][][][]float32 | [][][][][]float64 |
[][][][][][]bool | [][][][][][]int | [][][][][][]float32 | [][][][][][]float64 |
[][][][][][][]bool | [][][][][][][]int | [][][][][][][]float32 | [][][][][][][]float64 |
[][][][][][][][]bool | [][][][][][][][]int | [][][][][][][][]float32 | [][][][][][][][]float64 |
[][][][][][][][][]bool | [][][][][][][][][]int | [][][][][][][][][]float32 | [][][][][][][][][]float64
}
MultiDimensionSlice lists the Go types a Tensor can be converted to/from. There are no recursions in generics constraints definitions, so we enumerate up to 7 levels of slices. Feel free to add more if needed, the implementation will work with any arbitrary number.
type Number ¶
Number represents the Go numeric types that are supported by graph package. Used as a Generics constraint. Notice that "int" becomes int64 in the implementation. Since it needs a 1:1 mapping, it doesn't support the native (Go) int64 type.
type Shape ¶
type Shape struct {
DType DType
Dimensions []int
TupleShapes []Shape // Shapes of the tuple, if this is a tuple.
}
Shape represents the shape of either a Tensor or the expected shape of the value from a computation node.
Use Make to create a new shape. See example in package shapes documentation.
func ConcatenateDimensions ¶
ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other. It doesn't work for Tuples.
func (Shape) AssertDims ¶
AssertDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) AssertRank ¶
AssertRank checks that the shape has the given rank.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) AssertScalar ¶
func (s Shape) AssertScalar()
AssertScalar checks that the shape is a scalar.
It panics if it doesn't match.
See usage example in package shapes documentation.
func (Shape) CheckDims ¶
CheckDims checks that the shape has the given dimensions and rank. A value of -1 in dimensions means it can take any value and is not checked.
It returns an error if the rank is different or any of the dimensions.
func (Shape) CheckRank ¶
CheckRank checks that the shape has the given rank.
It returns an error if the rank is different.
func (Shape) CheckScalar ¶
CheckScalar checks that the shape is a scalar.
It returns an error if shape is not a scalar.
func (Shape) IsScalar ¶
IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).
func (Shape) Ok ¶
Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.
func (Shape) Size ¶
Size returns the number of elements of DType are needed for this shape. It's the product of all dimensions.