codegen

package
v0.1.15 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 15, 2026 License: AGPL-3.0 Imports: 17 Imported by: 0

Documentation

Overview

Package codegen lowers each kernel's SINK-rooted UOp tree, linearizes the instruction sequence, and renders device source (WGSL first).

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func ApplyOpt

func ApplyOpt(sink uop.UOp, opt Opt) uop.UOp

ApplyOpt applies an optimization to a kernel SINK-rooted AST.

OptVec4Load is the one kind whose eligibility depends on buffer shapes (the real K/N extents are not recoverable from a post-OptTile AST, which only carries ceil(K/TS) and TS); without that information this entry point refuses it (returns sink unchanged). Use ApplyOptBufs or ApplyOpts.

func ApplyOptBufs added in v0.1.12

func ApplyOptBufs(sink uop.UOp, opt Opt, bufs []schedule.Buffer) uop.UOp

ApplyOptBufs applies an optimization to a kernel SINK-rooted AST, with the kernel's buffer table available for opts whose eligibility depends on buffer shapes/dtypes (OptVec4Load). For every other kind it behaves exactly like ApplyOpt.

func ApplyOpts

func ApplyOpts(item schedule.ExecItem, opts []Opt) schedule.ExecItem

ApplyOpts applies a sequence of optimizations to a kernel's AST. When any opt changes the AST, the item's pre-rendered WGSL and the other render-derived fields are invalidated (via ExecItem.SetAst) so executors re-render from the opted AST instead of running the stale schedule-cache pre-render; an all-no-op sequence keeps the cached render intact.

func BeamApplyToItems

func BeamApplyToItems(items []schedule.ExecItem, exec backend.Executor, bench backend.Benchmarker) []schedule.ExecItem

BeamApplyToItems applies beam-cached opts to each item in the schedule slice.

Default mode (ANNEAL_BEAM unset or "0"): reads the disk cache (loaded once at package init); applies cached opts on hit after the WGSL-hash guard; falls back to identity on miss. No search runs; no GPU work added; the per-call overhead is one O(1) map lookup plus a cheap WGSL render on cache hits.

Search mode (ANNEAL_BEAM="1"): runs BeamSearch for every cache-missing kernel, persists the winner (opts + WGSL hash) to ~/.cache/anneal/beam_cache.json, and applies it. Blocks synchronously for the duration of each search.

Value-identity guard (cache hits with non-empty opts): the opted kernel's WGSL is rendered and its FNV-64a hash compared with the stored hash. A mismatch means an SK collision (two different kernels share the same 64-bit structural key). The guard logs a warning and falls back to identity. The risk is bounded by two independent 64-bit hashes colliding simultaneously (~2⁻⁶⁴ per kernel pair).

The returned slice shares Bufs and SymVars with the input but may carry a different Ast and pre-filled WGSL/LocalSize/WorkgroupCount so the executor skips re-render.

func BeamCacheReset

func BeamCacheReset()

BeamCacheReset clears all cache entries. Used by tests to isolate runs.

func BeamCacheStore

func BeamCacheStore(sk uint64, opts []Opt)

BeamCacheStore records a winning opt sequence for kernel SK. Pass opts=nil to record that identity was the winner.

func BeamDiskCacheInject

func BeamDiskCacheInject(sk uint64, opts []Opt, wgslHash string)

BeamDiskCacheInject stores a synthetic entry keyed by SK. Used by tests to exercise the SK-collision guard without needing a real hash collision.

func BeamDiskCacheReset

func BeamDiskCacheReset()

BeamDiskCacheReset clears the in-memory disk cache without deleting the file. Used by tests to inject synthetic entries and isolate runs.

func BeamWGSLHash

func BeamWGSLHash(wgsl string) string

BeamWGSLHash computes the normalized FNV-64a WGSL hash used by the value-identity guard. Exported so tests can compute the expected stored hash when injecting fake disk-cache entries via BeamDiskCacheInject.

func BufferByteSize added in v0.1.9

func BufferByteSize(elems int64, d *uop.DType) uint64

BufferByteSize returns the total GPU buffer byte size for a buffer of elems logical elements of dtype d. For image dtypes the buffer holds ceil(elems/4) vec4 slots (16 bytes each); for everything else it is the pre-image elems * SizeBytes formula every call site previously used.

func CompileWGSL

func CompileWGSL(item schedule.ExecItem) (schedule.RenderResult, error)

CompileWGSL converts a kernel's SINK AST to WGSL.

func KernelHasTiledReduce added in v0.1.4

func KernelHasTiledReduce(sink uop.UOp) bool

KernelHasTiledReduce reports whether sink contains at least one OpReduce tagged with "tile:" (set by OptTile). This is the prerequisite for the emitTiledReduce lowerer path that actually consumes AxisUpcast positions per (mr, nr) iteration. Without it, AxisUpcast lowering assigns the placeholder expression "0" and each thread silently writes only lane 0 of its factor-wide stripe.

Used by applyUpcast (fail-loud assertion at opt-application time), by ActionSpace (pre-filter so BEAM's probe never triggers the assertion), and by spray-mode callers that need to apply OptUpcast across a kernel list containing non-matmul items.

func KernelSK

func KernelSK(item schedule.ExecItem) uint64

KernelSK returns the structural key of the SINK-rooted kernel AST in item. The key is stable under arena append-only growth: StructuralKeys mixes children's SK values (not arena positions), so the original node's SK is invariant once built.

func RenderWGSL

func RenderWGSL(item schedule.ExecItem) schedule.RenderResult

RenderWGSL converts a kernel's SINK AST to a WGSL compute shader string.

func Vec4LoadParams added in v0.1.12

func Vec4LoadParams(sink uop.UOp) map[int]bool

Vec4LoadParams returns the set of param indices that must bind as array<vec4<f32>> because OptVec4Load tagged the kernel's tiled reduce. Returns nil for any kernel without a ":vec4"-tagged tilable reduce. This is the single derivation point shared by the renderer (binding emission in wgsl.go) and tests; the lowerer reads the same tag via tileTagParse.

Types

type BeamConfig

type BeamConfig struct {
	Width    int // beam width: candidates kept per depth round
	MaxDepth int // maximum opt-sequence length to explore
	Warmup   int // per-candidate benchmark warmup iterations
	Iters    int // per-candidate benchmark measurement iterations
}

BeamConfig parameterises the beam search.

func DefaultBeamConfig

func DefaultBeamConfig() BeamConfig

DefaultBeamConfig returns sensible defaults. BEAM_WIDTH and MAX_DEPTH can be overridden via environment variables.

type BeamResult

type BeamResult struct {
	Opts       []Opt   // winning opt sequence; nil means identity
	MinMicros  float64 // winner's min-of-N timing
	BaseMicros float64 // identity baseline min-of-N timing
	Searched   int     // candidates successfully benchmarked
	WallNs     int64   // search wall-clock nanoseconds
	FromCache  bool    // true if result came from the beam cache
}

BeamResult holds the output of a beam search for one kernel.

func BeamSearch

func BeamSearch(
	exec backend.Executor,
	bench backend.Benchmarker,
	item schedule.ExecItem,
	cfg BeamConfig,
) BeamResult

BeamSearch runs a beam search over opt sequences for a single kernel item. exec runs kernels for value-identity checks; bench times each candidate.

Correctness invariant: every returned opt sequence produces output bit-identical to the identity baseline (max-abs-diff == 0) on a small fixed test input. Any candidate that fails this check is silently dropped.

Termination bound: the search stops when no depth-D+1 candidate improves over the current best, or when MaxDepth is reached. Total candidates evaluated ≤ MaxDepth × BeamWidth × |ActionSpace| (bounded and finite).

type Instr

type Instr struct {
	Kind InstrKind

	// InstrBoundsCheck, InstrStore (scalar guard)
	TotalN int64

	// InstrGIDVar, InstrLoopBegin
	RangeID   int
	RangeSize int64

	// InstrGIDVar only
	Stride    int64
	Component int // 0:x, 1:y, 2:z
	Level     int // 0:Global (gid), 1:Workgroup (wid), 2:Local (lid)

	// InstrGIDVar (symbolic-stride case): WGSL u32 expression to use as the
	// divisor instead of the literal Stride int64. Set when the stride product
	// to the inner side of this range involves a symbolic factor (Slice 7b:
	// non-outermost symbolic dim). When non-empty, supersedes Stride; the
	// renderer emits `(base / <StrideExpr>) % rangeSize` (or `base / <StrideExpr>`
	// when Symbolic). When empty, the renderer falls back to the int64 Stride
	// path (byte-identical Slice 1–7a output).
	StrideExpr string

	// InstrBoundsCheck, InstrGIDVar, InstrLoopBegin: true when the range size is
	// symbolic (read from the params_n storage buffer at runtime).
	Symbolic bool

	// InstrLoopBegin (symbolic only): which params_n slot holds the loop
	// bound when the bound is a bare DefineVar. Used only when SymBoundExpr
	// is empty; the ALU-bound path populates SymBoundExpr directly and
	// supersedes this. InstrBoundsCheck always populates SymBoundExpr and
	// ignores this field.
	SymParamIdx int

	// InstrLoopBegin / InstrBoundsCheck (symbolic only): the symbolic bound
	// rendered as a WGSL u32 expression (e.g. "(params_n.n0 * 4u)" for
	// reshape-merge derived bounds). When non-empty, supersedes
	// SymParamIdx — the renderer uses this expression directly as the
	// loop bound / dispatch multiplier. Populated by renderSymBoundExpr
	// in the lowerer for ALU-typed OpRange bounds.
	SymBoundExpr string

	// InstrLoopBegin (symbolic only): true when rules.IndexDtypeForBound for
	// this loop's bound would have selected Int64 (vmax doesn't fit in int32).
	// WGSL has no i64, so the renderer emits an acknowledging comment but
	// still produces i32 — mirroring tinygrad PR #8268's WebGPU edge case.
	// On a future non-WebGPU backend the dtype decision would be honored.
	Int64Downgraded bool

	// InstrGIDVar (symbolic only): the rendered WGSL u32 expression for this
	// range's bound, used to emit a per-axis guard `if (r{N} >= <expr>) { return; }`
	// after the r{N} let-binding. Populated for sym ranges in multi-dim sym
	// dispatch; empty for static (which uses ins.RangeSize as the literal) and
	// for the legacy 1D-flatten path. Cooperates with the static-path
	// `if (r{N} >= RangeSize)` guard at wgsl.go:207.
	AxisGuardExpr string

	// InstrAccInit, InstrAccUpdate
	AccIdx   int
	WGSLType string // for InstrAccInit
	Identity string // for InstrAccInit
	AccOp    uop.Op // for InstrAccUpdate

	// InstrLet, InstrDefineLocal
	NodeIdx uint32
	DType   *uop.DType

	// InstrDefineLocal
	LocalName string
	LocalSize int

	// InstrLet, InstrAccUpdate, InstrStore, InstrIf, InstrAssign
	Expr      string
	IndexExpr string // for InstrStore, InstrAssign (LHS)

	// Name overrides the auto-derived `t{NodeIdx}` naming for InstrLet
	// (used by the B3 register-blocking codegen to emit named rA_k_mr /
	// rB_k_nr per-K register loads).
	Name string
}

Instr is one linearized instruction in the kernel. Fields are interpreted according to Kind; unused fields are zero.

func Lower

func Lower(item schedule.ExecItem) ([]Instr, [3]int, [3]int, [3]schedule.DimDispatch)

Lower converts one kernel's SINK AST into a linear instruction sequence. Instructions are in emit order; loop nesting depth is tracked by the renderer. symDispatch carries per-dim runtime extent info for symbolic kernels; entries with non-empty SymFactors instruct the executor to override workgroupCount[d] per binding.

type InstrKind

type InstrKind int

InstrKind identifies the type of a linearized instruction.

const (
	// InstrBoundsCheck emits: if (gid_x >= <SymBoundExpr>) { return; }
	// for symbolic kernels (Slice 7b: SymBoundExpr is the full trailingProduct
	// over all loopRanges, possibly involving multiple sym vars). For static
	// kernels the bound is encoded by workgroupCount * workgroupSize already
	// matching totalOut, so the bounds check is a no-op (Symbolic=false).
	InstrBoundsCheck InstrKind = iota
	// InstrGIDVar emits: let r_RangeID: i32 = i32((gid.x / Stride) % RangeSize);
	InstrGIDVar
	// InstrLoopBegin emits: for (var r_RangeID: i32 = 0; r_RangeID < RangeSize; r_RangeID++) {
	InstrLoopBegin
	// InstrLoopEnd emits: }
	InstrLoopEnd
	// InstrAccInit emits: var acc_AccIdx: WGSLType = Identity;
	InstrAccInit
	// InstrAccUpdate emits: acc_AccIdx = combine(AccOp, acc_AccIdx, Expr);
	InstrAccUpdate
	// InstrLet emits: let t_NodeIdx: WGSLType = Expr;
	InstrLet
	// InstrStore emits: data0[IndexExpr] = Expr;
	InstrStore
	// InstrDefineLocal emits: var<workgroup> LocalName: array<WGSLType, LocalSize>;
	InstrDefineLocal
	// InstrBarrier emits: workgroupBarrier();
	InstrBarrier
	// InstrIf emits: if (Cond) {
	InstrIf
	// InstrEndIf emits: }
	InstrEndIf
	// InstrAssign emits: IndexExpr = Expr;
	InstrAssign
	// InstrImgLaneBegin opens the image slot-dispatch lane pass: guards the
	// thread to one vec4 output slot (gid_x), declares the thread-private
	// vec4 accumulator, and opens the 4-lane loop with the tail mask
	// `if (_img_flat < TotalN)`. TotalN is the logical element count.
	InstrImgLaneBegin
	// InstrImgLaneStore assigns Expr to the _img_out component selected by
	// the current _img_lane (static-swizzle cascade on a private var).
	InstrImgLaneStore
	// InstrImgLaneEnd closes the tail mask and lane loop, then writes the
	// whole vec4 slot once: data0[gid_x] = _img_out;
	InstrImgLaneEnd
)

type Opt

type Opt struct {
	Kind OptKind
	Axis int
	Arg  int
}

Opt captures one kernel optimization: an op kind, an axis index, and an int arg.

func ActionSpace

func ActionSpace(sink uop.UOp) []Opt

ActionSpace returns every Opt that non-trivially transforms sink. "Non-trivial" means the returned UOp has a different arena index than sink (i.e. ApplyOpt created new nodes rather than returning sink unchanged). Call on the live (possibly already-optimised) sink at each search depth.

Without a buffer table, shape-gated opts (OptVec4Load) are never proposed; BeamSearch uses ActionSpaceBufs with the item's Bufs.

func ActionSpaceBufs added in v0.1.12

func ActionSpaceBufs(sink uop.UOp, bufs []schedule.Buffer) []Opt

ActionSpaceBufs is ActionSpace with the kernel's buffer table, enabling opts whose eligibility depends on buffer shapes/dtypes (OptVec4Load).

func BeamCacheLookup

func BeamCacheLookup(sk uint64) ([]Opt, bool)

BeamCacheLookup returns the cached opt sequence for kernel SK. Returns (opts, true) on hit; opts may be nil (identity won).

type OptKind

type OptKind int

OptKind identifies the type of a kernel optimization.

const (
	// OptIdentity returns the kernel AST unchanged.
	OptIdentity OptKind = iota
	// OptLocal splits an axis into workgroup and local dimensions.
	OptLocal
	// OptTile blocks a reduction axis and uses shared memory tiling.
	OptTile
	// OptUpcast splits a parallel (output) axis into outer + inner-unrolled
	// micro-tile stripes. Each thread covers `factor` sequential outputs in
	// that dim, enabling register blocking when composed with OptTile.
	OptUpcast
	// OptVectorize splits a parallel axis into outer + AxisVectorize(width) inner.
	// The inner dimension is emitted as SIMD vec4 operations in the lowerer.
	// Rejects AxisReduce and Symbolic axes. Only width=4 has a working lowerer.
	// Compose after OptTile+OptUpcast for the B3.7 register-blocking+vec4 pipeline.
	OptVectorize
	// OptVec4Load rebinds BOTH input params of an OptTile-tagged tilable-matmul
	// kernel as array<vec4<f32>> storage buffers and rewrites the shared-memory
	// tile fills in emitTiledReduce to whole-vec4 global loads (4 elements per
	// load — a genuine 128-bit `device float4*` load under the naga MSL
	// lowering, unlike OptVectorize's compute-only vec4 packaging). Axis and
	// Arg are unused. Compose after OptTile (any of the tile-only / +OptUpcast
	// / +OptVectorize stacks). Eligibility needs the kernel's buffer shapes, so
	// apply through ApplyOptBufs / ApplyOpts; bare ApplyOpt refuses.
	OptVec4Load
)

type WGSLTypeInfo added in v0.1.4

type WGSLTypeInfo struct {
	Scalar     string
	BufferElem string
	SizeBytes  uint64
}

WGSLTypeInfo carries per-dtype WGSL metadata in one place: the scalar type literal used in expression context, the storage buffer element type (which promotes bool and bf16 to u32), and the per-element byte size on the GPU.

Single source of truth for the three previously split lookups (wgslDType, wgslBufferElemType, elemBytes). Backend code that needs only the byte size reads WGSLTypeInfoFor(d).SizeBytes; the renderer reads .Scalar and .BufferElem.

func WGSLTypeInfoFor added in v0.1.4

func WGSLTypeInfoFor(d *uop.DType) WGSLTypeInfo

WGSLTypeInfoFor returns the WGSL metadata for a dtype. Pointer dtypes are unwrapped to their base. Nil and Void are treated as f32 (matches the pre-consolidation behaviour of wgslDType, wgslBufferElemType, and elemBytes).

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL