types

package
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 4, 2025 License: Apache-2.0 Imports: 5 Imported by: 1

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func ChannelTypeStrings added in v0.1.0

func ChannelTypeStrings() []string

ChannelTypeStrings returns a slice of all String values of the enum

func ComparisonDirectionStrings

func ComparisonDirectionStrings() []string

ComparisonDirectionStrings returns a slice of all String values of the enum

func ComparisonTypeStrings

func ComparisonTypeStrings() []string

ComparisonTypeStrings returns a slice of all String values of the enum

func DotGeneralPrecisionTypeStrings

func DotGeneralPrecisionTypeStrings() []string

DotGeneralPrecisionTypeStrings returns a slice of all String values of the enum

func FFTTypeStrings

func FFTTypeStrings() []string

FFTTypeStrings returns a slice of all String values of the enum

func RNGBitGeneratorAlgorithmStrings added in v0.2.0

func RNGBitGeneratorAlgorithmStrings() []string

RNGBitGeneratorAlgorithmStrings returns a slice of all String values of the enum

Types

type ChannelType added in v0.1.0

type ChannelType int

ChannelType defines the communication dimension for a collective op. It is int64 to match the i64 type in the StableHLO spec.

const (
	// CrossReplica communicates across replicas (data parallelism).
	// This is the default.
	CrossReplica ChannelType = 0

	// CrossPartition communicates across partitions (model parallelism).
	CrossPartition ChannelType = 1
)

func ChannelTypeString added in v0.1.0

func ChannelTypeString(s string) (ChannelType, error)

ChannelTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func ChannelTypeValues added in v0.1.0

func ChannelTypeValues() []ChannelType

ChannelTypeValues returns all values of the enum

func (ChannelType) IsAChannelType added in v0.1.0

func (i ChannelType) IsAChannelType() bool

IsAChannelType returns "true" if the value is listed in the enum definition. "false" otherwise

func (ChannelType) String added in v0.1.0

func (i ChannelType) String() string

type CollectiveConfig added in v0.1.0

type CollectiveConfig struct {
	// ChannelType specifies the communication dimension.
	// Defaults to CrossReplica (0).
	ChannelType ChannelType

	// ChannelID, if non-nil, forces a specific channel ID (the 'handle').
	// If nil, a unique ID will be automatically generated.
	// This is **required** for MPMD (multi-program, multi-data) to manually link ops across programs.
	ChannelID *int

	// UseGlobalDeviceIDs changes the interpretation of replica_groups
	// from replica IDs to global device IDs.
	// This only applies to AllReduce, not CollectiveBroadcast.
	// Defaults to false.
	UseGlobalDeviceIDs bool
}

CollectiveConfig provides advanced, optional configuration for collective operations. Pass this as the last (optional) argument to collective ops.

type ComparisonDirection

type ComparisonDirection int

ComparisonDirection enum defined for the Compare op.

const (
	CompareEQ ComparisonDirection = iota
	CompareGE
	CompareGT
	CompareLE
	CompareLT
	CompareNE
)

func ComparisonDirectionString

func ComparisonDirectionString(s string) (ComparisonDirection, error)

ComparisonDirectionString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func ComparisonDirectionValues

func ComparisonDirectionValues() []ComparisonDirection

ComparisonDirectionValues returns all values of the enum

func (ComparisonDirection) IsAComparisonDirection

func (i ComparisonDirection) IsAComparisonDirection() bool

IsAComparisonDirection returns "true" if the value is listed in the enum definition. "false" otherwise

func (ComparisonDirection) String

func (i ComparisonDirection) String() string

func (ComparisonDirection) ToStableHLO

func (c ComparisonDirection) ToStableHLO() string

type ComparisonType

type ComparisonType int

ComparisonType enum defined for the Compare op.

const (
	// CompareFloat are used for floating point comparisons.
	CompareFloat ComparisonType = iota

	// CompareTotalOrder version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`.
	CompareTotalOrder

	CompareSigned
	CompareUnsigned
)

func ComparisonTypeString

func ComparisonTypeString(s string) (ComparisonType, error)

ComparisonTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func ComparisonTypeValues

func ComparisonTypeValues() []ComparisonType

ComparisonTypeValues returns all values of the enum

func (ComparisonType) IsAComparisonType

func (i ComparisonType) IsAComparisonType() bool

IsAComparisonType returns "true" if the value is listed in the enum definition. "false" otherwise

func (ComparisonType) String

func (i ComparisonType) String() string

func (ComparisonType) ToStableHLO

func (c ComparisonType) ToStableHLO() string

ToStableHLO returns the StableHLO representation of the comparison type.

type ConvolveAxesConfig

type ConvolveAxesConfig struct {
	InputBatch, InputChannels int
	InputSpatial              []int

	KernelInputChannels, KernelOutputChannels int
	KernelSpatial                             []int

	OutputBatch, OutputChannels int
	OutputSpatial               []int
}

ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output have batch and channel axes. Kernel has inputChannel and outputChannel axes.

See Builder.ConvGeneral

func (ConvolveAxesConfig) Clone

Clone returns a deep copy of the structure.

type DotGeneralAlgorithm

type DotGeneralAlgorithm struct {
	// LhsPrecisionType, RhsPrecisionType that the LHS and RHS of the operation are rounded to.
	// Precision types are independent of the storage types of the inputs and the output.
	LhsPrecisionType, RhsPrecisionType FloatPrecisionType

	// AccumulationType defines the type of the accumulator used for the dot product.
	AccumulationType FloatPrecisionType

	// LhsComponentCount, RhsComponentCount and NumPrimitiveOperations apply when we are doing an algorithm which
	// decomposes the LHS and/or RHS into multiple components and does multiple "primitive" dot operations on those values -
	// usually to emulate a higher precision (e.g.: Leveraging the bfloat16 Artificial Intelligence Datatype For
	// Higher-Precision Computations: bf16_6x tf32_3x -- https://arxiv.org/pdf/1904.06376, etc).
	// For algorithms with no decomposition, these values should be set to 1
	LhsComponentCount, RhsComponentCount, NumPrimitiveOperations int

	// AllowImpreciseAccumulation to specify if accumulation in lower precision is permitted for some steps
	// (e.g. CUBLASLT_MATMUL_DESC_FAST_ACCUM).
	AllowImpreciseAccumulation bool
}

DotGeneralAlgorithm defines fine-control of the algorithm used for the dot product.

type DotGeneralPrecisionType

type DotGeneralPrecisionType int

DotGeneralPrecisionType defines the precision of the dot product.

It controls the tradeoff between speed and accuracy for computations on accelerator backends. This can be one of the following (at the moment, the semantics of these enum values are underspecified, but they are planning to address this in #755 -- https://github.com/openxla/stablehlo/issues/755):

const (
	// DotGeneralPrecisionDefault is the fastest calculation, but the least accurate approximation to the original number.
	DotGeneralPrecisionDefault DotGeneralPrecisionType = iota
	DotGeneralPrecisionHigh
	DotGeneralPrecisionHighest
)

func DotGeneralPrecisionTypeString

func DotGeneralPrecisionTypeString(s string) (DotGeneralPrecisionType, error)

DotGeneralPrecisionTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func DotGeneralPrecisionTypeValues

func DotGeneralPrecisionTypeValues() []DotGeneralPrecisionType

DotGeneralPrecisionTypeValues returns all values of the enum

func (DotGeneralPrecisionType) IsADotGeneralPrecisionType

func (i DotGeneralPrecisionType) IsADotGeneralPrecisionType() bool

IsADotGeneralPrecisionType returns "true" if the value is listed in the enum definition. "false" otherwise

func (DotGeneralPrecisionType) String

func (i DotGeneralPrecisionType) String() string

func (DotGeneralPrecisionType) ToStableHLO

func (p DotGeneralPrecisionType) ToStableHLO() string

type FFTType

type FFTType int

FFTType defines the type of the FFT operation, see FFT.

const (
	// FFTForward - complex in, complex out.
	FFTForward FFTType = iota

	// FFTInverse - complex in, complex out.
	FFTInverse

	// FFTForwardReal - real in, fft_length / 2 + 1 complex out
	FFTForwardReal

	// FFTInverseReal - fft_length / 2 + 1 complex in
	FFTInverseReal
)

func FFTTypeString

func FFTTypeString(s string) (FFTType, error)

FFTTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func FFTTypeValues

func FFTTypeValues() []FFTType

FFTTypeValues returns all values of the enum

func (FFTType) IsAFFTType

func (i FFTType) IsAFFTType() bool

IsAFFTType returns "true" if the value is listed in the enum definition. "false" otherwise

func (FFTType) String

func (i FFTType) String() string

func (FFTType) ToStableHLO

func (t FFTType) ToStableHLO() string

ToStableHLO returns the StableHLO representation of the FFT type.

type FloatPrecisionType

type FloatPrecisionType struct {
	// TF32 is used for the TF32 precision type.
	TF32 bool

	// DType is used for non-TF32 precision types.
	// It must be a float type.
	DType dtypes.DType
}

FloatPrecisionType defines the precision used during floating point operations. In particular, modern GPUs accept the TF32 type which sacrifices some accuracy for significant speed improvements.

func (FloatPrecisionType) ToStableHLO

func (f FloatPrecisionType) ToStableHLO() string

type RNGBitGeneratorAlgorithm added in v0.2.0

type RNGBitGeneratorAlgorithm int

RNGBitGeneratorAlgorithm used by the RngBitGenerator operation.

const (
	RNGDefault RNGBitGeneratorAlgorithm = iota
	RNGPhilox
	RNGThreeFry
)

func RNGBitGeneratorAlgorithmString added in v0.2.0

func RNGBitGeneratorAlgorithmString(s string) (RNGBitGeneratorAlgorithm, error)

RNGBitGeneratorAlgorithmString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func RNGBitGeneratorAlgorithmValues added in v0.2.0

func RNGBitGeneratorAlgorithmValues() []RNGBitGeneratorAlgorithm

RNGBitGeneratorAlgorithmValues returns all values of the enum

func (RNGBitGeneratorAlgorithm) IsARNGBitGeneratorAlgorithm added in v0.2.0

func (i RNGBitGeneratorAlgorithm) IsARNGBitGeneratorAlgorithm() bool

IsARNGBitGeneratorAlgorithm returns "true" if the value is listed in the enum definition. "false" otherwise

func (RNGBitGeneratorAlgorithm) String added in v0.2.0

func (i RNGBitGeneratorAlgorithm) String() string

Directories

Path Synopsis
Package shapes defines Shape and DType and associated tools.
Package shapes defines Shape and DType and associated tools.
Package shardy provides the types needed to define a distributed computation topology.
Package shardy provides the types needed to define a distributed computation topology.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL