dtypes

package
v0.0.0-...-475f5b6 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: May 13, 2026 License: Apache-2.0 Imports: 11 Imported by: 0

Documentation

Overview

Package dtypes includes the DType enum for all supported data types for GoMLX and the compute backends.

It includes several converters to/from Go native types (using generic functions and reflect.Type), and constants for min/max values for types, slice of DType types manipulation, etc.

It also includes some constraint interfaces to be used with generics (Number, NumberNotComplex, GoFloat).

## Half-precision data types

Float16 and BFloat16 support in Go uses the simple implementations in github.com/gomlx/compute/dtypes/float16 and github.com/gomlx/compute/dtypes/bfloat16.

Index

Constants

View Source
const (
	// INVALID (or PJRT_Buffer_Type_INVALID) is the C enum name for InvalidDType.
	INVALID = InvalidDType

	// PRED (or PJRT_Buffer_Type_PRED) is the C enum name for Bool.
	PRED = Bool

	// S8 (or PJRT_Buffer_Type_S8) is the C enum name for Int8.
	S8 = Int8

	// S16 (or PJRT_Buffer_Type_S16) is the C enum name for Int16.
	S16 = Int16

	// S32 (or PJRT_Buffer_Type_S32) is the C enum name for Int32.
	S32 = Int32

	// S64 (or PJRT_Buffer_Type_S64) is the C enum name for Int64.
	S64 = Int64

	// U8 (or PJRT_Buffer_Type_U8) is the C enum name for Uint8.
	U8 = Uint8

	// U16 (or PJRT_Buffer_Type_U16) is the C enum name for Uint16.
	U16 = Uint16

	// U32 (or PJRT_Buffer_Type_U32) is the C enum name for Uint32.
	U32 = Uint32

	// U64 (or PJRT_Buffer_Type_U64) is the C enum name for Uint64.
	U64 = Uint64

	// F16 (or PJRT_Buffer_Type_F16) is the C enum name for Float16.
	F16 = Float16

	// F32 (or PJRT_Buffer_Type_F32) is the C enum name for Float32.
	F32 = Float32

	// F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64.
	F64 = Float64

	// BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16.
	BF16 = BFloat16

	// C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64.
	C64 = Complex64

	// C128 (or PJRT_Buffer_Type_C128) is the C enum name for Complex128.
	C128 = Complex128

	// S4 (or PJRT_Buffer_Type_S4) is the C enum name for Int4.
	S4 = Int4

	// U4 (or PJRT_Buffer_Type_U4) is the C enum name for Uint4.
	U4 = Uint4

	// S2 (or PJRT_Buffer_Type_S2) is the C enum name for Int2.
	S2 = Int2

	// U2 (or PJRT_Buffer_Type_U2) is the C enum name for Uint2.
	U2 = Uint2

	// U1 (or PJRT_Buffer_Type_U1) is the C enum name for Uint1.
	U1 = Uint1

	// S1 (or PJRT_Buffer_Type_S1) is the C enum name for Int1.
	S1 = Int1
)

Aliases from PJRT C API.

View Source
const MaxDType = 64

MaxDType is the maximum number of DTypes that there can be.

Variables

View Source
var (
	FloatDTypes     = dtypeSetWith(Float32, Float64, Float16, BFloat16)
	Float16DTypes   = dtypeSetWith(Float16, BFloat16)
	ComplexDTypes   = dtypeSetWith(Complex64, Complex128)
	IntDTypes       = dtypeSetWith(Int64, Int32, Int16, Int8, Int4, Int2, Int1, Uint1, Uint2, Uint4, Uint8, Uint16, Uint32, Uint64)
	UnsignedDTypes  = dtypeSetWith(Uint8, Uint16, Uint32, Uint64, Uint4, Uint2, Uint1)
	SupportedDTypes = dtypeSetWith(Bool,
		Float16, BFloat16, Float32, Float64,
		Int64, Int32, Int16, Int8, Int4, Int2, Int1,
		Uint64, Uint32, Uint16, Uint8, Uint4, Uint2, Uint1,
		Complex64, Complex128)
)
View Source
var MapOfNames = map[string]DType{
	"InvalidDType": InvalidDType,
	"INVALID":      InvalidDType,
	"Bool":         Bool,
	"PRED":         Bool,

	"S1":   Int1,
	"Int1": Int1,
	"I1":   Int1,

	"I2":   Int2,
	"Int2": Int2,
	"S2":   Int2,

	"I4":   Int4,
	"Int4": Int4,
	"S4":   Int4,

	"I8":   Int8,
	"Int8": Int8,
	"S8":   Int8,

	"I16":   Int16,
	"Int16": Int16,
	"S16":   Int16,

	"I32":   Int32,
	"Int32": Int32,
	"S32":   Int32,

	"I64":   Int64,
	"Int64": Int64,
	"S64":   Int64,

	"Uint1":  Uint1,
	"U1":     Uint1,
	"Uint2":  Uint2,
	"U2":     Uint2,
	"Uint4":  Uint4,
	"U4":     Uint4,
	"Uint8":  Uint8,
	"U8":     Uint8,
	"Uint16": Uint16,
	"U16":    Uint16,
	"Uint32": Uint32,
	"U32":    Uint32,
	"Uint64": Uint64,
	"U64":    Uint64,

	"Float16":       Float16,
	"F16":           Float16,
	"Float32":       Float32,
	"F32":           Float32,
	"Float64":       Float64,
	"F64":           Float64,
	"BFloat16":      BFloat16,
	"BF16":          BFloat16,
	"Complex64":     Complex64,
	"C64":           Complex64,
	"Complex128":    Complex128,
	"C128":          Complex128,
	"F8E5M2":        F8E5M2,
	"F8E4M3FN":      F8E4M3FN,
	"F8E4M3B11FNUZ": F8E4M3B11FNUZ,
	"F8E5M2FNUZ":    F8E5M2FNUZ,
	"F8E4M3FNUZ":    F8E4M3FNUZ,
	"TOKEN":         TOKEN,
	"F8E4M3":        F8E4M3,
	"F8E3M4":        F8E3M4,
	"F8E8M0FNU":     F8E8M0FNU,
	"F4E2M1FN":      F4E2M1FN,
}

MapOfNames to their dtypes. It includes also aliases to the various dtypes. It is also later initialized to include the lower-case version of the names.

Functions

func CopyAnySlice

func CopyAnySlice(dst, src any)

CopyAnySlice copies the contents of src to dst, both should be slices of the same DType.

Unsafe: dst and src must be slices of the same dtype.

func DTypeStrings

func DTypeStrings() []string

DTypeStrings returns a slice of all String values of the enum

func MakeAnySlice

func MakeAnySlice(dtype DType, length int) any

MakeAnySlice creates a slice of the given dtype and length, casted to any.

For sub-byte types (Uint1, Uint2, Uint4, Int1, Int2, Int4) it returns a slice of uint8 of with an adjusted length -- that is, enough bytes to hold those many bits/crumbs/nibbles.

func UnsafeAnySliceFromBytes

func UnsafeAnySliceFromBytes(bytesPtr unsafe.Pointer, dtype DType, length int) any

UnsafeAnySliceFromBytes casts a pointer to a buffer of bytes to a slice of the given dtype and length pointing to the same data.

For sub-byte types (Uint1, Uint2, Uint4, Int1, Int2, Int4) it returns a slice of uint8 of the length adjusted to hold that many elements (packed), length is always given in number of elements (bits, crumbs, nibbles in this case).

Unsafe: bytesPtr must have enough data to hold the []dtype of the given length.

func UnsafeByteSlice

func UnsafeByteSlice[E Supported](flat []E) []byte

UnsafeByteSlice casts a slice of any of the supported Go types to a slice of bytes.

func UnsafeByteSliceFromAny

func UnsafeByteSliceFromAny(flatAny any) []byte

UnsafeByteSliceFromAny casts a slice of any of the supported Go types (feed as type any) to a slice of bytes.

func UnsafeSliceFromBytes

func UnsafeSliceFromBytes[E Supported](bytesPtr unsafe.Pointer, length int) []E

UnsafeSliceFromBytes casts a pointer to a buffer of bytes to a slice of the given E type and length pointing to the same data.

If length is zero or bytesPtr is nil, it returns nil.

Unsafe: bytesPtr must have enough data to hold the []E of the given length.

Types

type DType

type DType int32

DType is an enum represents the data type of any buffer, scalar, tensor, multi-dimensional array, etc.

Not all DType are supported by all backends. This is a collection of all data types GoMLX know about and some more it doesn't yet handle.

## Half-precision data types

Float16 and BFloat16 support in Go uses the simple implementations in github.com/gomlx/compute/dtypes/float16 and github.com/gomlx/compute/dtypes/bfloat16.

const (
	// InvalidDType is an invalid primitive type to serve as default.
	InvalidDType DType = 0

	// Bool represents predicates which are two-state booleans.
	Bool DType = 1

	// Int8 represents signed 8-bit integral values.
	Int8 DType = 2

	// Int16 represents signed 16-bit integral values.
	Int16 DType = 3

	// Int32 represents signed 32-bit integral values.
	Int32 DType = 4

	// Int64 represents signed 64-bit integral values.
	Int64 DType = 5

	// Uint8 represents unsigned 8-bit integral values.
	Uint8 DType = 6

	// Uint16 represents unsigned 16-bit integral values.
	Uint16 DType = 7

	// Uint32 represents unsigned 32-bit integral values.
	Uint32 DType = 8

	// Uint64 represents unsigned 64-bit integral values.
	Uint64 DType = 9

	// Float16 represents 16-bit floating-point values, a "half-precision" floating-point format.
	// It is referred to as IEEE 754 2008 "binary16", see https://en.wikipedia.org/wiki/Half-precision_floating-point_format
	Float16 DType = 10

	// Float32 represents 32-bit floating-point values.
	Float32 DType = 11

	// Float64 represents 64-bit floating-point values, also known as "double-precision" floating-point format.
	Float64 DType = 12

	// BFloat16 represents truncated 16-bit floating-point format, a "half-precision" floating-point format.
	// This is similar to IEEE's 16 bit floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
	// and 7 bits for the mantissa.
	//
	// This format is a shortened (16-bit) version of the 32-bit IEEE 754 single-precision floating-point format (binary32)
	// https://en.wikipedia.org/wiki/Tensor_Processing_Unit.
	BFloat16 DType = 13

	// Complex64 represents complex values.
	//
	// Paired F32 (real, imag), as in std::complex<float>.
	Complex64 DType = 14

	// Complex128 represents complex values.
	// Paired F64 (real, imag), as in std::complex<double>.
	Complex128 DType = 15

	// F8E5M2 represents truncated 8-bit floating-point formats.
	F8E5M2 DType = 16

	// F8E4M3FN represents truncated 8-bit floating-point formats.
	F8E4M3FN DType = 17

	// F8E4M3B11FNUZ represents truncated 8-bit floating-point formats.
	F8E4M3B11FNUZ DType = 18

	// F8E5M2FNUZ represents truncated 8-bit floating-point formats.
	F8E5M2FNUZ DType = 19

	// F8E4M3FNUZ represents truncated 8-bit floating-point formats.
	F8E4M3FNUZ DType = 20

	// Int4 represents a 4-bit integer type.
	//
	// This is assumed to be a "packet" type (multiple 4-bit "nibbles" per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Int4 DType = 21

	// Uint4 represents a 4-bit unsigned integer type.
	//
	// This is assumed to be a "packet" type (multiple 4-bit "nibbles" per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Uint4 DType = 22

	// TOKEN represents a token type.
	TOKEN DType = 23

	// Int2 represents a 2-bit integer type.
	//
	// This is assumed to be a "packet" type (multiple 2-bit "crumbs" per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Int2 DType = 24

	// Uint2 represents a 2-bit unsigned integer type.
	//
	// This is assumed to be a "packet" type (multiple 2-bit "crumbs" per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Uint2 DType = 25

	// F8E4M3 represents truncated 8-bit floating-point formats.
	F8E4M3 DType = 26

	// F8E3M4 represents truncated 8-bit floating-point formats.
	F8E3M4 DType = 27

	// F8E8M0FNU represents truncated 8-bit floating-point formats.
	F8E8M0FNU DType = 28

	// F4E2M1FN represents 4-bit MX floating-point format.
	F4E2M1FN DType = 29

	// Int1 represents a 1-bit integer type.
	//
	// This is assumed to be a "packet" type (8 values per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Int1 DType = 30

	// Uint1 represents a 1-bit unsigned integer type.
	//
	// This is assumed to be a "packet" type (8 values per byte): at least this is how it is stored
	// in Go tensors (in a slice of bytes), different backends may have their own varying internal representation.
	Uint1 DType = 31
)

func DTypeString

func DTypeString(s string) (DType, error)

DTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func DTypeValues

func DTypeValues() []DType

DTypeValues returns all values of the enum

func FromAny

func FromAny(value any) DType

FromAny introspects the underlying type of any and returns the corresponding DType. Non-scalar types, or unsupported types return an InvalidType.

func FromGenericsType

func FromGenericsType[T Supported]() DType

FromGenericsType returns the DType enum for the given type that this package knows about.

func FromGoType

func FromGoType(t reflect.Type) DType

FromGoType returns the DType for the given "reflect.Type". It panics for unknown DType values.

func (DType) Bits

func (dtype DType) Bits() int

Bits returns the number of bits for the given DType. This is only used for "packed" storage version (Int4, Int2, Uint4, Uint2, S1, U1). Bool is never packed and hence returns 8.

func (DType) GoStr

func (dtype DType) GoStr() string

GoStr converts dtype to the corresponding Go type and convert that to string. Notice the names are different from the Dtype (so `Int64` dtype is simply `int` in Go).

Sub-byte packed values (Int2, Uint2, Int4, Uint4, Int1, Uint1) are packed as "uint8", so that's what is returned.

func (DType) GoType

func (dtype DType) GoType() reflect.Type

GoType returns the Go `reflect.Type` corresponding to the tensor DType.

func (DType) HighestValue

func (dtype DType) HighestValue() any

HighestValue for dtype converted to the corresponding Go type. For float values it will return infinite. There is no lowest value for complex numbers, since they are not ordered.

For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the highest value is returned as a byte, with all values set to the highest value.

func (DType) IsADType

func (i DType) IsADType() bool

IsADType returns "true" if the value is listed in the enum definition. "false" otherwise

func (DType) IsComplex

func (dtype DType) IsComplex() bool

IsComplex returns whether dtype is a supported complex number type.

func (DType) IsFloat

func (dtype DType) IsFloat() bool

IsFloat returns whether dtype is a supported float -- float types not yet supported will return false. It returns false for complex numbers.

func (DType) IsFloat16

func (dtype DType) IsFloat16() bool

IsFloat16 returns whether dtype is a supported float with 16 bits: Float16 or BFloat16. Same as IsHalfPrecision.

func (DType) IsHalfPrecision

func (dtype DType) IsHalfPrecision() bool

IsHalfPrecision returns whether dtype is a supported float with 16 bits: Float16 or BFloat16. Same as IsFloat16.

func (DType) IsInt

func (dtype DType) IsInt() bool

IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false. This include the unsigned integer types.

func (DType) IsPacked

func (dtype DType) IsPacked() bool

IsPacked returns whether the dtype uses less than a byte and is "packed" into bytes into memory. It is always "little-endian": the lower bits represent the first values in a sequence. E.g.: Int4, Int2, Uint4, Uint2, S1, U1.

func (DType) IsPromotableTo

func (dtype DType) IsPromotableTo(target DType) bool

IsPromotableTo returns whether dtype can be promoted to target.

For example, Int32 can be promoted to Int64, but not to Uint64.

See https://openxla.org/stablehlo/spec#functions_on_types for reference.

func (DType) IsSupported

func (dtype DType) IsSupported() bool

IsSupported returns whether dtype is supported by `gopjrt`.

func (DType) IsUnsigned

func (dtype DType) IsUnsigned() bool

IsUnsigned returns whether dtype is one of the unsigned (only int for now) types.

func (DType) LowestValue

func (dtype DType) LowestValue() any

LowestValue for dtype converted to the corresponding Go type. For float values it will return negative infinite. There is no lowest value for complex numbers, since they are not ordered.

For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the lowest value is returned as a byte, with all values set to the lowest value.

func (DType) MarshalJSON

func (i DType) MarshalJSON() ([]byte, error)

MarshalJSON implements the json.Marshaler interface for DType

func (DType) MarshalText

func (i DType) MarshalText() ([]byte, error)

MarshalText implements the encoding.TextMarshaler interface for DType

func (DType) MarshalYAML

func (i DType) MarshalYAML() (interface{}, error)

MarshalYAML implements a YAML Marshaler for DType

func (DType) Memory deprecated

func (dtype DType) Memory() uintptr

Memory returns the number of bytes for the given DType. It's an alias to Size, converted to uintptr.

Deprecated: use Size() instead.

func (DType) RealDType

func (dtype DType) RealDType() DType

RealDType returns the real component of complex dtypes. For float dtypes, it returns itself.

It returns InvalidDType for other non-(complex or float) dtypes.

func (DType) Size

func (dtype DType) Size() int

Size returns the number of bytes for the given DType, or 0 if the dtype uses fraction(s) of bytes.

func (DType) SizeForDimensions

func (dtype DType) SizeForDimensions(dimensions ...int) int

SizeForDimensions returns the size in bytes used for the given dimensions. This is a safer method than Size in case the dtype uses an underlying size that is not multiple of 8 bits.

It works also for scalar (one element) shapes where the list of dimensions is empty.

For packed types, it assumes padding the last byte where needed. E.g.: Int4.SizeForDiemensions(1) -> 1, even though only 4 bits are used for the one byte.

func (DType) SmallestNonZeroValueForDType

func (dtype DType) SmallestNonZeroValueForDType() any

SmallestNonZeroValueForDType is the smallest non-zero-value dtypes. Only useful for float types. The return value is converted to the corresponding Go type. There is no smallest non-zero value for complex numbers, since they are not ordered.

For the packed sub-byte types (Int4, Int2, Uint4, Uint2, Int1, Uint1), the value is returned as a byte, with all values set to the lowest non-zero value.

func (DType) String

func (i DType) String() string

func (*DType) UnmarshalJSON

func (i *DType) UnmarshalJSON(data []byte) error

UnmarshalJSON implements the json.Unmarshaler interface for DType

func (*DType) UnmarshalText

func (i *DType) UnmarshalText(text []byte) error

UnmarshalText implements the encoding.TextUnmarshaler interface for DType

func (*DType) UnmarshalYAML

func (i *DType) UnmarshalYAML(unmarshal func(interface{}) error) error

UnmarshalYAML implements a YAML Unmarshaler for DType

func (DType) Values

func (DType) Values() []string

func (DType) ValuesPerStorageUnit

func (dtype DType) ValuesPerStorageUnit() int

ValuesPerStorageUnit returns the number of values that fit in a storage unit for packed dtypes (uint8/byte for Int2, Int4, etc.). The "storage unit" is returned by dtype.GoType().

type DTypeSet

type DTypeSet [MaxDType]bool

DTypeSet represents a set of DTypes.

type GoFloat

type GoFloat interface {
	float32 | float64
}

GoFloat represent a continuous Go numeric type, supported by GoMLX. It doesn't include complex numbers.

type HalfPrecision

type HalfPrecision[T any] interface {
	float16.Float16 | bfloat16.BFloat16
	Float64() float64
	Float32() float32
	Neg() T
}

HalfPrecision is an interface that represents half-precision floating point numbers, specifically float16 and bfloat16.

It includes the methods to convert to float64 and float32, so it can be used in generic methods.

It's a generic constraint, and usually used like this:

func myGeneric[T HalfPrecision[T]](v T, ...) { ... }

type HalfPrecisionPtr

type HalfPrecisionPtr[T HalfPrecision[T]] interface {
	*T
	SetFloat32(float32)
	SetFloat64(float64)
}

HalfPrecisionPtr is a pointer to a HalfPrecision wrapper type. It is used when one needs to set the value of a HalfPrecision type from a float32 or float64.

type Number

type Number interface {
	float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | complex64 | complex128
}

Number represents the Go numeric types corresponding to supported DType's. Used as traits for generics.

It includes complex numbers. It doesn't include float16.Float16 or bfloat16.BFloat16 because they are not native number types.

type NumberComplex

type NumberComplex interface {
	complex64 | complex128
}

NumberComplex represents the Go complex types corresponding to supported DType's.

type NumberHalfPrecision

type NumberHalfPrecision interface {
	float16.Float16 | bfloat16.BFloat16
}

NumberHalfPrecision represents the Go half-precision types corresponding to supported DType's.

type NumberNotComplex

type NumberNotComplex interface {
	float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
}

NumberNotComplex represents the Go numeric types corresponding to supported DType's. Used as a Generics constraint.

See also Number.

type Supported

type Supported interface {
	bool | float16.Float16 | bfloat16.BFloat16 |
		float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 |
		complex64 | complex128
}

Supported lists the Go types that `gopjrt` knows how to convert -- there are more types that can be manually converted. Used as traits for generics.

Notice Go's `int` type is not portable, since it may translate to dtypes Int32 or Int64 depending on the platform.

Directories

Path Synopsis
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
Package float16 implements the IEEE 754 half-precision floating-point format (binary16).
Package float16 implements the IEEE 754 half-precision floating-point format (binary16).

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL