Documentation
¶
Overview ¶
Package dtypes includes the DType enum for all supported data types for GoMLX and the compute backends.
It includes several converters to/from Go native types (using generic functions and reflect.Type), and constants for min/max values for types, slice of DType types manipulation, etc.
It also includes some constraint interfaces to be used with generics (Number, NumberNotComplex, GoFloat).
## Half-precision data types
Float16 and BFloat16 support in Go uses the simple implementations in github.com/gomlx/compute/dtypes/float16 and github.com/gomlx/compute/dtypes/bfloat16.
Index ¶
- Constants
- Variables
- func CopyAnySlice(dst, src any)
- func DTypeStrings() []string
- func MakeAnySlice(dtype DType, length int) any
- func UnsafeAnySliceFromBytes(bytesPtr unsafe.Pointer, dtype DType, length int) any
- func UnsafeByteSlice[E Supported](flat []E) []byte
- func UnsafeByteSliceFromAny(flatAny any) []byte
- func UnsafeSliceFromBytes[E Supported](bytesPtr unsafe.Pointer, length int) []E
- type DType
- func (dtype DType) Bits() int
- func (dtype DType) GoStr() string
- func (dtype DType) GoType() reflect.Type
- func (dtype DType) HighestValue() any
- func (i DType) IsADType() bool
- func (dtype DType) IsComplex() bool
- func (dtype DType) IsFloat() bool
- func (dtype DType) IsFloat16() bool
- func (dtype DType) IsHalfPrecision() bool
- func (dtype DType) IsInt() bool
- func (dtype DType) IsPacked() bool
- func (dtype DType) IsPromotableTo(target DType) bool
- func (dtype DType) IsSupported() bool
- func (dtype DType) IsUnsigned() bool
- func (dtype DType) LowestValue() any
- func (i DType) MarshalJSON() ([]byte, error)
- func (i DType) MarshalText() ([]byte, error)
- func (i DType) MarshalYAML() (interface{}, error)
- func (dtype DType) Memory() uintptrdeprecated
- func (dtype DType) RealDType() DType
- func (dtype DType) Size() int
- func (dtype DType) SizeForDimensions(dimensions ...int) int
- func (dtype DType) SmallestNonZeroValueForDType() any
- func (i DType) String() string
- func (i *DType) UnmarshalJSON(data []byte) error
- func (i *DType) UnmarshalText(text []byte) error
- func (i *DType) UnmarshalYAML(unmarshal func(interface{}) error) error
- func (DType) Values() []string
- func (dtype DType) ValuesPerStorageUnit() int
- type DTypeSet
- type GoFloat
- type HalfPrecision
- type HalfPrecisionPtr
- type Number
- type NumberComplex
- type NumberHalfPrecision
- type NumberNotComplex
- type Supported
Constants ¶
const ( // INVALID (or PJRT_Buffer_Type_INVALID) is the C enum name for InvalidDType. INVALID = InvalidDType // PRED (or PJRT_Buffer_Type_PRED) is the C enum name for Bool. PRED = Bool // S8 (or PJRT_Buffer_Type_S8) is the C enum name for Int8. S8 = Int8 // S16 (or PJRT_Buffer_Type_S16) is the C enum name for Int16. S16 = Int16 // S32 (or PJRT_Buffer_Type_S32) is the C enum name for Int32. S32 = Int32 // S64 (or PJRT_Buffer_Type_S64) is the C enum name for Int64. S64 = Int64 // U8 (or PJRT_Buffer_Type_U8) is the C enum name for Uint8. U8 = Uint8 // U16 (or PJRT_Buffer_Type_U16) is the C enum name for Uint16. U16 = Uint16 // U32 (or PJRT_Buffer_Type_U32) is the C enum name for Uint32. U32 = Uint32 // U64 (or PJRT_Buffer_Type_U64) is the C enum name for Uint64. U64 = Uint64 // F16 (or PJRT_Buffer_Type_F16) is the C enum name for Float16. F16 = Float16 // F32 (or PJRT_Buffer_Type_F32) is the C enum name for Float32. F32 = Float32 // F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64. F64 = Float64 // BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16. BF16 = BFloat16 // C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64. C64 = Complex64 // C128 (or PJRT_Buffer_Type_C128) is the C enum name for Complex128. C128 = Complex128 // S4 (or PJRT_Buffer_Type_S4) is the C enum name for Int4. S4 = Int4 // U4 (or PJRT_Buffer_Type_U4) is the C enum name for Uint4. U4 = Uint4 // S2 (or PJRT_Buffer_Type_S2) is the C enum name for Int2. S2 = Int2 // U2 (or PJRT_Buffer_Type_U2) is the C enum name for Uint2. U2 = Uint2 // U1 (or PJRT_Buffer_Type_U1) is the C enum name for Uint1. U1 = Uint1 // S1 (or PJRT_Buffer_Type_S1) is the C enum name for Int1. S1 = Int1 )
Aliases from PJRT C API.
const MaxDType = 64
MaxDType is the maximum number of DTypes that there can be.
Variables ¶
var ( FloatDTypes = dtypeSetWith(Float32, Float64, Float16, BFloat16) Float16DTypes = dtypeSetWith(Float16, BFloat16) ComplexDTypes = dtypeSetWith(Complex64, Complex128) IntDTypes = dtypeSetWith(Int64, Int32, Int16, Int8, Int4, Int2, Int1, Uint1, Uint2, Uint4, Uint8, Uint16, Uint32, Uint64) UnsignedDTypes = dtypeSetWith(Uint8, Uint16, Uint32, Uint64, Uint4, Uint2, Uint1) SupportedDTypes = dtypeSetWith(Bool, Float16, BFloat16, Float32, Float64, Int64, Int32, Int16, Int8, Int4, Int2, Int1, Uint64, Uint32, Uint16, Uint8, Uint4, Uint2, Uint1, Complex64, Complex128) )
var MapOfNames = map[string]DType{ "InvalidDType": InvalidDType, "INVALID": InvalidDType, "Bool": Bool, "PRED": Bool, "S1": Int1, "Int1": Int1, "I1": Int1, "I2": Int2, "Int2": Int2, "S2": Int2, "I4": Int4, "Int4": Int4, "S4": Int4, "I8": Int8, "Int8": Int8, "S8": Int8, "I16": Int16, "Int16": Int16, "S16": Int16, "I32": Int32, "Int32": Int32, "S32": Int32, "I64": Int64, "Int64": Int64, "S64": Int64, "Uint1": Uint1, "U1": Uint1, "Uint2": Uint2, "U2": Uint2, "Uint4": Uint4, "U4": Uint4, "Uint8": Uint8, "U8": Uint8, "Uint16": Uint16, "U16": Uint16, "Uint32": Uint32, "U32": Uint32, "Uint64": Uint64, "U64": Uint64, "Float16": Float16, "F16": Float16, "Float32": Float32, "F32": Float32, "Float64": Float64, "F64": Float64, "BFloat16": BFloat16, "BF16": BFloat16, "Complex64": Complex64, "C64": Complex64, "Complex128": Complex128, "C128": Complex128, "F8E5M2": F8E5M2, "F8E4M3FN": F8E4M3FN, "F8E4M3B11FNUZ": F8E4M3B11FNUZ, "F8E5M2FNUZ": F8E5M2FNUZ, "F8E4M3FNUZ": F8E4M3FNUZ, "TOKEN": TOKEN, "F8E4M3": F8E4M3, "F8E3M4": F8E3M4, "F8E8M0FNU": F8E8M0FNU, "F4E2M1FN": F4E2M1FN, }
MapOfNames to their dtypes. It includes also aliases to the various dtypes. It is also later initialized to include the lower-case version of the names.
Functions ¶
func CopyAnySlice ¶
func CopyAnySlice(dst, src any)
CopyAnySlice copies the contents of src to dst, both should be slices of the same DType.
Unsafe: dst and src must be slices of the same dtype.
func DTypeStrings ¶
func DTypeStrings() []string
DTypeStrings returns a slice of all String values of the enum
func MakeAnySlice ¶
MakeAnySlice creates a slice of the given dtype and length, casted to any.
For sub-byte types (Uint1, Uint2, Uint4, Int1, Int2, Int4) it returns a slice of uint8 of with an adjusted length -- that is, enough bytes to hold those many bits/crumbs/nibbles.
func UnsafeAnySliceFromBytes ¶
UnsafeAnySliceFromBytes casts a pointer to a buffer of bytes to a slice of the given dtype and length pointing to the same data.
For sub-byte types (Uint1, Uint2, Uint4, Int1, Int2, Int4) it returns a slice of uint8 of the length adjusted to hold that many elements (packed), length is always given in number of elements (bits, crumbs, nibbles in this case).
Unsafe: bytesPtr must have enough data to hold the []dtype of the given length.
func UnsafeByteSlice ¶
UnsafeByteSlice casts a slice of any of the supported Go types to a slice of bytes.
func UnsafeByteSliceFromAny ¶
UnsafeByteSliceFromAny casts a slice of any of the supported Go types (feed as type any) to a slice of bytes.
func UnsafeSliceFromBytes ¶
UnsafeSliceFromBytes casts a pointer to a buffer of bytes to a slice of the given E type and length pointing to the same data.
If length is zero or bytesPtr is nil, it returns nil.
Unsafe: bytesPtr must have enough data to hold the []E of the given length.
Types ¶
type DType ¶
type DType int32
DType is an enum represents the data type of any buffer, scalar, tensor, multi-dimensional array, etc.
Not all DType are supported by all backends. This is a collection of all data types GoMLX know about and some more it doesn't yet handle.
## Half-precision data types
Float16 and BFloat16 support in Go uses the simple implementations in github.com/gomlx/compute/dtypes/float16 and github.com/gomlx/compute/dtypes/bfloat16.
const ( // InvalidDType is an invalid primitive type to serve as default. InvalidDType DType = 0 // Bool represents predicates which are two-state booleans. Bool DType = 1 // Int8 represents signed 8-bit integral values. Int8 DType = 2 // Int16 represents signed 16-bit integral values. Int16 DType = 3 // Int32 represents signed 32-bit integral values. Int32 DType = 4 // Int64 represents signed 64-bit integral values. Int64 DType = 5 // Uint8 represents unsigned 8-bit integral values. Uint8 DType = 6 // Uint16 represents unsigned 16-bit integral values. Uint16 DType = 7 // Uint32 represents unsigned 32-bit integral values. Uint32 DType = 8 // Uint64 represents unsigned 64-bit integral values. Uint64 DType = 9 // Float16 represents 16-bit floating-point values, a "half-precision" floating-point format. // It is referred to as IEEE 754 2008 "binary16", see https://en.wikipedia.org/wiki/Half-precision_floating-point_format Float16 DType = 10 // Float32 represents 32-bit floating-point values. Float32 DType = 11 // Float64 represents 64-bit floating-point values, also known as "double-precision" floating-point format. Float64 DType = 12 // BFloat16 represents truncated 16-bit floating-point format, a "half-precision" floating-point format. // This is similar to IEEE's 16 bit floating-point format, but uses 1 bit for the sign, 8 bits for the exponent // and 7 bits for the mantissa. // // This format is a shortened (16-bit) version of the 32-bit IEEE 754 single-precision floating-point format (binary32) // https://en.wikipedia.org/wiki/Tensor_Processing_Unit. BFloat16 DType = 13 // Complex64 represents complex values. // // Paired F32 (real, imag), as in std::complex<float>. Complex64 DType = 14 // Complex128 represents complex values. // Paired F64 (real, imag), as in std::complex<double>. Complex128 DType = 15 // F8E5M2 represents truncated 8-bit floating-point formats. F8E5M2 DType = 16 // F8E4M3FN represents truncated 8-bit floating-point formats. F8E4M3FN DType = 17 // F8E4M3B11FNUZ represents truncated 8-bit floating-point formats. F8E4M3B11FNUZ DType = 18 // F8E5M2FNUZ represents truncated 8-bit floating-point formats. F8E5M2FNUZ DType = 19 // F8E4M3FNUZ represents truncated 8-bit floating-point formats. F8E4M3FNUZ DType = 20 // Int4 represents a 4-bit integer type. // // This is assumed to be a "packet" type (multiple 4-bit "nibbles" per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Int4 DType = 21 // Uint4 represents a 4-bit unsigned integer type. // // This is assumed to be a "packet" type (multiple 4-bit "nibbles" per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Uint4 DType = 22 // TOKEN represents a token type. TOKEN DType = 23 // Int2 represents a 2-bit integer type. // // This is assumed to be a "packet" type (multiple 2-bit "crumbs" per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Int2 DType = 24 // Uint2 represents a 2-bit unsigned integer type. // // This is assumed to be a "packet" type (multiple 2-bit "crumbs" per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Uint2 DType = 25 // F8E4M3 represents truncated 8-bit floating-point formats. F8E4M3 DType = 26 // F8E3M4 represents truncated 8-bit floating-point formats. F8E3M4 DType = 27 // F8E8M0FNU represents truncated 8-bit floating-point formats. F8E8M0FNU DType = 28 // F4E2M1FN represents 4-bit MX floating-point format. F4E2M1FN DType = 29 // Int1 represents a 1-bit integer type. // // This is assumed to be a "packet" type (8 values per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Int1 DType = 30 // Uint1 represents a 1-bit unsigned integer type. // // This is assumed to be a "packet" type (8 values per byte): at least this is how it is stored // in Go tensors (in a slice of bytes), different backends may have their own varying internal representation. Uint1 DType = 31 )
func DTypeString ¶
DTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func FromAny ¶
FromAny introspects the underlying type of any and returns the corresponding DType. Non-scalar types, or unsupported types return an InvalidType.
func FromGenericsType ¶
FromGenericsType returns the DType enum for the given type that this package knows about.
func FromGoType ¶
FromGoType returns the DType for the given "reflect.Type". It panics for unknown DType values.
func (DType) Bits ¶
Bits returns the number of bits for the given DType. This is only used for "packed" storage version (Int4, Int2, Uint4, Uint2, S1, U1). Bool is never packed and hence returns 8.
func (DType) GoStr ¶
GoStr converts dtype to the corresponding Go type and convert that to string. Notice the names are different from the Dtype (so `Int64` dtype is simply `int` in Go).
Sub-byte packed values (Int2, Uint2, Int4, Uint4, Int1, Uint1) are packed as "uint8", so that's what is returned.
func (DType) HighestValue ¶
HighestValue for dtype converted to the corresponding Go type. For float values it will return infinite. There is no lowest value for complex numbers, since they are not ordered.
For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the highest value is returned as a byte, with all values set to the highest value.
func (DType) IsADType ¶
IsADType returns "true" if the value is listed in the enum definition. "false" otherwise
func (DType) IsFloat ¶
IsFloat returns whether dtype is a supported float -- float types not yet supported will return false. It returns false for complex numbers.
func (DType) IsFloat16 ¶
IsFloat16 returns whether dtype is a supported float with 16 bits: Float16 or BFloat16. Same as IsHalfPrecision.
func (DType) IsHalfPrecision ¶
IsHalfPrecision returns whether dtype is a supported float with 16 bits: Float16 or BFloat16. Same as IsFloat16.
func (DType) IsInt ¶
IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false. This include the unsigned integer types.
func (DType) IsPacked ¶
IsPacked returns whether the dtype uses less than a byte and is "packed" into bytes into memory. It is always "little-endian": the lower bits represent the first values in a sequence. E.g.: Int4, Int2, Uint4, Uint2, S1, U1.
func (DType) IsPromotableTo ¶
IsPromotableTo returns whether dtype can be promoted to target.
For example, Int32 can be promoted to Int64, but not to Uint64.
See https://openxla.org/stablehlo/spec#functions_on_types for reference.
func (DType) IsSupported ¶
IsSupported returns whether dtype is supported by `gopjrt`.
func (DType) IsUnsigned ¶
IsUnsigned returns whether dtype is one of the unsigned (only int for now) types.
func (DType) LowestValue ¶
LowestValue for dtype converted to the corresponding Go type. For float values it will return negative infinite. There is no lowest value for complex numbers, since they are not ordered.
For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the lowest value is returned as a byte, with all values set to the lowest value.
func (DType) MarshalJSON ¶
MarshalJSON implements the json.Marshaler interface for DType
func (DType) MarshalText ¶
MarshalText implements the encoding.TextMarshaler interface for DType
func (DType) MarshalYAML ¶
MarshalYAML implements a YAML Marshaler for DType
func (DType) RealDType ¶
RealDType returns the real component of complex dtypes. For float dtypes, it returns itself.
It returns InvalidDType for other non-(complex or float) dtypes.
func (DType) Size ¶
Size returns the number of bytes for the given DType, or 0 if the dtype uses fraction(s) of bytes.
func (DType) SizeForDimensions ¶
SizeForDimensions returns the size in bytes used for the given dimensions. This is a safer method than Size in case the dtype uses an underlying size that is not multiple of 8 bits.
It works also for scalar (one element) shapes where the list of dimensions is empty.
For packed types, it assumes padding the last byte where needed. E.g.: Int4.SizeForDiemensions(1) -> 1, even though only 4 bits are used for the one byte.
func (DType) SmallestNonZeroValueForDType ¶
SmallestNonZeroValueForDType is the smallest non-zero-value dtypes. Only useful for float types. The return value is converted to the corresponding Go type. There is no smallest non-zero value for complex numbers, since they are not ordered.
For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the value is returned as a byte, with all values set to the lowest non-zero value.
func (*DType) UnmarshalJSON ¶
UnmarshalJSON implements the json.Unmarshaler interface for DType
func (*DType) UnmarshalText ¶
UnmarshalText implements the encoding.TextUnmarshaler interface for DType
func (*DType) UnmarshalYAML ¶
UnmarshalYAML implements a YAML Unmarshaler for DType
func (DType) ValuesPerStorageUnit ¶
ValuesPerStorageUnit returns the number of values that fit in a storage unit for packed dtypes (uint8/byte for Int2, Int4, etc.). The "storage unit" is returned by dtype.GoType().
type GoFloat ¶
GoFloat represent a continuous Go numeric type, supported by GoMLX. It doesn't include complex numbers.
type HalfPrecision ¶
type HalfPrecision[T any] interface { float16.Float16 | bfloat16.BFloat16 Float64() float64 Float32() float32 Neg() T }
HalfPrecision is an interface that represents half-precision floating point numbers, specifically float16 and bfloat16.
It includes the methods to convert to float64 and float32, so it can be used in generic methods.
It's a generic constraint, and usually used like this:
func myGeneric[T HalfPrecision[T]](v T, ...) { ... }
type HalfPrecisionPtr ¶
type HalfPrecisionPtr[T HalfPrecision[T]] interface { *T SetFloat32(float32) SetFloat64(float64) }
HalfPrecisionPtr is a pointer to a HalfPrecision wrapper type. It is used when one needs to set the value of a HalfPrecision type from a float32 or float64.
type Number ¶
type Number interface {
float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | complex64 | complex128
}
Number represents the Go numeric types corresponding to supported DType's. Used as traits for generics.
It includes complex numbers. It doesn't include float16.Float16 or bfloat16.BFloat16 because they are not native number types.
type NumberComplex ¶
type NumberComplex interface {
complex64 | complex128
}
NumberComplex represents the Go complex types corresponding to supported DType's.
type NumberHalfPrecision ¶
NumberHalfPrecision represents the Go half-precision types corresponding to supported DType's.
type NumberNotComplex ¶
type NumberNotComplex interface {
float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
}
NumberNotComplex represents the Go numeric types corresponding to supported DType's. Used as a Generics constraint.
See also Number.
type Supported ¶
type Supported interface {
bool | float16.Float16 | bfloat16.BFloat16 |
float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 |
complex64 | complex128
}
Supported lists the Go types that `gopjrt` knows how to convert -- there are more types that can be manually converted. Used as traits for generics.
Notice Go's `int` type is not portable, since it may translate to dtypes Int32 or Int64 depending on the platform.
Directories
¶
| Path | Synopsis |
|---|---|
|
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
|
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22 |
|
Package float16 implements the IEEE 754 half-precision floating-point format (binary16).
|
Package float16 implements the IEEE 754 half-precision floating-point format (binary16). |