Documentation
¶
Overview ¶
Package fsdp implements Fully Sharded Data Parallelism for distributed training.
Stability: alpha
Package fsdp implements Fully Sharded Data Parallelism (FSDP) for distributed training. It shards model parameters across ranks so each rank holds only 1/worldSize of each parameter, reducing per-GPU memory proportionally. Before forward, AllGather reconstructs full parameters; after backward, ReduceScatter aggregates gradient shards.
Index ¶
- func LoadCheckpoint[T tensor.Numeric](path string, module *ShardedModule[T], rank int) error
- func SaveCheckpoint[T tensor.Numeric](path string, module *ShardedModule[T], rank int) error
- type GradAccum
- type ShardedAdamW
- type ShardedModule
- func (s *ShardedModule[T]) Backward(ctx context.Context, grad *tensor.TensorNumeric[T], ...) ([]*tensor.TensorNumeric[T], error)
- func (s *ShardedModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (s *ShardedModule[T]) Parameters() []*graph.Parameter[T]
- func (s *ShardedModule[T]) ReplicatedMemoryBytes() int64
- func (s *ShardedModule[T]) ShardMemoryBytes() int64
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func LoadCheckpoint ¶
LoadCheckpoint reads a GGUF checkpoint on rank 0 and distributes tensor shards to each rank's ShardedModule. Each rank receives its 1/worldSize slice of every parameter.
func SaveCheckpoint ¶
SaveCheckpoint gathers all parameter shards via AllGather and writes them as a GGUF v3 checkpoint file. Only rank 0 writes the file; other ranks participate in AllGather but do not perform I/O.
Types ¶
type GradAccum ¶
GradAccum accumulates gradients across M micro-steps before triggering a synchronization (e.g., AllReduce or ReduceScatter for FSDP). This allows effective batch sizes larger than what fits in GPU memory.
func NewGradAccum ¶
func NewGradAccum[T tensor.Numeric](module *ShardedModule[T], stepsPerSync int) *GradAccum[T]
NewGradAccum creates a GradAccum that accumulates gradients for stepsPerSync micro-steps before triggering synchronization through the ShardedModule.
func (*GradAccum[T]) Accumulate ¶
Accumulate adds a set of gradients (keyed by parameter name) to the accumulator. Returns true when the accumulation window is full and Sync should be called.
type ShardedAdamW ¶
ShardedAdamW implements AdamW where each rank only maintains optimizer state for its own parameter shard (ZeRO Stage 2). Moment tensors are sized 1/worldSize of the full parameter, so total optimizer memory scales as O(params/N).
func NewShardedAdamW ¶
func NewShardedAdamW[T tensor.Numeric](rank, worldSize int, lr, beta1, beta2, eps, wd float32) *ShardedAdamW[T]
NewShardedAdamW creates a sharded AdamW optimizer for the given rank. Each rank maintains moment buffers only for its local parameter shard.
func (*ShardedAdamW[T]) MemoryBytes ¶
func (o *ShardedAdamW[T]) MemoryBytes() int64
MemoryBytes returns the total memory used by moment buffers across all tracked parameter shards.
func (*ShardedAdamW[T]) Step ¶
func (o *ShardedAdamW[T]) Step(shardGrads map[string][]T) map[string][]T
Step performs one AdamW update on the local parameter shards. shardGrads maps parameter names to gradient slices (already reduced/scattered to this rank's shard). Returns the updated parameter shard delta values.
func (*ShardedAdamW[T]) StepOnParams ¶
func (o *ShardedAdamW[T]) StepOnParams(shardParams, shardGrads map[string][]T)
StepOnParams performs one AdamW update directly on parameter shard slices. This is the primary entry point for FSDP training: it updates shardParams in-place using the corresponding shardGrads.
type ShardedModule ¶
ShardedModule wraps a model and shards its parameters across worldSize devices. Before forward: AllGather reconstructs full parameter tensors. After backward: ReduceScatter aggregates gradient shards.
func NewShardedModule ¶
func NewShardedModule[T tensor.Numeric](module training.Model[T], rank, worldSize int, comm *distributed.NCCLComm) *ShardedModule[T]
NewShardedModule creates a ShardedModule that shards all model parameters across worldSize ranks. Each rank retains only its 1/worldSize slice.
func (*ShardedModule[T]) Backward ¶
func (s *ShardedModule[T]) Backward(ctx context.Context, grad *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward performs AllGather, runs the model backward pass, then ReduceScatters gradients so each rank holds its gradient shard.
func (*ShardedModule[T]) Forward ¶
func (s *ShardedModule[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward performs AllGather on all parameters, runs the model forward pass, then restores sharded state.
func (*ShardedModule[T]) Parameters ¶
func (s *ShardedModule[T]) Parameters() []*graph.Parameter[T]
Parameters returns the underlying model's parameters (in sharded state).
func (*ShardedModule[T]) ReplicatedMemoryBytes ¶
func (s *ShardedModule[T]) ReplicatedMemoryBytes() int64
ReplicatedMemoryBytes returns the memory that would be used by fully replicated (unsharded) parameters.
func (*ShardedModule[T]) ShardMemoryBytes ¶
func (s *ShardedModule[T]) ShardMemoryBytes() int64
ShardMemoryBytes returns the memory used by sharded parameters (per rank).