Documentation
¶
Overview ¶
Package tensors implements a `Tensor`, a representation of a multi-dimensional array.
Tensors are multidimensional arrays (from scalar with 0 dimensions, to arbitrarily large dimensions), defined by their shape (a data type and its axes dimensions) and their actual content. As a special case, a Tensor can also be a tuple of multiple tensors.
The main use of tensors are to be used as input and output of GoMLX computation graph.
There are various ways to construct a Tensor from local data:
FromShape(shape shapes.Shape): creates a tensor with the given shape, and zero values.
FromScalarAndDimensions[T shapes.Supported](value T, dimensions ...int): creates a Tensor with the given dimensions, filled with the scalar value given. `T` must be one of the supported types.
FromFlatDataAndDimensions[T shapes.Supported](data []T, dimensions ...int): creates a Tensor with the given dimensions, and set the flattened values with the given data. `T` must be one of the supported types. Example:
t := FromFlatDataAndDimensions([]int8{1, 2, 3, 4}, 2, 2}) // Tensor with [[1,2], [3,4]]
FromValue[S MultiDimensionSlice](value S): Generic conversion, works with the scalar supported `DType`s as well as with any arbitrary multidimensional slice of them. Slices of rank > 1 must be regular, that is all the sub-slices must have the same shape. Example:
t := FromValue([][]float{{1,2}, {3, 5}, {7, 11}})`
FromAnyValue(value any): same as FromValue but non-generic, it takes an anonymous type `any`. The exception is if `value` is already a tensor, then it is a no-op and it returns the tensor itself.
Behind the scenes Tensor is a container that keeps in sync different materialization's of value:
- `local`: a copy of the values stored in CPU, as a Go flat array of the underlying dtype.
- `onDevices`: a copy of the values stored in the accelerator device(s) (CPU, GPU, TPU, etc.), a wrapper for a "XLA's PJRT buffer" managed by the lower levels (see github.com/gomlx/gopjrt). There can be multiple `Device` backing of a tensor, if there are multiple devices (like a multi-GPU set up).
The Tensor container is lazy in nature: it won't transfer data from local storage to "on device" until needed. And if one is updated, the others are immediately invalidated.
Transferring tensors to/from local/device areas has a cost, and should be avoided. For example, while training weights of an ML model, one generally does not need to transfer those weights to local -- just at the end of training to save the model weights. But the Tensor will keep the (local/device) copies cached, so they can be used multiple times, and transfer only occurs once.
Index ¶
- Variables
- func ConstFlatData[T dtypes.Supported](t *Tensor, accessFn func(flat []T))
- func CopyFlatData[T dtypes.Supported](t *Tensor) []T
- func MutableFlatData[T dtypes.Supported](t *Tensor, accessFn func(flat []T))
- func ToScalar[T dtypes.Supported](t *Tensor) T
- type HasClient
- type MultiDimensionSlice
- type Tensor
- func FromAnyValue(value any) (t *Tensor)
- func FromBuffer(backend backends.Backend, buffer backends.Buffer) (t *Tensor)
- func FromFlatDataAndDimensions[T dtypes.Supported](data []T, dimensions ...int) (t *Tensor)
- func FromScalar[T dtypes.Supported](value T) (t *Tensor)
- func FromScalarAndDimensions[T dtypes.Supported](value T, dimensions ...int) (t *Tensor)
- func FromShape(shape shapes.Shape) (t *Tensor)
- func FromValue[S MultiDimensionSlice](value S) *Tensor
- func GobDeserialize(decoder *gob.Decoder) (t *Tensor, err error)
- func Load(filePath string) (t *Tensor, err error)
- func (t *Tensor) AssertValid()
- func (t *Tensor) Buffer(backend backends.Backend, deviceNum ...backends.DeviceNum) backends.Buffer
- func (t *Tensor) ConstBytes(accessFn func(data []byte))
- func (t *Tensor) ConstFlatData(accessFn func(flat any))
- func (t *Tensor) DType() dtypes.DType
- func (t *Tensor) DonateBuffer(backend backends.Backend, deviceNum ...backends.DeviceNum) backends.Buffer
- func (t *Tensor) Equal(otherTensor *Tensor) bool
- func (t *Tensor) FinalizeAll()
- func (t *Tensor) FinalizeLocal()
- func (t *Tensor) GoStr() string
- func (t *Tensor) GobSerialize(encoder *gob.Encoder) (err error)
- func (t *Tensor) HasLocal() bool
- func (t *Tensor) InDelta(otherTensor *Tensor, delta float64) bool
- func (t *Tensor) InvalidateOnDevice()
- func (t *Tensor) IsLocal() bool
- func (t *Tensor) IsOnDevice(deviceNum backends.DeviceNum) bool
- func (t *Tensor) IsScalar() bool
- func (t *Tensor) LayoutStrides() (strides []int)
- func (t *Tensor) LocalClone() *Tensor
- func (t *Tensor) MaterializeLocal()
- func (t *Tensor) MaterializeOnDevices(backend backends.Backend, deviceNums ...backends.DeviceNum)
- func (t *Tensor) Memory() uintptr
- func (t *Tensor) MutableBytes(accessFn func(data []byte))
- func (t *Tensor) MutableFlatData(accessFn func(flat any))
- func (t *Tensor) Ok() bool
- func (t *Tensor) Rank() int
- func (t *Tensor) Save(filePath string) (err error)
- func (t *Tensor) Shape() shapes.Shape
- func (t *Tensor) Size() int
- func (t *Tensor) String() string
- func (t *Tensor) StringN(n int) string
- func (t *Tensor) Value() any
Constants ¶
This section is empty.
Variables ¶
var MaxSizeForString = 500
MaxSizeForString is the largest Local tensor that is actually returned by String() is requested.
Functions ¶
func ConstFlatData ¶
ConstFlatData calls accessFn with the flattened data as a slice of the Go type corresponding to the DType type. Even scalar values have a flattened data representation of one element. It locks the Tensor until accessFn returns.
It is the "generics" version of Tensor.ConstFlatData(),
This provides accessFn with the actual Tensor data (not a copy), and it's owned by the Tensor, but it should not be changed -- the contents of the corresponding "on device" tensors would go out-of-sync. See Tensor.MutableFlatData to access a mutable version of the flat data.
See Tensor.Size for the number of elements, and Tensor.LayoutStrides to calculate the offset of individual positions, given the indices at each axis.
It panics if the tensor is in an invalid state (if it was finalized), or if it is a tuple.
func CopyFlatData ¶
CopyFlatData returns a copy of the flat data of the Tensor.
It triggers a synchronous transfer from device to local, if the tensor is only on device.
It will panic if the given generic type doesn't match the DType of the tensor.
func MutableFlatData ¶
MutableFlatData invalidates and frees any copy of the data on device, calls accessFn with the local flattened data as a slice of the Go type corresponding to the DType type. The contents of the slice itself can be changed until accessFn returns. During this time the Tensor is locked.
It is the "generics" version of Tensor.MutableFlatData(),
Even scalar values have a flattened data representation of one element.
the flattened data as a slice of the Go type corresponding to the DType type.
It triggers a synchronous transfer from device to local, if the tensor is only on device.
This returns the actual Tensor data (not a copy), and it's owned by the Tensor, and should not be changed -- the contents of the corresponding "on device" tensors would go out-of-sync.
See Tensor.MutableFlatData to access a mutable version of the flat data.
See Tensor.Size for the number of elements, and Tensor.LayoutStrides to calculate the offset of individual positions, given the indices at each axis.
It is only valid while `ref` hasn't been released.
It panics if the tensor is in an invalid state (if it was finalized), or if it is a tuple.
Types ¶
type HasClient ¶
HasClient accepts anything that can return a xla.Client. That includes xla.Client itself and graph.Backend.
type MultiDimensionSlice ¶
type MultiDimensionSlice interface {
bool | float32 | float64 | int | int32 | int64 | uint8 | uint32 | uint64 | complex64 | complex128 |
[]bool | []float32 | []float64 | []int | []int32 | []int64 | []uint8 | []uint32 | []uint64 | []complex64 | []complex128 |
[][]bool | [][]float32 | [][]float64 | [][]int | [][]int32 | [][]int64 | [][]uint8 | [][]uint32 | [][]uint64 | [][]complex64 | [][]complex128 |
[][][]bool | [][][]float32 | [][][]float64 | [][][]int | [][][]int32 | [][][]int64 | [][][]uint8 | [][][]uint32 | [][][]uint64 | [][][]complex64 | [][][]complex128 |
[][][][]bool | [][][][]float32 | [][][][]float64 | [][][][]int | [][][][]int32 | [][][][]int64 | [][][][]uint8 | [][][][]uint32 | [][][][]uint64 | [][][][]complex64 | [][][][]complex128 |
[][][][][]bool | [][][][][]float32 | [][][][][]float64 | [][][][][]int | [][][][][]int32 | [][][][][]int64 | [][][][][]uint8 | [][][][][]uint32 | [][][][][]uint64 | [][][][][]complex64 | [][][][][]complex128 |
[][][][][][]bool | [][][][][][]float32 | [][][][][][]float64 | [][][][][][]int | [][][][][][]int32 | [][][][][][]int64 | [][][][][][]uint8 | [][][][][][]uint32 | [][][][][][]uint64 | [][][][][][]complex64 | [][][][][][]complex128
}
MultiDimensionSlice lists the Go types a Tensor can be converted to/from. There are no recursions in generics' constraint definitions, so we enumerate up to 7 levels of slices. Feel free to add more if needed, the implementation will work with any arbitrary number.
type Tensor ¶
type Tensor struct {
// contains filtered or unexported fields
}
Tensor represents a multidimensional arrays (from scalar with 0 dimensions, to arbitrarily large dimensions), defined by their shape, a data type (dtypes.DType) and its axes' dimensions, and their actual content stored as a flat (1D) array of values.
It is a container for "local" (host CPU) and "on-device" backing of the tensor. A local backed tensor is stored as flat slice of the underlying DType.
Tensor manages caching of Local and Device copies. There is a transferring cost that one needs to be aware when using it for large data -- LLM models can have 100s of GB in size... There is a cache system to prevent duplicate transfers, but it requires some care from the user (see ConstFlatData and MutableFlatData).
More details in the `tensor` package documentation.
func FromAnyValue ¶
FromAnyValue is a non-generic version of FromValue that returns a *tensors.Tensor (not specified if local or on device). The input is expected to be either a scalar or a slice of slices with homogeneous dimensions. If the input is a tensor already (Local or Device), it is simply returned. If value is anything but a Device tensor, it will return a Local tensor.
It panics with an error if `value` type is unsupported or the shape is not regular.
func FromBuffer ¶
FromBuffer creates a Tensor from a backend's buffer. It requires the deviceNum information as well. The ownership of the buffer is transferred to the new Tensor.
func FromFlatDataAndDimensions ¶
FromFlatDataAndDimensions creates a tensor with the given dimensions, filled with the flattened values given in `data`. The data is copied to the Tensor. The `DType` is inferred from the `data` type.
func FromScalar ¶ added in v0.11.1
FromScalar creates a local tensor with the given scalar. The `DType` is inferred from the value.
func FromScalarAndDimensions ¶
FromScalarAndDimensions creates a local tensor with the given dimensions, filled with the given scalar value replicated everywhere. The `DType` is inferred from the value.
func FromShape ¶
FromShape returns a Tensor with the given shape, with the data initialized with zeros.
func FromValue ¶
func FromValue[S MultiDimensionSlice](value S) *Tensor
FromValue returns a `Local` tensor constructed from the given multi-dimension slice (or scalar). If the rank of the `value` is larger than 1, the shape of all sub-slices must be the same.
It panics if the shape is not regular.
Notice that FromFlatDataAndDimensions is much faster if speed here is a concern.
func GobDeserialize ¶
GobDeserialize a Tensor from the reader. Returns new tensor.Local or an error.
func (*Tensor) AssertValid ¶
func (t *Tensor) AssertValid()
AssertValid panics if local is nil, or if its shape is invalid.
func (*Tensor) Buffer ¶
Buffer returns the backend buffer for the tensor. It triggers the transfer from local to the device, if the tensor is not already store on device.
The deviceNum is optional. But only one can be given. The default value is 0.
Careful not to finalize the buffer while the buffer is in use.
func (*Tensor) ConstBytes ¶
ConstBytes calls accessFn with the data as a bytes slice. Even scalar values have a bytes data representation of one element. It locks the Tensor until accessFn returns.
This provides accessFn with the actual Tensor data (not a copy), and it's owned by the Tensor, but it should not be changed -- the contents of the corresponding "on device" tensors would go out-of-sync. See Tensor.MutableBytes to access a mutable version of the data as bytes.
It panics if the tensor is in an invalid state (if it was finalized), or if it is a tuple.
func (*Tensor) ConstFlatData ¶
ConstFlatData calls accessFn with the flattened data as a slice of the Go type corresponding to the DType type. Even scalar values have a flattened data representation of one element. It locks the Tensor until accessFn returns.
It triggers a synchronous transfer from device to local, if the tensor is only on device.
This provides accessFn with the actual Tensor data (not a copy), and it's owned by the Tensor, but it should not be changed -- the contents of the corresponding "on device" tensors would go out-of-sync. See Tensor.MutableFlatData to access a mutable version of the flat data.
See Tensor.Size for the number of elements, and Tensor.LayoutStrides to calculate the offset of individual positions, given the indices at each axis.
Even scalar values have a flattened data representation of one element.
It panics if the tensor is in an invalid state (if it was finalized), or if it is a tuple.
func (*Tensor) DType ¶
DType returns the DType of the tensor's shape. It is a shortcut to `Tensor.Shape().DType`.
func (*Tensor) DonateBuffer ¶
func (t *Tensor) DonateBuffer(backend backends.Backend, deviceNum ...backends.DeviceNum) backends.Buffer
DonateBuffer returns the backend buffer for the tensor, and transfers the ownership of the buffer to the caller. This may invalidate the tensor, if there is no other on-device storage or local storage.
Mostly used internally -- by graph.Graph.Run and graph.Exec when the value in the buffer is no longer needed after execution.
It triggers the transfer from local to the device, if the tensor is not already store on device.
The deviceNum is optional. But only one can be given. The default value is 0.
func (*Tensor) Equal ¶
Equal checks weather t == otherTensor. If they are the same pointer they are considered equal. If the shapes are different it returns false. If either are invalid (nil) it panics.
Slow implementation: fine for small tensors, but write something specialized for the DType if speed is desired.
func (*Tensor) FinalizeAll ¶
func (t *Tensor) FinalizeAll()
FinalizeAll immediately frees all associated data and leave Tensor in an invalid state. Shape is cleared also.
It's the caller responsibility to ensure the tensor buffers are not being used elsewhere (like in the middle of an execution).
func (*Tensor) FinalizeLocal ¶
func (t *Tensor) FinalizeLocal()
FinalizeLocal immediately frees the local storage copy of the tensor. If there are no on-device copies of the tensor, it becomes invalid.
func (*Tensor) GoStr ¶
GoStr converts to string, using a Go-syntax representation that can be copied&pasted back to code.
func (*Tensor) GobSerialize ¶
GobSerialize Tensor in binary format.
It triggers a synchronous transfer from device to local, if the tensor is only on device.
It returns an error for I/O errors. It panics for invalid tensors.
func (*Tensor) HasLocal ¶
HasLocal returns whether there is an up-to-date copy of the Tensor on local storage. If false, any access to the data (e.g.: Tensor.ConstFlatData) will require a transfer (Tensor.MaterializeToLocal).
func (*Tensor) InDelta ¶
InDelta checks weather Abs(t - otherTensor) < delta for every element. If they are the same pointer they are considered equal. If the shapes are different it returns false. If either are invalid (nil) it panics. If the DType is not a float or complex, it also panics.
Slow implementation: fine for small tensors, but write something specialized for the DType if speed is desired.
func (*Tensor) InvalidateOnDevice ¶
func (t *Tensor) InvalidateOnDevice()
InvalidateOnDevice destroys all on-device copies of the Tensor.
It's the caller responsibility to ensure this buffer is not being used elsewhere (like in the middle of an execution).
This is automatically called when the Tensor is mutated (e.g.: Tensor.MutableFlatData) or when the on-device value is donated to the execution of a graph.
If there is no local copy of the Tensor, this will invalidate the whole tensor.
Usually, this is called automatically. Mostly for internal use.
func (*Tensor) IsLocal ¶
IsLocal returns true if there is a local storage copy of the tensor.
See MaterializeLocal to trigger a transfer/copy to the local storage.
func (*Tensor) IsOnDevice ¶
IsOnDevice checks whether the Tensor has an on-device copy on the given deviceNum.
See MaterializeOnDevices to trigger a transfer/copy to the given device.
func (*Tensor) IsScalar ¶
IsScalar returns whether the tensor represents a scalar value. It is a shortcut to `Tensor.Shape().IsScalar()`.
func (*Tensor) LayoutStrides ¶
LayoutStrides return the strides for each axis. This can be handy when manipulating the flat data.
func (*Tensor) LocalClone ¶
LocalClone creates a clone of the Tensor value with local backing. It will trigger a transfer from on-device data to local, if the value is not present in local memory yet.
func (*Tensor) MaterializeLocal ¶
func (t *Tensor) MaterializeLocal()
MaterializeLocal will make sure there is a local storage of the tensor. If there isn't already a local copy, this triggers a transfer from an on-device storage to a local copy.
func (*Tensor) MaterializeOnDevices ¶
MaterializeOnDevices will transfer a Tensor from local storage to the given devices, if needed. Generally the user doesn't need to call this function, it's called by the libraries executing GoMLX computations automatically when needed.
- If an updated copy of the Tensor is already on the device(s), this is a no-op. - If the Tensor has already been used with a different client, this panics: one cannot mix clients on the same Tensor. - If no deviceNum is given, 0 is assumed, the default device for the client.
TODO: For now this only transfers from local storage to on-device. Implement cross-device copy in gopjrt.
func (*Tensor) Memory ¶
Memory returns the number of bytes used to store the tensor. An alias to Tensor.Shape().Memory().
func (*Tensor) MutableBytes ¶
MutableBytes gives mutable access to the local storage of the values for the tensor. It's similar to MutableFlatData, but provide a bytes view to the same data.
It triggers a synchronous transfer from device to local, if the tensor is only on device, and it invalidates the device storage, since it's assumed they will be out-of-date.
This returns the actual Tensor data (not a copy), and the bytes slice is owned by the Tensor -- but it's contents can be changed.
See Tensor.ConstBytes for constant access to the data as bytes -- that doesn't invalidate the device storage.
func (*Tensor) MutableFlatData ¶
MutableFlatData invalidates and frees any copy of the data on device, calls accessFn with the local flattened data as a slice of the Go type corresponding to the DType type. The contents of the slice itself can be changed until accessFn returns. During this time the Tensor is locked.
Even scalar values have a flattened data representation of one element.
the flattened data as a slice of the Go type corresponding to the DType type.
It triggers a synchronous transfer from device to local, if the tensor is only on device, and it invalidates the device storage, since it's assumed they will be out-of-date.
This returns the actual Tensor data (not a copy), and the slice is owned by the Tensor -- but it's contents can be changed.
See Tensor.ConstFlatData to access a mutable version of the flat data.
See Tensor.Size for the number of elements, and Tensor.LayoutStrides to calculate the offset of individual positions, given the indices at each axis.
It panics if the tensor is in an invalid state (if it was finalized), or if it is a tuple.
func (*Tensor) Ok ¶
Ok returns whether the Tensor is in a valid state: it is not nil, and it hasn't been finalized.
func (*Tensor) Rank ¶
Rank returns the rank of the tensor's shape. It is a shortcut to `Tensor.Shape().Rank()`.
func (*Tensor) Save ¶
Save the Local tensor to the given file path.
It returns an error for I/O errors. It may panic if the tensor is invalid (`nil` or already finalized).
func (*Tensor) Size ¶
Size returns the number of elements in the tensor. It is a shortcut to `Tensor.Shape().IsScalar()`.
func (*Tensor) StringN ¶
StringN converts to string, displaying at most n elements. TODO: nice pretty-print version, even for large tensors.
func (*Tensor) Value ¶
Value returns a multidimensional slice (except if shape is a scalar) containing a copy of the values stored in the tensor. This is expensive, and usually only used for smaller tensors in tests and to print results.
If the local tensor is empty it panics with the corresponding error.