stablehlo

package
v1.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 3, 2026 License: Apache-2.0 Imports: 4 Imported by: 0

Documentation

Overview

Package stablehlo provides StableHLO MLIR text emission for the PJRT backend.

Package stablehlo generates StableHLO MLIR text for PJRT compilation.

It provides type mapping (Go types to MLIR tensor type strings), SSA value naming (%v0, %v1, ...), shape formatting (tensor<2x3x4xf32>), and StableHLO operation name constants.

Index

Constants

View Source
const (
	DTypeF32  = "f32"
	DTypeF64  = "f64"
	DTypeF16  = "f16"
	DTypeBF16 = "bf16"
	DTypeF8   = "f8E4M3FN"
	DTypeI8   = "i8"
	DTypeI16  = "i16"
	DTypeI32  = "i32"
	DTypeI64  = "i64"
	DTypeUI8  = "ui8"
	DTypeUI32 = "ui32"
	DTypeUI64 = "ui64"
	DTypeBool = "i1"
)

MLIR dtype strings for StableHLO tensor types.

View Source
const (
	OpAdd         = "stablehlo.add"
	OpSubtract    = "stablehlo.subtract"
	OpMultiply    = "stablehlo.multiply"
	OpDivide      = "stablehlo.divide"
	OpDotGeneral  = "stablehlo.dot_general"
	OpTranspose   = "stablehlo.transpose"
	OpReshape     = "stablehlo.reshape"
	OpBroadcastIn = "stablehlo.broadcast_in_dim"
	OpReduce      = "stablehlo.reduce"
	OpGather      = "stablehlo.gather"
	OpSlice       = "stablehlo.slice"
	OpConcatenate = "stablehlo.concatenate"
	OpExp         = "stablehlo.exponential"
	OpLog         = "stablehlo.log"
	OpSin         = "stablehlo.sine"
	OpCos         = "stablehlo.cosine"
	OpTanh        = "stablehlo.tanh"
	OpNegate      = "stablehlo.negate"
	OpAbs         = "stablehlo.abs"
	OpSqrt        = "stablehlo.sqrt"
	OpRsqrt       = "stablehlo.rsqrt"
	OpMaximum     = "stablehlo.maximum"
	OpMinimum     = "stablehlo.minimum"
	OpClamp       = "stablehlo.clamp"
	OpSelect      = "stablehlo.select"
	OpCompare     = "stablehlo.compare"
	OpConvert     = "stablehlo.convert"
	OpConstant    = "stablehlo.constant"
	OpIota        = "stablehlo.iota"
	OpPower       = "stablehlo.power"
)

StableHLO operation name constants.

Variables

This section is empty.

Functions

func EmitConcat

func EmitConcat(namer *SSANamer, operands []string, shapes [][]int, axis int, dtype string) (string, string, error)

EmitConcat emits a stablehlo.concatenate operation along the given axis. operands are the SSA names, shapes are the corresponding tensor shapes.

func EmitGather

func EmitGather(namer *SSANamer, operand, indices string,
	operandShape, indicesShape, sliceSizes []int,
	offsetDims, collapsedSliceDims, startIndexMap []int,
	indexVectorDim int,
	dtype string,
) (string, string, error)

EmitGather emits a stablehlo.gather operation. operandShape is the shape of the data tensor, indicesShape is the shape of the index tensor. sliceSizes specifies the size of each gathered slice. offsetDims, collapsedSliceDims, startIndexMap are the gather dimension numbers. indexVectorDim is the dimension in the indices tensor that contains the index vector.

func EmitKVCacheProgram

func EmitKVCacheProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, kvSlots []KVCacheSlot, dtype string, decode bool) (string, error)

EmitKVCacheProgram emits a StableHLO program with explicit KV cache I/O.

KV cache tensors are added as both function arguments and return values. The function signature becomes:

func.func @main(%regular_args..., %kv_in_0, %kv_in_1, ...) ->
    (regular_output, %kv_out_0, %kv_out_1, ...)

For decode programs (decode=true), each KV cache output is produced by concatenating the KV input with the new KV step along the sequence axis:

kv_out = concat(kv_in, kv_step, axis=seq_axis)

For prefill programs (decode=false), KV cache outputs are passed through directly from the ops that produce them.

func EmitMatMul

func EmitMatMul(namer *SSANamer, lhs, rhs string, lhsShape, rhsShape []int, dtype string) (string, string, error)

EmitMatMul emits a stablehlo.dot_general operation for matrix multiplication. Handles 2D (MxK @ KxN) and batched (BxMxK @ BxKxN) cases. Returns the MLIR line and the SSA name assigned to the result.

func EmitProgram

func EmitProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, dtype string) (string, error)

EmitProgram takes a sequence of operations and produces a complete StableHLO MLIR module. inputSlots identifies which slots are function arguments, and inputShapes provides their shapes. The last operation's output slot is used as the function return value.

Slot indices map to SSA names: function arguments get %arg0, %arg1, etc. Intermediate results get %v0, %v1, etc. via the Emitter's SSANamer.

func EmitReduceMax

func EmitReduceMax(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)

EmitReduceMax emits a StableHLO reduce with a maximum body.

func EmitReduceMean

func EmitReduceMean(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)

EmitReduceMean emits a ReduceSum followed by a DivScalar to compute the mean. Returns the final result SSA name and the emitted MLIR text for both ops.

func EmitReduceSum

func EmitReduceSum(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)

EmitReduceSum emits a StableHLO reduce with an add body.

The generated MLIR has the form:

%result = "stablehlo.reduce"(%input, %init) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
  %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
  stablehlo.return %0 : tensor<f32>
}) {dimensions = array<i64: axis>} : (inputType, tensor<dtype>) -> outputType

func EmitReshape

func EmitReshape(namer *SSANamer, operand string, inShape, targetShape []int, dtype string) (string, string, error)

EmitReshape emits a stablehlo.reshape operation. targetShape is the desired output shape.

func EmitSlice

func EmitSlice(namer *SSANamer, operand string, shape, start, limit, strides []int, dtype string) (string, string, error)

EmitSlice emits a stablehlo.slice operation with start, limit, and stride indices. strides may be nil, in which case all strides default to 1.

func EmitSoftmax

func EmitSoftmax(namer *SSANamer, input string, inputShape []int, axis int, dtype string) (string, string)

EmitSoftmax decomposes Softmax into 5 StableHLO operations:

  1. max = ReduceMax(input, axis, keepDims=true)
  2. shifted = Sub(input, max)
  3. exp = Exp(shifted)
  4. sum = ReduceSum(exp, axis, keepDims=true)
  5. result = Div(exp, sum)

Returns the final result SSA name and the emitted MLIR text.

func EmitTranspose

func EmitTranspose(namer *SSANamer, operand string, shape []int, perm []int, dtype string) (string, string, error)

EmitTranspose emits a stablehlo.transpose operation. perm specifies the axis permutation (e.g., [2, 0, 1]).

func FormatScalarType

func FormatScalarType(dtype string) string

FormatScalarType returns the MLIR scalar type string for a dtype. Example: FormatScalarType("f32") returns "f32".

func FormatTensorType

func FormatTensorType(shape []int, dtype string) string

FormatTensorType formats a MLIR tensor type string from a shape and dtype. Example: FormatTensorType([]int{2, 3, 4}, "f32") returns "tensor<2x3x4xf32>". For scalar tensors (empty shape), it returns "tensor<f32>".

func GoDTypeToMLIR

func GoDTypeToMLIR(goType string) (string, bool)

GoDTypeToMLIR maps a Go reflect type name to a MLIR dtype string. Supported mappings:

float32  -> f32
float64  -> f64
float16  -> f16
bfloat16 -> bf16
float8   -> f8E4M3FN
int8     -> i8
int16    -> i16
int32    -> i32
int64    -> i64
uint8    -> ui8
uint32   -> ui32
uint64   -> ui64

Returns the MLIR dtype string and true if the mapping exists, or ("", false) otherwise.

func InferShape

func InferShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error)

InferShape computes the output shape for a given operation name, input shapes, and optional attributes. It returns an error if the shapes are incompatible.

func InferStructuralShape

func InferStructuralShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error)

InferStructuralShape computes the output shape for structural operations: MatMul, Transpose, Reshape, Concat, Slice, Gather, ReduceSum, ReduceMax, ReduceMean.

attrs supports the following keys depending on the operation:

  • "perm" ([]int): axis permutation for Transpose
  • "shape" ([]int): target shape for Reshape
  • "axis" (int): concatenation axis for Concat, reduction axis for Reduce*
  • "start" ([]int): start indices for Slice
  • "end" ([]int): end indices for Slice
  • "sliceSizes" ([]int): slice sizes for Gather
  • "keepDims" (bool): whether to keep the reduced dimension for Reduce*

Types

type Emitter

type Emitter struct {
	Namer *SSANamer
}

Emitter generates StableHLO MLIR text from operation inputs. Each emit method takes SSA input names, tensor shapes, and a dtype, and returns the emitted MLIR line(s) plus the output SSA name.

func NewEmitter

func NewEmitter() *Emitter

NewEmitter creates an Emitter with a fresh SSANamer.

func (*Emitter) EmitAdd

func (e *Emitter) EmitAdd(lhs, rhs string, shape []int, dtype string) (string, string)

EmitAdd emits stablehlo.add.

func (*Emitter) EmitAddScalar

func (e *Emitter) EmitAddScalar(input string, scalar float64, shape []int, dtype string) (string, string)

EmitAddScalar emits stablehlo.constant + broadcast_in_dim + add.

func (*Emitter) EmitBinaryElementwise

func (e *Emitter) EmitBinaryElementwise(opName, lhs, rhs string, shape []int, dtype string) (mlir, outName string)

EmitBinaryElementwise emits a binary element-wise op (add, subtract, multiply, divide, power). Both inputs must have the same shape and dtype.

func (*Emitter) EmitCos

func (e *Emitter) EmitCos(input string, shape []int, dtype string) (string, string)

EmitCos emits stablehlo.cosine.

func (*Emitter) EmitDiv

func (e *Emitter) EmitDiv(lhs, rhs string, shape []int, dtype string) (string, string)

EmitDiv emits stablehlo.divide.

func (*Emitter) EmitDivScalar

func (e *Emitter) EmitDivScalar(input string, scalar float64, shape []int, dtype string) (string, string)

EmitDivScalar emits stablehlo.constant + broadcast_in_dim + divide.

func (*Emitter) EmitExp

func (e *Emitter) EmitExp(input string, shape []int, dtype string) (string, string)

EmitExp emits stablehlo.exponential.

func (*Emitter) EmitLog

func (e *Emitter) EmitLog(input string, shape []int, dtype string) (string, string)

EmitLog emits stablehlo.log.

func (*Emitter) EmitMul

func (e *Emitter) EmitMul(lhs, rhs string, shape []int, dtype string) (string, string)

EmitMul emits stablehlo.multiply.

func (*Emitter) EmitMulScalar

func (e *Emitter) EmitMulScalar(input string, scalar float64, shape []int, dtype string) (string, string)

EmitMulScalar emits stablehlo.constant + broadcast_in_dim + multiply.

func (*Emitter) EmitNeg

func (e *Emitter) EmitNeg(input string, shape []int, dtype string) (string, string)

EmitNeg emits stablehlo.negate.

func (*Emitter) EmitOp

func (e *Emitter) EmitOp(opName string, inputs []string, shape []int, dtype string, attrs map[string]any) (string, string, error)

EmitOp dispatches to the appropriate emit function based on the engine op name. For binary ops, inputs should be [lhs, rhs]. For unary ops, inputs should be [input]. For scalar ops, inputs should be [input] and attrs must contain "scalar" (float64). Returns the emitted MLIR text and the output SSA name.

func (*Emitter) EmitPow

func (e *Emitter) EmitPow(lhs, rhs string, shape []int, dtype string) (string, string)

EmitPow emits stablehlo.power.

func (*Emitter) EmitRsqrt

func (e *Emitter) EmitRsqrt(input string, shape []int, dtype string) (string, string)

EmitRsqrt emits stablehlo.rsqrt.

func (*Emitter) EmitScalarOp

func (e *Emitter) EmitScalarOp(elemOp, input string, scalar float64, shape []int, dtype string) (mlir, outName string)

EmitScalarOp emits a scalar operation as three MLIR instructions:

  1. stablehlo.constant for the scalar value
  2. stablehlo.broadcast_in_dim to broadcast to the tensor shape
  3. The element-wise binary op

Returns all three lines (newline-separated) and the final output SSA name.

func (*Emitter) EmitSin

func (e *Emitter) EmitSin(input string, shape []int, dtype string) (string, string)

EmitSin emits stablehlo.sine.

func (*Emitter) EmitSqrt

func (e *Emitter) EmitSqrt(input string, shape []int, dtype string) (string, string)

EmitSqrt emits stablehlo.sqrt.

func (*Emitter) EmitSub

func (e *Emitter) EmitSub(lhs, rhs string, shape []int, dtype string) (string, string)

EmitSub emits stablehlo.subtract.

func (*Emitter) EmitTanh

func (e *Emitter) EmitTanh(input string, shape []int, dtype string) (string, string)

EmitTanh emits stablehlo.tanh.

func (*Emitter) EmitUnary

func (e *Emitter) EmitUnary(opName, input string, shape []int, dtype string) (mlir, outName string)

EmitUnary emits a unary element-wise op (exponential, log, sine, cosine, tanh, sqrt, rsqrt, negate).

type KVCacheSlot

type KVCacheSlot struct {
	InputSlot  int   // slot index where the KV cache is read (becomes a function arg)
	OutputSlot int   // slot index where the updated KV cache is produced (becomes a return value)
	Shape      []int // tensor shape (e.g., [num_heads, seq_len, head_dim])
	SeqAxis    int   // axis along which decode concatenation occurs
}

KVCacheSlot describes a stateful KV cache slot that must be rewritten as explicit function I/O for PJRT's pure-functional execution model.

In the original graph, the KV cache is fed back via StatefulInputNode. For PJRT, each KV cache tensor becomes both a function argument (the previous state) and a return value (the updated state).

type ProgramOp

type ProgramOp struct {
	OpName      string         // Engine method name (e.g., "Add", "MatMul", "Softmax")
	InputSlots  []int          // indices into the slot table for inputs
	OutputSlot  int            // index for the output
	InputShapes [][]int        // shapes of each input
	OutputShape []int          // shape of the output
	Dtype       string         // "f32", "f16", etc.
	Attrs       map[string]any // op-specific attributes
}

ProgramOp describes a single operation in a StableHLO program.

type SSANamer

type SSANamer struct {
	// contains filtered or unexported fields
}

SSANamer generates monotonically increasing SSA value names (%v0, %v1, ...).

func (*SSANamer) Count

func (n *SSANamer) Count() int

Count returns the current counter value (number of names issued so far).

func (*SSANamer) NextName

func (n *SSANamer) NextName() string

NextName returns the next SSA value name and advances the counter.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL