Documentation
¶
Overview ¶
Package stablehlo provides StableHLO MLIR text emission for the PJRT backend.
Package stablehlo generates StableHLO MLIR text for PJRT compilation.
It provides type mapping (Go types to MLIR tensor type strings), SSA value naming (%v0, %v1, ...), shape formatting (tensor<2x3x4xf32>), and StableHLO operation name constants.
Index ¶
- Constants
- func EmitConcat(namer *SSANamer, operands []string, shapes [][]int, axis int, dtype string) (string, string, error)
- func EmitGather(namer *SSANamer, operand, indices string, ...) (string, string, error)
- func EmitKVCacheProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, kvSlots []KVCacheSlot, ...) (string, error)
- func EmitMatMul(namer *SSANamer, lhs, rhs string, lhsShape, rhsShape []int, dtype string) (string, string, error)
- func EmitProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, dtype string) (string, error)
- func EmitReduceMax(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, ...) (string, string)
- func EmitReduceMean(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, ...) (string, string)
- func EmitReduceSum(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, ...) (string, string)
- func EmitReshape(namer *SSANamer, operand string, inShape, targetShape []int, dtype string) (string, string, error)
- func EmitSlice(namer *SSANamer, operand string, shape, start, limit, strides []int, ...) (string, string, error)
- func EmitSoftmax(namer *SSANamer, input string, inputShape []int, axis int, dtype string) (string, string)
- func EmitTranspose(namer *SSANamer, operand string, shape []int, perm []int, dtype string) (string, string, error)
- func FormatScalarType(dtype string) string
- func FormatTensorType(shape []int, dtype string) string
- func GoDTypeToMLIR(goType string) (string, bool)
- func InferShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error)
- func InferStructuralShape(opName string, inputShapes [][]int, attrs map[string]any) ([]int, error)
- type Emitter
- func (e *Emitter) EmitAdd(lhs, rhs string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitAddScalar(input string, scalar float64, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitBinaryElementwise(opName, lhs, rhs string, shape []int, dtype string) (mlir, outName string)
- func (e *Emitter) EmitCos(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitDiv(lhs, rhs string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitDivScalar(input string, scalar float64, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitExp(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitLog(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitMul(lhs, rhs string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitMulScalar(input string, scalar float64, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitNeg(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitOp(opName string, inputs []string, shape []int, dtype string, ...) (string, string, error)
- func (e *Emitter) EmitPow(lhs, rhs string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitRsqrt(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitScalarOp(elemOp, input string, scalar float64, shape []int, dtype string) (mlir, outName string)
- func (e *Emitter) EmitSin(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitSqrt(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitSub(lhs, rhs string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitTanh(input string, shape []int, dtype string) (string, string)
- func (e *Emitter) EmitUnary(opName, input string, shape []int, dtype string) (mlir, outName string)
- type KVCacheSlot
- type ProgramOp
- type SSANamer
Constants ¶
const ( DTypeF32 = "f32" DTypeF64 = "f64" DTypeF16 = "f16" DTypeBF16 = "bf16" DTypeF8 = "f8E4M3FN" DTypeI8 = "i8" DTypeI16 = "i16" DTypeI32 = "i32" DTypeI64 = "i64" DTypeUI8 = "ui8" DTypeUI32 = "ui32" DTypeUI64 = "ui64" DTypeBool = "i1" )
MLIR dtype strings for StableHLO tensor types.
const ( OpAdd = "stablehlo.add" OpSubtract = "stablehlo.subtract" OpMultiply = "stablehlo.multiply" OpDivide = "stablehlo.divide" OpDotGeneral = "stablehlo.dot_general" OpTranspose = "stablehlo.transpose" OpReshape = "stablehlo.reshape" OpBroadcastIn = "stablehlo.broadcast_in_dim" OpReduce = "stablehlo.reduce" OpGather = "stablehlo.gather" OpSlice = "stablehlo.slice" OpConcatenate = "stablehlo.concatenate" OpExp = "stablehlo.exponential" OpLog = "stablehlo.log" OpSin = "stablehlo.sine" OpCos = "stablehlo.cosine" OpTanh = "stablehlo.tanh" OpNegate = "stablehlo.negate" OpAbs = "stablehlo.abs" OpSqrt = "stablehlo.sqrt" OpRsqrt = "stablehlo.rsqrt" OpMaximum = "stablehlo.maximum" OpMinimum = "stablehlo.minimum" OpClamp = "stablehlo.clamp" OpSelect = "stablehlo.select" OpCompare = "stablehlo.compare" OpConvert = "stablehlo.convert" OpConstant = "stablehlo.constant" OpIota = "stablehlo.iota" OpPower = "stablehlo.power" )
StableHLO operation name constants.
Variables ¶
This section is empty.
Functions ¶
func EmitConcat ¶
func EmitConcat(namer *SSANamer, operands []string, shapes [][]int, axis int, dtype string) (string, string, error)
EmitConcat emits a stablehlo.concatenate operation along the given axis. operands are the SSA names, shapes are the corresponding tensor shapes.
func EmitGather ¶
func EmitGather(namer *SSANamer, operand, indices string, operandShape, indicesShape, sliceSizes []int, offsetDims, collapsedSliceDims, startIndexMap []int, indexVectorDim int, dtype string, ) (string, string, error)
EmitGather emits a stablehlo.gather operation. operandShape is the shape of the data tensor, indicesShape is the shape of the index tensor. sliceSizes specifies the size of each gathered slice. offsetDims, collapsedSliceDims, startIndexMap are the gather dimension numbers. indexVectorDim is the dimension in the indices tensor that contains the index vector.
func EmitKVCacheProgram ¶
func EmitKVCacheProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, kvSlots []KVCacheSlot, dtype string, decode bool) (string, error)
EmitKVCacheProgram emits a StableHLO program with explicit KV cache I/O.
KV cache tensors are added as both function arguments and return values. The function signature becomes:
func.func @main(%regular_args..., %kv_in_0, %kv_in_1, ...) ->
(regular_output, %kv_out_0, %kv_out_1, ...)
For decode programs (decode=true), each KV cache output is produced by concatenating the KV input with the new KV step along the sequence axis:
kv_out = concat(kv_in, kv_step, axis=seq_axis)
For prefill programs (decode=false), KV cache outputs are passed through directly from the ops that produce them.
func EmitMatMul ¶
func EmitMatMul(namer *SSANamer, lhs, rhs string, lhsShape, rhsShape []int, dtype string) (string, string, error)
EmitMatMul emits a stablehlo.dot_general operation for matrix multiplication. Handles 2D (MxK @ KxN) and batched (BxMxK @ BxKxN) cases. Returns the MLIR line and the SSA name assigned to the result.
func EmitProgram ¶
func EmitProgram(ops []ProgramOp, inputSlots []int, inputShapes [][]int, dtype string) (string, error)
EmitProgram takes a sequence of operations and produces a complete StableHLO MLIR module. inputSlots identifies which slots are function arguments, and inputShapes provides their shapes. The last operation's output slot is used as the function return value.
Slot indices map to SSA names: function arguments get %arg0, %arg1, etc. Intermediate results get %v0, %v1, etc. via the Emitter's SSANamer.
func EmitReduceMax ¶
func EmitReduceMax(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)
EmitReduceMax emits a StableHLO reduce with a maximum body.
func EmitReduceMean ¶
func EmitReduceMean(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)
EmitReduceMean emits a ReduceSum followed by a DivScalar to compute the mean. Returns the final result SSA name and the emitted MLIR text for both ops.
func EmitReduceSum ¶
func EmitReduceSum(namer *SSANamer, input string, inputShape []int, axis int, keepDims bool, dtype string) (string, string)
EmitReduceSum emits a StableHLO reduce with an add body.
The generated MLIR has the form:
%result = "stablehlo.reduce"(%input, %init) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
stablehlo.return %0 : tensor<f32>
}) {dimensions = array<i64: axis>} : (inputType, tensor<dtype>) -> outputType
func EmitReshape ¶
func EmitReshape(namer *SSANamer, operand string, inShape, targetShape []int, dtype string) (string, string, error)
EmitReshape emits a stablehlo.reshape operation. targetShape is the desired output shape.
func EmitSlice ¶
func EmitSlice(namer *SSANamer, operand string, shape, start, limit, strides []int, dtype string) (string, string, error)
EmitSlice emits a stablehlo.slice operation with start, limit, and stride indices. strides may be nil, in which case all strides default to 1.
func EmitSoftmax ¶
func EmitSoftmax(namer *SSANamer, input string, inputShape []int, axis int, dtype string) (string, string)
EmitSoftmax decomposes Softmax into 5 StableHLO operations:
- max = ReduceMax(input, axis, keepDims=true)
- shifted = Sub(input, max)
- exp = Exp(shifted)
- sum = ReduceSum(exp, axis, keepDims=true)
- result = Div(exp, sum)
Returns the final result SSA name and the emitted MLIR text.
func EmitTranspose ¶
func EmitTranspose(namer *SSANamer, operand string, shape []int, perm []int, dtype string) (string, string, error)
EmitTranspose emits a stablehlo.transpose operation. perm specifies the axis permutation (e.g., [2, 0, 1]).
func FormatScalarType ¶
FormatScalarType returns the MLIR scalar type string for a dtype. Example: FormatScalarType("f32") returns "f32".
func FormatTensorType ¶
FormatTensorType formats a MLIR tensor type string from a shape and dtype. Example: FormatTensorType([]int{2, 3, 4}, "f32") returns "tensor<2x3x4xf32>". For scalar tensors (empty shape), it returns "tensor<f32>".
func GoDTypeToMLIR ¶
GoDTypeToMLIR maps a Go reflect type name to a MLIR dtype string. Supported mappings:
float32 -> f32 float64 -> f64 float16 -> f16 bfloat16 -> bf16 float8 -> f8E4M3FN int8 -> i8 int16 -> i16 int32 -> i32 int64 -> i64 uint8 -> ui8 uint32 -> ui32 uint64 -> ui64
Returns the MLIR dtype string and true if the mapping exists, or ("", false) otherwise.
func InferShape ¶
InferShape computes the output shape for a given operation name, input shapes, and optional attributes. It returns an error if the shapes are incompatible.
func InferStructuralShape ¶
InferStructuralShape computes the output shape for structural operations: MatMul, Transpose, Reshape, Concat, Slice, Gather, ReduceSum, ReduceMax, ReduceMean.
attrs supports the following keys depending on the operation:
- "perm" ([]int): axis permutation for Transpose
- "shape" ([]int): target shape for Reshape
- "axis" (int): concatenation axis for Concat, reduction axis for Reduce*
- "start" ([]int): start indices for Slice
- "end" ([]int): end indices for Slice
- "sliceSizes" ([]int): slice sizes for Gather
- "keepDims" (bool): whether to keep the reduced dimension for Reduce*
Types ¶
type Emitter ¶
type Emitter struct {
Namer *SSANamer
}
Emitter generates StableHLO MLIR text from operation inputs. Each emit method takes SSA input names, tensor shapes, and a dtype, and returns the emitted MLIR line(s) plus the output SSA name.
func (*Emitter) EmitAddScalar ¶
func (e *Emitter) EmitAddScalar(input string, scalar float64, shape []int, dtype string) (string, string)
EmitAddScalar emits stablehlo.constant + broadcast_in_dim + add.
func (*Emitter) EmitBinaryElementwise ¶
func (e *Emitter) EmitBinaryElementwise(opName, lhs, rhs string, shape []int, dtype string) (mlir, outName string)
EmitBinaryElementwise emits a binary element-wise op (add, subtract, multiply, divide, power). Both inputs must have the same shape and dtype.
func (*Emitter) EmitDivScalar ¶
func (e *Emitter) EmitDivScalar(input string, scalar float64, shape []int, dtype string) (string, string)
EmitDivScalar emits stablehlo.constant + broadcast_in_dim + divide.
func (*Emitter) EmitMulScalar ¶
func (e *Emitter) EmitMulScalar(input string, scalar float64, shape []int, dtype string) (string, string)
EmitMulScalar emits stablehlo.constant + broadcast_in_dim + multiply.
func (*Emitter) EmitOp ¶
func (e *Emitter) EmitOp(opName string, inputs []string, shape []int, dtype string, attrs map[string]any) (string, string, error)
EmitOp dispatches to the appropriate emit function based on the engine op name. For binary ops, inputs should be [lhs, rhs]. For unary ops, inputs should be [input]. For scalar ops, inputs should be [input] and attrs must contain "scalar" (float64). Returns the emitted MLIR text and the output SSA name.
func (*Emitter) EmitScalarOp ¶
func (e *Emitter) EmitScalarOp(elemOp, input string, scalar float64, shape []int, dtype string) (mlir, outName string)
EmitScalarOp emits a scalar operation as three MLIR instructions:
- stablehlo.constant for the scalar value
- stablehlo.broadcast_in_dim to broadcast to the tensor shape
- The element-wise binary op
Returns all three lines (newline-separated) and the final output SSA name.
type KVCacheSlot ¶
type KVCacheSlot struct {
InputSlot int // slot index where the KV cache is read (becomes a function arg)
OutputSlot int // slot index where the updated KV cache is produced (becomes a return value)
Shape []int // tensor shape (e.g., [num_heads, seq_len, head_dim])
SeqAxis int // axis along which decode concatenation occurs
}
KVCacheSlot describes a stateful KV cache slot that must be rewritten as explicit function I/O for PJRT's pure-functional execution model.
In the original graph, the KV cache is fed back via StatefulInputNode. For PJRT, each KV cache tensor becomes both a function argument (the previous state) and a return value (the updated state).
type ProgramOp ¶
type ProgramOp struct {
OpName string // Engine method name (e.g., "Add", "MatMul", "Softmax")
InputSlots []int // indices into the slot table for inputs
OutputSlot int // index for the output
InputShapes [][]int // shapes of each input
OutputShape []int // shape of the output
Dtype string // "f32", "f16", etc.
Attrs map[string]any // op-specific attributes
}
ProgramOp describes a single operation in a StableHLO program.