Documentation
¶
Overview ¶
Package graph provides a composable execution graph where nodes are Modules.
A Graph is itself a Module — the key composition primitive. Any trained graph can be placed as a node inside a larger graph, enabling hierarchical composition of models (Graph-as-Module).
Build graphs using the fluent API:
g, err := graph.From(encoder).
Through(relu).Tag("hidden").
Through(decoder).Tag("loss").
Build()
result := g.Forward(input) // Graph implements nn.Module
Nodes in the same topological level (no dependencies on each other) execute concurrently via goroutines. This naturally parallelizes Split branches without any special configuration.
Observation ¶
Tagged node outputs are captured during Forward and can be logged, collected into batch buffers, flushed to epoch-level history, and queried as trends. See Graph.Log, Graph.Collect, Graph.Flush, and Graph.Trend in observe.go for details.
Observation layer for training metrics.
Tags already name meaningful nodes in the graph. The observation layer builds on this: every Forward call captures tagged node outputs, making them available for logging, collection, and trend analysis.
Logging ¶
Log prints current tagged values. The default writes to stderr; replace it with OnLog for custom handling (structured logging, etc).
g.Forward(input)
g.Log("loss") // loss: 0.2341
g.Log() // all tags
Collecting, Recording, and Flushing ¶
Collect snapshots scalar values from tagged graph nodes into a batch buffer (within an epoch). Record injects external scalar values into the same buffer — use this for metrics computed outside the graph (e.g. losses that need both graph outputs and external targets). Flush promotes the batch mean to epoch-level history, then clears the buffer. This two-level structure gives you both fine-grained batch data and coarse-grained epoch trends from the same mechanism.
for epoch := range epochs {
for _, batch := range loader {
g.Forward(batch.Input)
g.Collect("loss") // from graph tag
loss := nn.CrossEntropyLoss(pred, target)
g.Record("ext_loss", loss.Item()) // from outside
}
g.Flush() // batch mean → epoch history
}
Trends ¶
Trend returns the epoch-level time series for a tag, with built-in statistical queries that drive training decisions:
if g.Trend("loss").Stalled(5, 1e-4) {
scheduler.Decay() // reduce LR on plateau
}
if g.Trend("loss").Improving(3) {
g.Unfreeze("decoder") // start fine-tuning
}
current := g.Trend("loss").Latest() // most recent epoch value
Flush timing and ETA ¶
Each Flush records a wall-clock timestamp. This enables built-in ETA calculation and per-epoch duration tracking:
g.Flush()
fmt.Printf("ETA: %s\n", graph.FormatDuration(g.ETA(100)))
fmt.Printf("last epoch: %s\n", graph.FormatDuration(g.LastFlushDuration()))
Sub-graph observation ¶
Since a Graph is a Module, sub-graphs run automatically as part of the parent's Forward — no separate Forward call needed. Use Sub to reach into a sub-graph and observe its internal tags:
// Build a sub-graph with its own tags.
refiner, _ := graph.From(step).Tag("residual").Build()
// Compose into root graph.
g, _ := graph.From(encoder).
Through(refiner).Tag("refine").
Through(lossModule).Tag("loss").
Build()
// Training loop — only Forward on root.
for epoch := range epochs {
for _, batch := range loader {
g.Forward(batch.Input)
g.Collect("loss") // root tag
g.Sub("refine").Collect("residual") // sub-graph tag
}
g.Flush()
g.Sub("refine").Flush()
if g.Trend("loss").Stalled(5, 1e-4) {
scheduler.Decay()
}
}
Training visualization: HTML charts, CSV export, and text logs for epoch trends.
PlotHTML generates a self-contained HTML file with interactive training curves — no external dependencies, no npm, no CDN. Open in any browser.
g.PlotHTML("training.html", "loss", "head_0", "head_1")
ExportTrends writes epoch history to CSV for external analysis tools.
g.ExportTrends("metrics.csv", "loss", "accuracy")
WriteLog writes a human-readable training log with per-epoch metrics, timing, and ETA:
g.WriteLog("training.log", 100, "loss", "hit_rate")
Profiling captures per-node and per-level execution timings during Forward.
Profiling is opt-in: call Graph.EnableProfiling to start recording. When disabled (the default), profiling adds zero overhead — the bool gate is never true and no timing calls are made.
After each Forward call with profiling enabled, Graph.Profile returns a Profile with per-node durations, per-level wall-clock times, and parallelism efficiency metrics for multi-node levels.
Timing Trends ¶
Profiling integrates with the observation layer for epoch-level analysis. Graph.CollectTimings snapshots tagged node durations into a batch buffer, Graph.FlushTimings promotes the batch mean to epoch history, and Graph.TimingTrend returns a Trend over the timing history:
g.EnableProfiling()
for epoch := range epochs {
for loader.Next() {
g.Forward(input)
g.CollectTimings("encoder", "decoder")
}
g.FlushTimings()
fmt.Printf("encoder: %v (slope: %.4f)\n",
g.Timing("encoder"),
g.TimingTrend("encoder").Slope(5))
}
Index ¶
- Constants
- func Add() nn.Module
- func ArgmaxSelector(inputDim int64, numBranches int) nn.Module
- func Cat(dim int) nn.Module
- func FixedSelector(index int) nn.Module
- func FormatDuration(d time.Duration) string
- func LearnedHalt(inputDim int64) nn.Module
- func Mean() nn.Module
- func Reshape(shape ...int64) nn.Module
- func SigmoidRouter(inputDim int64, numExperts int) nn.Module
- func SoftmaxRouter(inputDim int64, numExperts int) nn.Module
- func StateAdd() nn.Module
- func ThresholdHalt(threshold float32) nn.Module
- type Edge
- type FlowBuilder
- func (fb *FlowBuilder) Also(m nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Build() (*Graph, error)
- func (fb *FlowBuilder) Gate(router nn.Module, experts ...nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Input(names ...string) *FlowBuilder
- func (fb *FlowBuilder) Loop(body nn.Module) *LoopBuilder
- func (fb *FlowBuilder) Map(body nn.Module) *MapBuilder
- func (fb *FlowBuilder) Merge(m nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Split(modules ...nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Switch(router nn.Module, branches ...nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Tag(name string) *FlowBuilder
- func (fb *FlowBuilder) TagGroup(name string) *FlowBuilder
- func (fb *FlowBuilder) Through(m nn.Module) *FlowBuilder
- func (fb *FlowBuilder) Using(refs ...string) *FlowBuilder
- type Graph
- func (g *Graph) Collect(tags ...string)
- func (g *Graph) CollectTimings(tags ...string)
- func (g *Graph) Collected(tag string) []float64
- func (g *Graph) DOT() string
- func (g *Graph) DOTWithProfile() string
- func (g *Graph) DetachState()
- func (g *Graph) Device() *tensor.Device
- func (g *Graph) DisableProfiling()
- func (g *Graph) ETA(totalEpochs int) time.Duration
- func (g *Graph) Elapsed() time.Duration
- func (g *Graph) EnableProfiling()
- func (g *Graph) ExportTimingTrends(path string, tags ...string) error
- func (g *Graph) ExportTrends(path string, tags ...string) error
- func (g *Graph) Flush(tags ...string)
- func (g *Graph) FlushCount() int
- func (g *Graph) FlushTimings(tags ...string)
- func (g *Graph) Forward(inputs ...*autograd.Variable) *autograd.Variable
- func (g *Graph) ForwardCtx(ctx context.Context, inputs ...*autograd.Variable) *autograd.Variable
- func (g *Graph) Freeze(tags ...string)
- func (g *Graph) LastFlushDuration() time.Duration
- func (g *Graph) Log(tags ...string)
- func (g *Graph) OnFlush(fn func(flushed map[string]float64))
- func (g *Graph) OnLog(fn func(values map[string]*autograd.Variable))
- func (g *Graph) OnProfile(fn func(*Profile))
- func (g *Graph) Parameters() []*nn.Parameter
- func (g *Graph) ParametersByTag(tag string) []*nn.Parameter
- func (g *Graph) PlotHTML(path string, tags ...string) error
- func (g *Graph) PlotTimingsHTML(path string, tags ...string) error
- func (g *Graph) Profile() *Profile
- func (g *Graph) Profiling() bool
- func (g *Graph) Record(tag string, values ...float64)
- func (g *Graph) ResetState()
- func (g *Graph) ResetTimingTrend(tags ...string)
- func (g *Graph) ResetTrend(tags ...string)
- func (g *Graph) SVG(path ...string) ([]byte, error)
- func (g *Graph) SVGWithProfile(path ...string) ([]byte, error)
- func (g *Graph) SetDevice(device tensor.Device)
- func (g *Graph) SetTraining(training bool)
- func (g *Graph) Sub(tag string) *Graph
- func (g *Graph) TagGroup(name string) []string
- func (g *Graph) Tagged(tag string) *autograd.Variable
- func (g *Graph) Timing(tag string) time.Duration
- func (g *Graph) TimingTrend(tag string) *Trend
- func (g *Graph) TimingTrends(tags ...string) TrendGroup
- func (g *Graph) Traces(tag string) []*autograd.Variable
- func (g *Graph) Trend(tag string) *Trend
- func (g *Graph) Trends(tags ...string) TrendGroup
- func (g *Graph) Unfreeze(tags ...string)
- func (g *Graph) WriteLog(path string, totalEpochs int, tags ...string) error
- func (g *Graph) ZeroFrozenGrads()
- type LevelTiming
- type LoopBuilder
- type MapBuilder
- type Node
- type NodeTiming
- type Profile
- type Trend
- func (t *Trend) Converged(window int, tol float64) bool
- func (t *Trend) Improving(window int) bool
- func (t *Trend) Last(n int) []float64
- func (t *Trend) Latest() float64
- func (t *Trend) Len() int
- func (t *Trend) Max() float64
- func (t *Trend) Mean() float64
- func (t *Trend) Min() float64
- func (t *Trend) Slope(window int) float64
- func (t *Trend) Stalled(window int, tol float64) bool
- func (t *Trend) Values() []float64
- type TrendGroup
- func (tg TrendGroup) AllConverged(window int, tol float64) bool
- func (tg TrendGroup) AllImproving(window int) bool
- func (tg TrendGroup) AllStalled(window int, tol float64) bool
- func (tg TrendGroup) AnyConverged(window int, tol float64) bool
- func (tg TrendGroup) AnyImproving(window int) bool
- func (tg TrendGroup) AnyStalled(window int, tol float64) bool
- func (tg TrendGroup) MeanSlope(window int) float64
- func (tg TrendGroup) Slopes(window int) []float64
Constants ¶
const ( // DefaultInput is the port name for single-input modules. DefaultInput = "input" // DefaultOutput is the port name for single-output modules. DefaultOutput = "output" )
Variables ¶
This section is empty.
Functions ¶
func Add ¶
Add returns a merge module that element-wise adds all inputs. Gradients flow to every input equally (each gets the full upstream gradient).
Used for residual connections and skip connections:
graph.From(encoder).
Split(branchA, branchB).
Merge(graph.Add()).
Build()
func ArgmaxSelector ¶
ArgmaxSelector returns a Switch router with a learnable linear projection. It picks the branch whose logit is highest (argmax).
Selection is non-differentiable — gradients flow through the selected branch only. The projection parameters are included in the graph's Parameters() for training with policy-gradient methods if desired.
Switch(graph.ArgmaxSelector(hidden, 3), branchA, branchB, branchC)
func Cat ¶
Cat returns a merge module that concatenates all inputs along the given dimension.
graph.From(encoder).
Split(branchA, branchB).
Merge(graph.Cat(1)).
Through(nn.MustLinear(combined, hidden)).
Build()
func FixedSelector ¶
FixedSelector returns a Switch router that always selects the same branch. Useful for testing, ablation studies, or static configurations.
Switch(graph.FixedSelector(0), branchA, branchB)
func FormatDuration ¶ added in v0.2.0
FormatDuration formats a duration for training logs: <1s → "42ms", <1m → "1.2s", ≥1m → "2m05s".
func LearnedHalt ¶
LearnedHalt returns a halt condition for Loop.Until with a learnable linear probe. The probe projects the state to a scalar — iteration stops when the output is positive.
This is the Adaptive Computation Time (ACT) pattern: the network learns when to stop iterating.
Loop(body).Until(graph.LearnedHalt(hidden), 20)
func Mean ¶
Mean returns a merge module that averages all inputs element-wise.
graph.From(encoder).
Split(branchA, branchB).
Merge(graph.Mean()).
Build()
func Reshape ¶
Reshape returns a zero-parameter module that reshapes its input to the given shape. This is a graph-level primitive useful for adapting tensor dimensions between modules.
graph.From(encoder).
Through(graph.Reshape(4, 2)). // [1, 8] → [4, 2]
Map(readHead(2)).Each().
Through(graph.Reshape(1, 8)). // [4, 2] → [1, 8]
Build()
func SigmoidRouter ¶
SigmoidRouter returns a Gate router that produces independent sigmoid weights over numExperts experts. Unlike SoftmaxRouter, weights do not sum to 1 — each expert is gated independently between 0 and 1.
Gate(graph.SigmoidRouter(hidden, 2), expertA, expertB)
func SoftmaxRouter ¶
SoftmaxRouter returns a Gate router that produces softmax-normalized weights over numExperts experts.
When the router receives multiple inputs (via Gate.Using), they are summed element-wise before projection — this lets the router see both the stream and any tagged references without changing dimensions.
graph.From(embed).
Tag("ctx").
Through(layer).
Gate(graph.SoftmaxRouter(hidden, 3), expertA, expertB, expertC).Using("ctx").
Build()
func StateAdd ¶
StateAdd returns an additive state cell for use with forward references (Using before Tag).
On the first Forward call, the state is auto-zeroed by the graph, so stream + zeros = stream (pass-through). On subsequent calls, the accumulated state is added to the current stream.
graph.From(embed).
Through(graph.StateAdd()).Using("memory").
Tag("memory").
Build()
func ThresholdHalt ¶
ThresholdHalt returns a halt condition for Loop.Until that signals stop when the maximum element of the state exceeds the threshold.
Loop(body).Until(graph.ThresholdHalt(50), 20)
Types ¶
type Edge ¶
type Edge struct {
// contains filtered or unexported fields
}
Edge connects an output port of one node to an input port of another.
type FlowBuilder ¶
type FlowBuilder struct {
// contains filtered or unexported fields
}
FlowBuilder builds a graph using a fluent API that reads as data flow.
g, err := graph.From(encoder).
Through(attention).
Through(decoder).
Build()
func From ¶
func From(m nn.Module) *FlowBuilder
From starts a new graph flow at the given module. The module's input ports become the graph's inputs.
func (*FlowBuilder) Also ¶
func (fb *FlowBuilder) Also(m nn.Module) *FlowBuilder
Also creates a residual connection: the input passes through the module, and the result is added element-wise back to the original. output = input + module(input)
func (*FlowBuilder) Build ¶
func (fb *FlowBuilder) Build() (*Graph, error)
Build finalizes the graph. The current stream's output becomes the graph's output. Returns an error if the flow has open branches or structural problems.
func (*FlowBuilder) Gate ¶
func (fb *FlowBuilder) Gate(router nn.Module, experts ...nn.Module) *FlowBuilder
Gate creates a gated routing construct. A router module produces weights over a set of expert modules. All experts execute on the current stream, and their outputs are combined using the router's weights.
The router is responsible for its own normalization strategy:
- Softmax for standard MoE (weights sum to 1)
- Sigmoid for independent gating (each expert 0-1)
- Top-k + softmax for sparse routing
- Or any other scheme — the merge just applies the weights as-is
The router must output a tensor of shape [..., numExperts]. Each expert receives the current stream and produces an output of the same shape.
Use Using after Gate to wire additional tagged references to the router:
graph.From(encoder).
Through(layer1).Tag("features").
Through(layer2).
Gate(router, expertA, expertB, expertC).Using("features").
Through(decoder).
Build()
func (*FlowBuilder) Input ¶
func (fb *FlowBuilder) Input(names ...string) *FlowBuilder
Input adds named auxiliary inputs to the graph. Each name creates a passthrough node, tags it, and exposes it as a graph-level input. The main flow continues unchanged — downstream nodes consume inputs via FlowBuilder.Using.
Forward receives inputs in declaration order: the From entry first, then each Input name in the order they appear.
g, _ := graph.From(encoder).Tag("image").
Input("case", "context").
Through(decoder).Using("image", "case", "context").
Build()
g.Forward(img, caseLabel, ctx) // three inputs
func (*FlowBuilder) Loop ¶
func (fb *FlowBuilder) Loop(body nn.Module) *LoopBuilder
Loop starts a loop construct that repeats a body module, carrying state between iterations. Call .For(n) to set the iteration count.
graph.From(encoder).
Loop(refinementStep).For(5).
Through(decoder).
Build()
func (*FlowBuilder) Map ¶
func (fb *FlowBuilder) Map(body nn.Module) *MapBuilder
Map starts a map construct that applies a body module independently to each element along dim 0 of a tensor. Results are concatenated back along dim 0.
Call .Over(tag) to iterate over a tagged tensor, or .Each() to iterate over the current stream. Additional Using refs are broadcast to every invocation.
graph.From(positionDecoder).
Map(readHead).Each().Using("image").
Through(decoder).
Build()
func (*FlowBuilder) Merge ¶
func (fb *FlowBuilder) Merge(m nn.Module) *FlowBuilder
Merge combines parallel streams using the given module. The module receives all branch outputs as its variadic inputs.
func (*FlowBuilder) Split ¶
func (fb *FlowBuilder) Split(modules ...nn.Module) *FlowBuilder
Split forks the flow into parallel branches, one per module. Each branch receives the current stream's output as input. Call Merge to recombine the branches.
func (*FlowBuilder) Switch ¶
func (fb *FlowBuilder) Switch(router nn.Module, branches ...nn.Module) *FlowBuilder
Switch creates a hard-routing conditional construct. A router module selects which branch to execute based on its output.
The router must return a scalar (1-element tensor) containing the 0-based branch index. The router owns the selection logic — argmax, sampling, round-robin, or any other strategy. Only the selected branch executes; unselected branches are skipped entirely.
The selection is non-differentiable — gradients flow through the selected branch only. The router does not receive gradient through the selection. For differentiable routing, use Gate instead.
Use Using after Switch to wire additional tagged references to the router:
graph.From(encoder).
Through(layer1).Tag("features").
Through(layer2).
Switch(router, branchA, branchB, branchC).Using("features").
Through(decoder).
Build()
func (*FlowBuilder) Tag ¶
func (fb *FlowBuilder) Tag(name string) *FlowBuilder
Tag names the current position in the flow so it can be referenced later with Using. The name must be unique within the graph.
graph.From(encoder).
Through(layer1).Tag("hidden").
Through(layer2).
Through(crossAttn).Using("hidden").
Build()
func (*FlowBuilder) TagGroup ¶
func (fb *FlowBuilder) TagGroup(name string) *FlowBuilder
TagGroup names each stream in a multi-stream flow with auto-suffixed tags. Given a group name, it creates tags "name_0", "name_1", etc. — one per branch. The group is registered so that Graph.Trends and Graph.TimingTrends can expand it automatically.
TagGroup requires multiple streams (after Split). For single-stream tagging, use FlowBuilder.Tag instead.
graph.From(encoder).
Split(headA, headB, headC).TagGroup("head").
Merge(graph.Mean()).
Build()
// Creates tags: "head_0", "head_1", "head_2"
// Group "head" → ["head_0", "head_1", "head_2"]
func (*FlowBuilder) Through ¶
func (fb *FlowBuilder) Through(m nn.Module) *FlowBuilder
Through passes the flow through a module. Requires a single stream (call Merge first if the flow is split).
func (*FlowBuilder) Using ¶
func (fb *FlowBuilder) Using(refs ...string) *FlowBuilder
Using wires additional inputs from previously tagged points to the preceding node(s). After Through, Gate, Merge, or Also, it targets that single node. After Split, it targets all branch modules — each branch receives the tagged references as extra Forward arguments.
// Single target:
graph.From(encoder).Tag("memory").
Through(crossAttention).Using("memory").
Build()
// All branches:
graph.From(encoder).Tag("memory").
Split(headA, headB, headC).Using("memory").
Merge(concat).
Build()
type Graph ¶
type Graph struct {
// contains filtered or unexported fields
}
Graph is a composition of connected Nodes. It implements nn.Module, enabling Graph-as-Module composition — a graph can be a node in a parent graph.
Nodes that have no dependencies on each other execute in parallel via goroutines. This is determined at Build time from the graph topology — no runtime scheduling overhead.
func (*Graph) Collect ¶
Collect snapshots the current scalar value of the specified tagged nodes into the batch buffer. Call Flush to promote the batch mean to epoch-level history and clear the buffer.
for _, batch := range loader {
g.Forward(batch.Input)
g.Collect("loss")
}
g.Flush("loss")
func (*Graph) CollectTimings ¶
CollectTimings snapshots the execution duration of tagged nodes from the most recent Forward into a timing batch buffer. If no tags are specified, all tagged nodes with timing data are collected.
Call Graph.FlushTimings at epoch boundaries to promote the batch mean to epoch history.
func (*Graph) Collected ¶
Collected returns the raw batch-level buffer for a tag — all values since the last Flush. Returns nil if nothing has been collected.
func (*Graph) DOT ¶
DOT returns a Graphviz DOT representation of the graph. The output can be rendered with `dot -Tsvg graph.dot -o graph.svg` or pasted into an online viewer like https://dreampuf.github.io/GraphvizOnline.
Composite nodes (Switch, Loop.While, Loop.Until) are expanded into clusters showing their internal structure — branches, body, and condition modules.
g, _ := graph.From(encoder).Through(decoder).Build() fmt.Println(g.DOT())
func (*Graph) DOTWithProfile ¶
DOTWithProfile returns a timing-annotated DOT representation using the most recent Profile from a profiled Forward call. Nodes are color-coded green→yellow→red by relative execution time, with durations shown in labels and parallelism metrics on level clusters.
Returns the structural DOT (same as Graph.DOT) if no profile is available.
g.EnableProfiling() g.Forward(input) fmt.Println(g.DOTWithProfile())
func (*Graph) DetachState ¶
func (g *Graph) DetachState()
DetachState breaks the gradient chain on all forward-reference state buffers and any module-level state in the graph. Call this between training steps to prevent unbounded computation graph growth.
DetachState is recursive: it walks into sub-graphs and calls nn.Detach on any module implementing nn.Detachable. A single call on the top-level graph handles the entire hierarchy.
for epoch := range epochs {
for loader.Next() {
pred := g.Forward(input)
loss := nn.MSELoss(pred, target)
opt.ZeroGrad()
loss.Backward()
opt.Step()
g.DetachState() // prevents graph growth across batches
}
}
func (*Graph) Device ¶ added in v0.2.0
Device returns the device set on this graph, or nil if no device placement has been configured. When nil, no automatic input migration happens during Forward.
func (*Graph) DisableProfiling ¶
func (g *Graph) DisableProfiling()
DisableProfiling turns off timing. Subsequent Forward calls have zero profiling overhead.
func (*Graph) ETA ¶ added in v0.2.0
ETA estimates the remaining wall-clock time based on flush cadence. totalEpochs is the total expected number of Flush calls (epochs). Returns 0 if no flushes have occurred yet.
The estimate includes epoch 0's full duration — training start is recorded on the first Forward call, not the first Flush.
remaining := g.ETA(100)
fmt.Printf("ETA: %s\n", remaining)
func (*Graph) Elapsed ¶ added in v0.2.0
Elapsed returns the wall-clock time since training started (first Forward). Returns 0 if Forward has never been called.
func (*Graph) EnableProfiling ¶
func (g *Graph) EnableProfiling()
EnableProfiling turns on per-node and per-level timing for subsequent Forward calls. Overhead is ~20-50ns per node (two time.Now() calls). Call Graph.Profile after Forward to retrieve the results.
func (*Graph) ExportTimingTrends ¶
ExportTimingTrends writes timing epoch history to a CSV file. Same format as [ExportTrends] but uses timing data.
func (*Graph) ExportTrends ¶
ExportTrends writes epoch history to a CSV file. Columns are epoch number followed by one column per tag. Tag groups are expanded. If no tags are specified, all tags with history are exported.
g.ExportTrends("metrics.csv", "loss", "accuracy")
Output:
epoch,loss,accuracy 1,0.5432,0.7123 2,0.4321,0.7856
func (*Graph) Flush ¶
Flush promotes the mean of each tag's batch buffer to the epoch history, then clears the batch buffer. If no tags are specified, all buffered tags are flushed.
If OnFlush has been set, calls the hook with the flushed values (tag → epoch mean).
for epoch := range epochs {
for _, batch := range loader {
g.Forward(batch.Input)
g.Collect("loss")
}
g.Flush() // promotes batch mean → epoch history
}
func (*Graph) FlushCount ¶ added in v0.2.0
FlushCount returns the number of Flush calls that produced data. Each Flush typically corresponds to one training epoch.
func (*Graph) FlushTimings ¶
FlushTimings computes the mean of each tag's timing batch buffer, appends it to the timing epoch history, and clears the buffer. If specific tags are given, only those are flushed; otherwise all buffered tags are flushed.
func (*Graph) Forward ¶
Forward executes the graph, routing variables along edges between nodes. Nodes in the same topological level run concurrently via goroutines. Implements nn.Module.
For cancellation and timeout support, use Graph.ForwardCtx instead.
func (*Graph) ForwardCtx ¶
ForwardCtx executes the graph with a context for cancellation and timeouts. Loops (For, While, Until) and Maps check ctx between iterations; if the context is cancelled or its deadline exceeded, execution returns the context error. The context is also checked between topological levels.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() result := g.ForwardCtx(ctx, input)
func (*Graph) Freeze ¶
Freeze prevents gradient updates for parameters at the tagged node(s). After calling Freeze, subsequent Backward() calls will still compute gradients through frozen nodes (for upstream flow), but their parameter gradients are zeroed before optimizer.Step().
Call after Build():
g.Freeze("encoder") // freeze one tag
g.Freeze("embed", "norm") // freeze multiple tags
Use Unfreeze to re-enable gradient updates.
func (*Graph) LastFlushDuration ¶ added in v0.2.0
LastFlushDuration returns the wall-clock time between the two most recent Flush calls. Returns 0 if fewer than 2 flushes have occurred.
func (*Graph) Log ¶
Log prints the current values of the specified tagged nodes. If no tags are specified, all tagged values are printed. If OnLog has been set, calls the hook instead of printing.
g.Forward(input)
g.Log("loss", "accuracy") // loss: 0.2341 | accuracy: 0.891
func (*Graph) OnFlush ¶
OnFlush sets a custom handler called after each Flush with the epoch values (tag → mean). Pass nil to remove the hook.
g.OnFlush(func(flushed map[string]float64) {
fmt.Printf("epoch loss: %.4f\n", flushed["loss"])
})
func (*Graph) OnLog ¶
OnLog sets a custom handler for Log calls. Pass nil to restore the default (print to stderr).
g.OnLog(func(values map[string]*autograd.Variable) {
fmt.Printf("loss=%.4f\n", scalarValue(values["loss"]))
})
func (*Graph) OnProfile ¶
OnProfile sets a callback invoked after each Forward call when profiling is enabled. The callback receives the completed Profile. Set to nil to remove the hook.
g.OnProfile(func(p *graph.Profile) {
fmt.Print(p) // print the profile summary
})
func (*Graph) Parameters ¶
Parameters collects all learnable parameters from all nodes in the graph. Deduplicates by pointer identity so shared parameters (weight tying) are not counted twice.
func (*Graph) ParametersByTag ¶
ParametersByTag returns parameters belonging to the module at the tagged node. If the module is a sub-graph, its parameters are included recursively. Returns nil for unknown tags or nodes without parameters.
Use this for selective parameter freezing or per-group learning rates:
// Freeze encoder parameters after backward.
for _, p := range g.ParametersByTag("encoder") {
p.ZeroGrad()
}
optimizer.Step()
func (*Graph) PlotHTML ¶
PlotHTML generates a self-contained HTML file with training curves for the specified tags. Tag group names are expanded automatically. If no tags are specified, all tags with epoch history are plotted.
The generated file uses inline JavaScript with HTML5 Canvas — no external dependencies. Open it in any browser.
// Plot specific tags.
g.PlotHTML("training.html", "loss", "accuracy")
// Plot a tag group (expands to head_0, head_1, head_2).
g.PlotHTML("heads.html", "head")
// Plot everything.
g.PlotHTML("all.html")
func (*Graph) PlotTimingsHTML ¶
PlotTimingsHTML generates a self-contained HTML file with timing trend curves. Same as [PlotHTML] but uses timing epoch history (from Graph.CollectTimings / Graph.FlushTimings).
g.PlotTimingsHTML("timings.html", "encoder", "decoder")
func (*Graph) Profile ¶
Profile returns the timing data from the most recent Forward call, or nil if profiling is disabled or no Forward has been called.
func (*Graph) Record ¶ added in v0.2.0
Record injects external scalar values into the batch buffer for a tag. Use this for metrics computed outside the graph (e.g. losses that need both graph outputs and external targets).
Recorded values participate in the same Flush/Trend/Plot pipeline as Collected values — no separate handling needed.
g.Forward(input)
letterLoss := nn.CrossEntropyLoss(g.Tagged("logits"), targets)
g.Record("letter_ce", letterLoss.Item())
g.Record("hit_rate", hitRate)
g.Flush() // promotes both Collected and Recorded
g.Trend("letter_ce").Improving(5)
g.PlotHTML("training.html") // plots everything
func (*Graph) ResetState ¶
func (g *Graph) ResetState()
ResetState clears all forward-reference state buffers to nil. Call this when starting inference on a new sequence.
func (*Graph) ResetTimingTrend ¶
ResetTimingTrend clears the timing epoch history. If specific tags are given, only those are cleared; otherwise all timing history is cleared.
func (*Graph) ResetTrend ¶
ResetTrend clears the epoch history for the specified tags. If no tags are specified, all epoch history is cleared.
func (*Graph) SVG ¶
SVG renders the graph as SVG using the Graphviz dot command. Returns the SVG content as bytes. If a path is provided, the SVG is also written to that file (parent directories must exist).
Requires the dot binary (from Graphviz) to be installed and in PATH. Install: apt install graphviz (Ubuntu), brew install graphviz (macOS).
svg, _ := g.SVG() // just get the bytes
g.SVG("graph.svg") // write to file
g.SVG("docs/architecture.svg") // write to path
func (*Graph) SVGWithProfile ¶
SVGWithProfile renders a timing-annotated SVG using the most recent Profile. Nodes are color-coded by execution time. See Graph.DOTWithProfile for details.
g.EnableProfiling()
g.Forward(input)
g.SVGWithProfile("profile.svg")
func (*Graph) SetDevice ¶ added in v0.2.0
SetDevice moves all parameters and state buffers to the given device and records the device for automatic input migration during Forward.
SetDevice recurses into sub-graphs and composite modules (loops, switches, maps). Call before creating optimizers — optimizer state tensors are allocated lazily on first Step, matching the parameter device automatically.
if tensor.CUDAAvailable() {
g.SetDevice(tensor.CUDA)
}
optimizer := nn.NewAdam(g.Parameters(), 0.001)
func (*Graph) SetTraining ¶
SetTraining propagates training mode to all modules in the graph. Modules that implement nn.TrainToggler (e.g., Dropout, BatchNorm) will switch behavior. Walks into nn.SubModuler children and nested graphs recursively.
func (*Graph) Sub ¶
Sub returns the sub-graph at a tagged node, or nil if the tag doesn't exist or the node isn't a Graph.
The sub-graph's Forward runs automatically as part of the parent's Forward (Graph-as-Module), so its tagged outputs are already populated. Use Sub to observe the sub-graph's internal metrics without a separate Forward call:
g.Forward(input) // runs everything
g.Sub("encoder").Collect("attention") // inner tag
func (*Graph) TagGroup ¶
TagGroup returns the member tags of a tag group, or nil if the group name is not registered.
func (*Graph) Tagged ¶
Tagged returns the output of a tagged node from the last Forward call. Returns nil if the tag is unknown or Forward hasn't been called.
Tagged values are captured automatically during Forward for every node that has a Tag. No explicit setup is needed.
func (*Graph) Timing ¶
Timing returns the execution duration of a tagged node from the most recent Forward call. Returns zero if profiling is disabled or the tag is not found.
func (*Graph) TimingTrend ¶
TimingTrend returns an epoch-level trend over the timing history of a tagged node. The trend values are mean execution times in seconds, one per flushed epoch. Supports the same queries as value trends: Slope, Stalled, Improving, Converged.
if g.TimingTrend("encoder").Slope(5) > 0.001 {
log.Println("encoder getting slower — possible memory issue")
}
func (*Graph) TimingTrends ¶
func (g *Graph) TimingTrends(tags ...string) TrendGroup
TimingTrends returns a TrendGroup for timing trends of the given tags, expanding any tag group names registered with FlowBuilder.TagGroup.
if g.TimingTrends("head").MeanSlope(5) > 0.001 {
fmt.Println("heads getting slower")
}
func (*Graph) Traces ¶
Traces returns the per-iteration side outputs collected from a nn.Traced loop body at the tagged node. Returns nil if the tag is unknown, the node isn't a loop, or the body doesn't implement Traced.
The slice contains one entry per iteration plus the initial state (captured after Reset, before the first iteration). For a loop with N iterations, Traces returns N+1 entries.
g.Forward(input)
locations := g.Traces("attention") // [initial, step1, step2, ...]
func (*Graph) Trend ¶
Trend returns the epoch-level trend for a tag — one data point per Flush call. Returns an empty Trend if no data has been flushed.
trend := g.Trend("loss")
if trend.Stalled(5, 1e-4) {
scheduler.Decay()
}
func (*Graph) Trends ¶
func (g *Graph) Trends(tags ...string) TrendGroup
Trends returns a TrendGroup for the given tags, expanding any tag group names registered with FlowBuilder.TagGroup.
// With TagGroup("head") → ["head_0", "head_1", "head_2"]:
if g.Trends("head").AllImproving(5) {
fmt.Println("all heads improving")
}
func (*Graph) WriteLog ¶ added in v0.2.0
WriteLog writes a human-readable training log to a text file. Each line shows one epoch with metric values, duration, and ETA. Tag group names are expanded. If no tags are specified, all tags with epoch history are included.
totalEpochs is the expected total number of epochs (for ETA calculation). If 0, ETA is omitted.
g.WriteLog("training.log", 100, "loss", "hit_rate", "lr")
Output:
# goDl training log — 2026-03-08T10:30:00Z epoch 1 loss=0.5432 hit_rate=71.23% lr=0.001000 [1.2s ETA 1m58s] epoch 2 loss=0.4321 hit_rate=78.56% lr=0.001000 [1.1s ETA 1m47s]
func (*Graph) ZeroFrozenGrads ¶
func (g *Graph) ZeroFrozenGrads()
ZeroFrozenGrads zeroes out gradients for all frozen parameters. Call between Backward() and optimizer.Step():
loss.Backward() g.ZeroFrozenGrads() optimizer.Step()
type LevelTiming ¶
type LevelTiming struct {
Index int // topological level index
WallClock time.Duration // wall-clock time for the entire level
SumNodes time.Duration // sum of all node durations in this level
NumNodes int // number of nodes in this level
}
LevelTiming records the execution time of a topological level. Multi-node levels execute in parallel via goroutines.
func (*LevelTiming) Parallelism ¶
func (lt *LevelTiming) Parallelism() float64
Parallelism returns the ratio of sequential node time to wall-clock time. Values above 1.0 indicate effective parallelism — a value of 2.5 means the level ran 2.5x faster than sequential execution. Returns 1.0 for single-node levels or when wall-clock is zero.
type LoopBuilder ¶
type LoopBuilder struct {
// contains filtered or unexported fields
}
LoopBuilder configures a loop construct in the graph flow. A loop repeats a body module, carrying state between iterations.
The body receives the current state as input and returns the new state. After all iterations, the final state continues downstream.
If Using refs are wired to the loop node, they are forwarded to the body at each iteration. For bodies implementing nn.NamedInputModule, refs are passed as a named map via ForwardNamed. For plain modules, refs are appended as extra positional arguments to Forward.
Three termination modes:
- For(n): fixed iteration count, always runs exactly n times
- While(cond, maxIter): condition checked before body (0..maxIter iterations)
- Until(cond, maxIter): condition checked after body (1..maxIter iterations)
While and Until use the same halt convention: the condition module receives the current state and returns a scalar. Positive (> 0) means halt. They differ only in timing — While can skip the body entirely, Until always runs it at least once.
Backward passes unroll automatically via autograd — each iteration builds its own computation graph, and the backward walk reverses through all of them (backpropagation through time).
func (*LoopBuilder) For ¶
func (lb *LoopBuilder) For(n int) *FlowBuilder
For sets the loop iteration count and wires the loop into the graph. Returns the FlowBuilder for continued chaining.
graph.From(encoder).
Loop(refinementStep).For(5).
Through(decoder).
Build()
func (*LoopBuilder) Until ¶
func (lb *LoopBuilder) Until(cond nn.Module, maxIter int) *FlowBuilder
Until repeats the body until the condition module signals halt, up to maxIter iterations. The body always executes at least once.
After each body execution, the condition module receives the state and returns a scalar. Iteration stops when the scalar is positive (> 0), indicating the halt condition is satisfied. The stop decision is non-differentiable — gradients flow through the body iterations.
The condition module's parameters are included in Parameters(), and SetTraining propagates to it.
graph.From(encoder).
Loop(refinement).Until(haltProbe, 20).
Through(decoder).
Build()
func (*LoopBuilder) While ¶
func (lb *LoopBuilder) While(cond nn.Module, maxIter int) *FlowBuilder
While repeats the body while the condition module says "continue", up to maxIter iterations. The condition is checked before each iteration — if it signals halt immediately, the body never runs and the input passes through unchanged.
The condition module receives the current state and returns a scalar. Positive (> 0) means halt — same convention as Until.
graph.From(encoder).
Loop(refine).While(graph.ThresholdHalt(100), 20).
Through(decoder).
Build()
type MapBuilder ¶
type MapBuilder struct {
// contains filtered or unexported fields
}
MapBuilder configures a map construct in the graph flow. A map slices a tensor along dim 0 and runs a body module on each element independently. Results are concatenated back along dim 0.
Two iteration sources:
- Over(tag): iterate over a tagged tensor (backward ref)
- Each(): iterate over the current stream
Additional Using refs are broadcast — every body invocation receives the same value. If the body implements nn.NamedInputModule, broadcast refs are passed by tag name via ForwardNamed.
graph.From(positionDecoder).Tag("positions").
Map(readHead).Over("positions").Using("image").
Through(decoder).
Build()
func (*MapBuilder) Batched ¶
func (mb *MapBuilder) Batched() *MapBuilder
Batched enables the batched fast path: instead of iterating element by element (Narrow+Cat), the entire source tensor is passed to the body module in one call. This is significantly faster for stateless bodies (Linear, activations, etc.) that handle batch dimensions natively.
Use only when the body module processes each batch element independently. Modules that normalize across the batch (e.g. BatchNorm) will produce different results in batched mode.
graph.From(encoder).
Map(nn.MustLinear(4, 4)).Batched().Each().
Build()
func (*MapBuilder) Each ¶
func (mb *MapBuilder) Each() *FlowBuilder
Each iterates over elements of the current stream (dim 0).
graph.From(positionDecoder).
Map(readHead).Each().Using("image").
Through(decoder).
Build()
func (*MapBuilder) Over ¶
func (mb *MapBuilder) Over(tag string) *FlowBuilder
Over sets the iteration source to a tagged tensor. The tag must reference a point already defined (backward ref). Each element along dim 0 is passed to the body as its stream input.
func (*MapBuilder) Slices ¶
func (mb *MapBuilder) Slices(n int) *FlowBuilder
Slices decomposes the last dimension of the current stream into n equal slices, applies the body to each, and recomposes the results. This handles dynamic batch sizes at runtime.
For input [B, D] with Slices(n): reshape [B, D] → [B*n, D/n], map body over B*n elements, reshape back to [B, D']. D must be divisible by n. D' depends on the body's output dimension.
graph.From(encoder).
Map(readHead(2)).Slices(4). // [B, 8] → 4 positions × [B, 2] → [B, 8]
Through(decoder).
Build()
type Node ¶
type Node struct {
// contains filtered or unexported fields
}
Node is a computation unit in a graph with named input/output ports.
type NodeTiming ¶
type NodeTiming struct {
ID string // internal node ID (e.g. "Linear_0")
Tag string // tag name, empty if untagged
Duration time.Duration // wall-clock time for node.run()
Level int // topological level index
}
NodeTiming records the execution time of a single node in a Forward pass.
type Profile ¶
type Profile struct {
Total time.Duration // total forward pass wall-clock time
Levels []LevelTiming // per-level timing (in execution order)
Nodes []NodeTiming // per-node timing (in execution order)
}
Profile holds timing data from a single Forward pass. Obtain it via Graph.Profile after a forward call with profiling enabled.
type Trend ¶
type Trend struct {
// contains filtered or unexported fields
}
Trend provides statistical queries over a time series of scalar values.
Typically obtained via g.Trend(tag), which returns the epoch-level history for a tag (one value per Flush call). Can also be created from any []float64 with NewTrend.
All query methods accept a window parameter: positive values limit the analysis to the last N data points; zero or negative uses all. Methods return safe zero-values for empty or insufficient data.
Common patterns:
trend := g.Trend("loss")
trend.Improving(5) // is loss decreasing over last 5 epochs?
trend.Stalled(10, 1e-4) // has loss stopped moving?
trend.Converged(5, 1e-5) // has loss stabilized?
trend.Slope(0) // linear trend over all epochs
func NewTrend ¶
NewTrend creates a Trend from a slice of values. The slice is not copied — the Trend is a view over the data.
func (*Trend) Converged ¶
Converged returns true if the variance over the last window values is below tolerance — the metric has stabilized. Returns false if fewer than 2 values are available.
func (*Trend) Improving ¶
Improving returns true if the slope over the last window values is negative — the metric is decreasing (good for loss). Returns false if fewer than 2 values are available.
func (*Trend) Latest ¶ added in v0.2.0
Latest returns the most recent value, or 0 if the series is empty.
func (*Trend) Slope ¶
Slope returns the OLS linear regression slope over the last window values. A negative slope means the values are decreasing. Returns 0 if fewer than 2 values are available. If window <= 0, uses all values.
type TrendGroup ¶
type TrendGroup []*Trend
TrendGroup is a collection of Trends for group queries. Obtained via Graph.Trends or Graph.TimingTrends, which expand tag groups registered with FlowBuilder.TagGroup.
if g.Trends("head").AllImproving(5) {
fmt.Println("all heads improving")
}
func (TrendGroup) AllConverged ¶
func (tg TrendGroup) AllConverged(window int, tol float64) bool
AllConverged returns true if every trend in the group has converged. Returns false for empty groups.
func (TrendGroup) AllImproving ¶
func (tg TrendGroup) AllImproving(window int) bool
AllImproving returns true if every trend in the group is improving. Returns false for empty groups.
func (TrendGroup) AllStalled ¶
func (tg TrendGroup) AllStalled(window int, tol float64) bool
AllStalled returns true if every trend in the group is stalled. Returns false for empty groups.
func (TrendGroup) AnyConverged ¶
func (tg TrendGroup) AnyConverged(window int, tol float64) bool
AnyConverged returns true if at least one trend in the group has converged.
func (TrendGroup) AnyImproving ¶
func (tg TrendGroup) AnyImproving(window int) bool
AnyImproving returns true if at least one trend in the group is improving.
func (TrendGroup) AnyStalled ¶
func (tg TrendGroup) AnyStalled(window int, tol float64) bool
AnyStalled returns true if at least one trend in the group is stalled.
func (TrendGroup) MeanSlope ¶
func (tg TrendGroup) MeanSlope(window int) float64
MeanSlope returns the average slope across all trends in the group. Returns 0 for empty groups.
func (TrendGroup) Slopes ¶
func (tg TrendGroup) Slopes(window int) []float64
Slopes returns the slope of each trend in the group.