fsdp

package
v1.38.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 30, 2026 License: Apache-2.0 Imports: 14 Imported by: 0

Documentation

Overview

Package fsdp implements Fully Sharded Data Parallelism for distributed training.

Stability: alpha

Package fsdp implements Fully Sharded Data Parallelism (FSDP) for distributed training. It shards model parameters across ranks so each rank holds only 1/worldSize of each parameter, reducing per-GPU memory proportionally. Before forward, AllGather reconstructs full parameters; after backward, ReduceScatter aggregates gradient shards.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func LoadCheckpoint

func LoadCheckpoint[T tensor.Numeric](path string, module *ShardedModule[T], rank int) error

LoadCheckpoint reads a GGUF checkpoint on rank 0 and distributes tensor shards to each rank's ShardedModule. Each rank receives its 1/worldSize slice of every parameter.

func SaveCheckpoint

func SaveCheckpoint[T tensor.Numeric](path string, module *ShardedModule[T], rank int) error

SaveCheckpoint gathers all parameter shards via AllGather and writes them as a GGUF v3 checkpoint file. Only rank 0 writes the file; other ranks participate in AllGather but do not perform I/O.

Types

type GradAccum

type GradAccum[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GradAccum accumulates gradients across M micro-steps before triggering a synchronization (e.g., AllReduce or ReduceScatter for FSDP). This allows effective batch sizes larger than what fits in GPU memory.

func NewGradAccum

func NewGradAccum[T tensor.Numeric](module *ShardedModule[T], stepsPerSync int) *GradAccum[T]

NewGradAccum creates a GradAccum that accumulates gradients for stepsPerSync micro-steps before triggering synchronization through the ShardedModule.

func (*GradAccum[T]) Accumulate

func (g *GradAccum[T]) Accumulate(grads map[string][]T) bool

Accumulate adds a set of gradients (keyed by parameter name) to the accumulator. Returns true when the accumulation window is full and Sync should be called.

func (*GradAccum[T]) Reset

func (g *GradAccum[T]) Reset()

Reset clears the accumulator and resets the step counter.

func (*GradAccum[T]) Sync

func (g *GradAccum[T]) Sync() map[string][]T

Sync returns the averaged accumulated gradients (sum / steps) and resets the accumulator. This should be called after Accumulate returns true.

type ShardedAdamW

type ShardedAdamW[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ShardedAdamW implements AdamW where each rank only maintains optimizer state for its own parameter shard (ZeRO Stage 2). Moment tensors are sized 1/worldSize of the full parameter, so total optimizer memory scales as O(params/N).

func NewShardedAdamW

func NewShardedAdamW[T tensor.Numeric](rank, worldSize int, lr, beta1, beta2, eps, wd float32) *ShardedAdamW[T]

NewShardedAdamW creates a sharded AdamW optimizer for the given rank. Each rank maintains moment buffers only for its local parameter shard.

func (*ShardedAdamW[T]) MemoryBytes

func (o *ShardedAdamW[T]) MemoryBytes() int64

MemoryBytes returns the total memory used by moment buffers across all tracked parameter shards.

func (*ShardedAdamW[T]) Step

func (o *ShardedAdamW[T]) Step(shardGrads map[string][]T) map[string][]T

Step performs one AdamW update on the local parameter shards. shardGrads maps parameter names to gradient slices (already reduced/scattered to this rank's shard). Returns the updated parameter shard delta values.

func (*ShardedAdamW[T]) StepOnParams

func (o *ShardedAdamW[T]) StepOnParams(shardParams, shardGrads map[string][]T)

StepOnParams performs one AdamW update directly on parameter shard slices. This is the primary entry point for FSDP training: it updates shardParams in-place using the corresponding shardGrads.

type ShardedModule

type ShardedModule[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

ShardedModule wraps a model and shards its parameters across worldSize devices. Before forward: AllGather reconstructs full parameter tensors. After backward: ReduceScatter aggregates gradient shards.

func NewShardedModule

func NewShardedModule[T tensor.Numeric](module training.Model[T], rank, worldSize int, comm *distributed.NCCLComm) *ShardedModule[T]

NewShardedModule creates a ShardedModule that shards all model parameters across worldSize ranks. Each rank retains only its 1/worldSize slice.

func (*ShardedModule[T]) Backward

func (s *ShardedModule[T]) Backward(ctx context.Context, grad *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)

Backward performs AllGather, runs the model backward pass, then ReduceScatters gradients so each rank holds its gradient shard.

func (*ShardedModule[T]) Forward

func (s *ShardedModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)

Forward performs AllGather on all parameters, runs the model forward pass, then restores sharded state.

func (*ShardedModule[T]) Parameters

func (s *ShardedModule[T]) Parameters() []*graph.Parameter[T]

Parameters returns the underlying model's parameters (in sharded state).

func (*ShardedModule[T]) ReplicatedMemoryBytes

func (s *ShardedModule[T]) ReplicatedMemoryBytes() int64

ReplicatedMemoryBytes returns the memory that would be used by fully replicated (unsharded) parameters.

func (*ShardedModule[T]) ShardMemoryBytes

func (s *ShardedModule[T]) ShardMemoryBytes() int64

ShardMemoryBytes returns the memory used by sharded parameters (per rank).

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL