Documentation
¶
Index ¶
- func ChannelTypeStrings() []string
- func ComparisonDirectionStrings() []string
- func ComparisonTypeStrings() []string
- func DotGeneralPrecisionTypeStrings() []string
- func FFTTypeStrings() []string
- func RNGBitGeneratorAlgorithmStrings() []string
- type ChannelType
- type CollectiveConfig
- type ComparisonDirection
- type ComparisonType
- type ConvolveAxesConfig
- type DotGeneralAlgorithm
- type DotGeneralPrecisionType
- type FFTType
- type FloatPrecisionType
- type RNGBitGeneratorAlgorithm
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func ChannelTypeStrings ¶ added in v0.1.0
func ChannelTypeStrings() []string
ChannelTypeStrings returns a slice of all String values of the enum
func ComparisonDirectionStrings ¶
func ComparisonDirectionStrings() []string
ComparisonDirectionStrings returns a slice of all String values of the enum
func ComparisonTypeStrings ¶
func ComparisonTypeStrings() []string
ComparisonTypeStrings returns a slice of all String values of the enum
func DotGeneralPrecisionTypeStrings ¶
func DotGeneralPrecisionTypeStrings() []string
DotGeneralPrecisionTypeStrings returns a slice of all String values of the enum
func FFTTypeStrings ¶
func FFTTypeStrings() []string
FFTTypeStrings returns a slice of all String values of the enum
func RNGBitGeneratorAlgorithmStrings ¶ added in v0.2.0
func RNGBitGeneratorAlgorithmStrings() []string
RNGBitGeneratorAlgorithmStrings returns a slice of all String values of the enum
Types ¶
type ChannelType ¶ added in v0.1.0
type ChannelType int
ChannelType defines the communication dimension for a collective op. It is int64 to match the i64 type in the StableHLO spec.
const ( // CrossReplica communicates across replicas (data parallelism). // This is the default. CrossReplica ChannelType = 0 // CrossPartition communicates across partitions (model parallelism). CrossPartition ChannelType = 1 )
func ChannelTypeString ¶ added in v0.1.0
func ChannelTypeString(s string) (ChannelType, error)
ChannelTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func ChannelTypeValues ¶ added in v0.1.0
func ChannelTypeValues() []ChannelType
ChannelTypeValues returns all values of the enum
func (ChannelType) IsAChannelType ¶ added in v0.1.0
func (i ChannelType) IsAChannelType() bool
IsAChannelType returns "true" if the value is listed in the enum definition. "false" otherwise
func (ChannelType) String ¶ added in v0.1.0
func (i ChannelType) String() string
type CollectiveConfig ¶ added in v0.1.0
type CollectiveConfig struct {
// ChannelType specifies the communication dimension.
// Defaults to CrossReplica (0).
ChannelType ChannelType
// ChannelID, if non-nil, forces a specific channel ID (the 'handle').
// If nil, a unique ID will be automatically generated.
// This is **required** for MPMD (multi-program, multi-data) to manually link ops across programs.
ChannelID *int
// UseGlobalDeviceIDs changes the interpretation of replica_groups
// from replica IDs to global device IDs.
// This only applies to AllReduce, not CollectiveBroadcast.
// Defaults to false.
UseGlobalDeviceIDs bool
}
CollectiveConfig provides advanced, optional configuration for collective operations. Pass this as the last (optional) argument to collective ops.
type ComparisonDirection ¶
type ComparisonDirection int
ComparisonDirection enum defined for the Compare op.
const ( CompareEQ ComparisonDirection = iota CompareGE CompareGT CompareLE CompareLT CompareNE )
func ComparisonDirectionString ¶
func ComparisonDirectionString(s string) (ComparisonDirection, error)
ComparisonDirectionString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func ComparisonDirectionValues ¶
func ComparisonDirectionValues() []ComparisonDirection
ComparisonDirectionValues returns all values of the enum
func (ComparisonDirection) IsAComparisonDirection ¶
func (i ComparisonDirection) IsAComparisonDirection() bool
IsAComparisonDirection returns "true" if the value is listed in the enum definition. "false" otherwise
func (ComparisonDirection) String ¶
func (i ComparisonDirection) String() string
func (ComparisonDirection) ToStableHLO ¶
func (c ComparisonDirection) ToStableHLO() string
type ComparisonType ¶
type ComparisonType int
ComparisonType enum defined for the Compare op.
const ( // CompareFloat are used for floating point comparisons. CompareFloat ComparisonType = iota // CompareTotalOrder version of the operation enforces `-NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN`. CompareTotalOrder CompareSigned CompareUnsigned )
func ComparisonTypeString ¶
func ComparisonTypeString(s string) (ComparisonType, error)
ComparisonTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func ComparisonTypeValues ¶
func ComparisonTypeValues() []ComparisonType
ComparisonTypeValues returns all values of the enum
func (ComparisonType) IsAComparisonType ¶
func (i ComparisonType) IsAComparisonType() bool
IsAComparisonType returns "true" if the value is listed in the enum definition. "false" otherwise
func (ComparisonType) String ¶
func (i ComparisonType) String() string
func (ComparisonType) ToStableHLO ¶
func (c ComparisonType) ToStableHLO() string
ToStableHLO returns the StableHLO representation of the comparison type.
type ConvolveAxesConfig ¶
type ConvolveAxesConfig struct {
InputBatch, InputChannels int
InputSpatial []int
KernelInputChannels, KernelOutputChannels int
KernelSpatial []int
OutputBatch, OutputChannels int
OutputSpatial []int
}
ConvolveAxesConfig defines the interpretation of the input/kernel/output tensor axes. There must be the same number of spatial dimensions (axes) for each of the 3 tensors. Input and output have batch and channel axes. Kernel has inputChannel and outputChannel axes.
See Builder.ConvGeneral
func (ConvolveAxesConfig) Clone ¶
func (c ConvolveAxesConfig) Clone() ConvolveAxesConfig
Clone returns a deep copy of the structure.
type DotGeneralAlgorithm ¶
type DotGeneralAlgorithm struct {
// LhsPrecisionType, RhsPrecisionType that the LHS and RHS of the operation are rounded to.
// Precision types are independent of the storage types of the inputs and the output.
LhsPrecisionType, RhsPrecisionType FloatPrecisionType
// AccumulationType defines the type of the accumulator used for the dot product.
AccumulationType FloatPrecisionType
// LhsComponentCount, RhsComponentCount and NumPrimitiveOperations apply when we are doing an algorithm which
// decomposes the LHS and/or RHS into multiple components and does multiple "primitive" dot operations on those values -
// usually to emulate a higher precision (e.g.: Leveraging the bfloat16 Artificial Intelligence Datatype For
// Higher-Precision Computations: bf16_6x tf32_3x -- https://arxiv.org/pdf/1904.06376, etc).
// For algorithms with no decomposition, these values should be set to 1
LhsComponentCount, RhsComponentCount, NumPrimitiveOperations int
// AllowImpreciseAccumulation to specify if accumulation in lower precision is permitted for some steps
// (e.g. CUBLASLT_MATMUL_DESC_FAST_ACCUM).
AllowImpreciseAccumulation bool
}
DotGeneralAlgorithm defines fine-control of the algorithm used for the dot product.
type DotGeneralPrecisionType ¶
type DotGeneralPrecisionType int
DotGeneralPrecisionType defines the precision of the dot product.
It controls the tradeoff between speed and accuracy for computations on accelerator backends. This can be one of the following (at the moment, the semantics of these enum values are underspecified, but they are planning to address this in #755 -- https://github.com/openxla/stablehlo/issues/755):
const ( // DotGeneralPrecisionDefault is the fastest calculation, but the least accurate approximation to the original number. DotGeneralPrecisionDefault DotGeneralPrecisionType = iota DotGeneralPrecisionHigh DotGeneralPrecisionHighest )
func DotGeneralPrecisionTypeString ¶
func DotGeneralPrecisionTypeString(s string) (DotGeneralPrecisionType, error)
DotGeneralPrecisionTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func DotGeneralPrecisionTypeValues ¶
func DotGeneralPrecisionTypeValues() []DotGeneralPrecisionType
DotGeneralPrecisionTypeValues returns all values of the enum
func (DotGeneralPrecisionType) IsADotGeneralPrecisionType ¶
func (i DotGeneralPrecisionType) IsADotGeneralPrecisionType() bool
IsADotGeneralPrecisionType returns "true" if the value is listed in the enum definition. "false" otherwise
func (DotGeneralPrecisionType) String ¶
func (i DotGeneralPrecisionType) String() string
func (DotGeneralPrecisionType) ToStableHLO ¶
func (p DotGeneralPrecisionType) ToStableHLO() string
type FFTType ¶
type FFTType int
FFTType defines the type of the FFT operation, see FFT.
func FFTTypeString ¶
FFTTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func (FFTType) IsAFFTType ¶
IsAFFTType returns "true" if the value is listed in the enum definition. "false" otherwise
func (FFTType) ToStableHLO ¶
ToStableHLO returns the StableHLO representation of the FFT type.
type FloatPrecisionType ¶
type FloatPrecisionType struct {
// TF32 is used for the TF32 precision type.
TF32 bool
// DType is used for non-TF32 precision types.
// It must be a float type.
DType dtypes.DType
}
FloatPrecisionType defines the precision used during floating point operations. In particular, modern GPUs accept the TF32 type which sacrifices some accuracy for significant speed improvements.
func (FloatPrecisionType) ToStableHLO ¶
func (f FloatPrecisionType) ToStableHLO() string
type RNGBitGeneratorAlgorithm ¶ added in v0.2.0
type RNGBitGeneratorAlgorithm int
RNGBitGeneratorAlgorithm used by the RngBitGenerator operation.
const ( RNGDefault RNGBitGeneratorAlgorithm = iota RNGPhilox RNGThreeFry )
func RNGBitGeneratorAlgorithmString ¶ added in v0.2.0
func RNGBitGeneratorAlgorithmString(s string) (RNGBitGeneratorAlgorithm, error)
RNGBitGeneratorAlgorithmString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.
func RNGBitGeneratorAlgorithmValues ¶ added in v0.2.0
func RNGBitGeneratorAlgorithmValues() []RNGBitGeneratorAlgorithm
RNGBitGeneratorAlgorithmValues returns all values of the enum
func (RNGBitGeneratorAlgorithm) IsARNGBitGeneratorAlgorithm ¶ added in v0.2.0
func (i RNGBitGeneratorAlgorithm) IsARNGBitGeneratorAlgorithm() bool
IsARNGBitGeneratorAlgorithm returns "true" if the value is listed in the enum definition. "false" otherwise
func (RNGBitGeneratorAlgorithm) String ¶ added in v0.2.0
func (i RNGBitGeneratorAlgorithm) String() string
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
Package shapes defines Shape and DType and associated tools.
|
Package shapes defines Shape and DType and associated tools. |
|
Package shardy provides the types needed to define a distributed computation topology.
|
Package shardy provides the types needed to define a distributed computation topology. |