Documentation
¶
Overview ¶
Package distributed provides multi-node distributed training for the Zerfoo ML framework. It implements gradient synchronization, barrier coordination, and tensor broadcasting across a cluster of worker nodes. (Stability: beta)
Architecture ¶
The package is built around the InternalStrategy interface, which defines the collective operations required for distributed training: all-reduce, barrier, and broadcast. Two concrete implementations are provided:
GrpcStrategy uses gRPC for CPU-based gradient exchange over the network. Workers register with a coordinator, discover peers, and perform a star-topology all-reduce through bidirectional streaming RPCs.
[NcclStrategy] (build tag: cuda) uses NVIDIA NCCL for GPU-native collective operations. Gradient tensors remain on-device throughout the all-reduce with no CPU round-trip, delivering significantly higher throughput on multi-GPU nodes.
AllReduceStrategy composes two InternalStrategy instances into a hierarchical scheme: a local strategy handles intra-node communication (typically NCCL) while a cross-node strategy handles inter-node communication (typically gRPC). Node leaders participate in both layers.
Coordinator and Worker Lifecycle ¶
A coordinator process (defined in the distributed/pb protobuf service) manages worker registration, heartbeats, and peer discovery. Each worker follows this lifecycle:
- Create a WorkerNode or GrpcStrategy with the desired configuration.
- Call Init (or Start for WorkerNode), which registers with the coordinator, starts a local gRPC server, and connects to all peers.
- Use AllReduceGradients, Barrier, and BroadcastTensor during training.
- Call Shutdown (or Close) for orderly teardown: unregister from the coordinator, close peer connections, and stop the local gRPC server.
WorkerNode wraps GrpcStrategy with mutex-guarded lifecycle management, health check integration, and compatibility with shutdown.Coordinator.
gRPC Protocol ¶
The protobuf service (distributed/pb) defines three RPCs on the worker service:
AllReduce: bidirectional streaming. Each non-root worker sends its gradient tensors to the root (rank 0), which collects all submissions, computes the element-wise average, and streams the result back.
Barrier: unary RPC. Each worker calls Barrier on the root, which blocks all callers until every rank has arrived, then releases them.
Broadcast: unary RPC. The root sets a tensor via SetBroadcastTensor, and non-root workers retrieve it by calling Broadcast on the root.
A separate coordinator service handles RegisterWorker, UnregisterWorker, and Heartbeat RPCs for cluster membership.
NCCL Gradient Exchange ¶
When built with the cuda build tag, [NcclStrategy] provides GPU-native collectives via NCCL. It groups multiple tensor reductions into a single NCCL launch (ncclGroupStart/ncclGroupEnd) for efficiency, synchronizes on a dedicated CUDA stream, and implements barriers as zero-byte all-reduce operations. Use [NcclStrategy.InitWithUID] for single-process multi-GPU setups where the coordinator can distribute the NCCL UniqueID directly.
TLS ¶
TLSConfig provides optional TLS and mutual TLS (mTLS) for all gRPC connections, including coordinator registration and peer-to-peer traffic. When TLSConfig is nil, plaintext connections are used.
Metrics ¶
All strategies and the worker service emit Prometheus-compatible metrics (counters and histograms) through the ztensor metrics.Collector interface. Use SetCollector to wire in a concrete collector. Stability: beta
Index ¶
- func IsNCCLAvailable() bool
- func NCCLAllGather(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error
- func NCCLAllGatherFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error
- func NCCLGetUniqueID() ([ncclUniqueIDSize]byte, error)
- func NCCLReduceScatter(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error
- func NCCLReduceScatterFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error
- func NewWorkerService(rank, worldSize int32, logger log.Logger) *workerService
- type AllReduceStrategy
- func (s *AllReduceStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
- func (s *AllReduceStrategy[T]) Barrier() error
- func (s *AllReduceStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
- func (s *AllReduceStrategy[T]) Close(_ context.Context) error
- func (s *AllReduceStrategy[T]) Init(rank, size int, coordinatorAddress string) error
- func (s *AllReduceStrategy[T]) Rank() int
- func (s *AllReduceStrategy[T]) SetCollector(c metrics.Collector)
- func (s *AllReduceStrategy[T]) Shutdown()
- func (s *AllReduceStrategy[T]) Size() int
- type CoordinatorClient
- type Dialer
- type GrpcServer
- type GrpcStrategy
- func (s *GrpcStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
- func (s *GrpcStrategy[T]) Barrier() error
- func (s *GrpcStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
- func (s *GrpcStrategy[T]) Close(_ context.Context) error
- func (s *GrpcStrategy[T]) Init(rank, size int, coordinatorAddress string) error
- func (s *GrpcStrategy[T]) Rank() int
- func (s *GrpcStrategy[T]) Shutdown()
- func (s *GrpcStrategy[T]) Size() int
- type GrpcStrategyConfig
- type InternalStrategy
- type ListenerFactory
- type NCCLComm
- type NetworkManager
- type NumericStrategy
- type ServerManager
- type ServiceClientFactory
- type TLSConfig
- type WorkerNode
- type WorkerNodeConfig
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func IsNCCLAvailable ¶ added in v1.5.0
func IsNCCLAvailable() bool
IsNCCLAvailable checks if libnccl.so can be loaded via dlopen.
func NCCLAllGather ¶ added in v1.5.0
func NCCLAllGather(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error
NCCLAllGather gathers sendbuf from all ranks into recvbuf. recvbuf must be nRanks * len(sendbuf) elements. The stream parameter is a cudaStream_t (use 0 for the default stream).
func NCCLAllGatherFloat64 ¶ added in v1.5.0
func NCCLAllGatherFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error
NCCLAllGatherFloat64 is like NCCLAllGather but for float64 data.
func NCCLGetUniqueID ¶ added in v1.5.0
NCCLGetUniqueID generates a new unique ID for communicator initialization. Exactly one rank should call this and distribute the result to all other ranks.
func NCCLReduceScatter ¶ added in v1.5.0
func NCCLReduceScatter(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error
NCCLReduceScatter reduces across ranks and scatters the result. recvbuf gets sendcount/nRanks elements per rank. The stream parameter is a cudaStream_t (use 0 for the default stream).
func NCCLReduceScatterFloat64 ¶ added in v1.5.0
func NCCLReduceScatterFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error
NCCLReduceScatterFloat64 is like NCCLReduceScatter but for float64 data.
func NewWorkerService ¶ added in v0.2.1
NewWorkerService creates a new workerService.
Types ¶
type AllReduceStrategy ¶
AllReduceStrategy implements a more advanced AllReduce algorithm.
func NewAllReduceStrategy ¶
func NewAllReduceStrategy[T tensor.Numeric]( localStrategy, crossNodeStrategy InternalStrategy[T], ) *AllReduceStrategy[T]
NewAllReduceStrategy creates a new AllReduceStrategy.
func (*AllReduceStrategy[T]) AllReduceGradients ¶
func (s *AllReduceStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
AllReduceGradients performs hierarchical all-reduce on gradients.
func (*AllReduceStrategy[T]) Barrier ¶
func (s *AllReduceStrategy[T]) Barrier() error
Barrier synchronizes all workers across all nodes.
func (*AllReduceStrategy[T]) BroadcastTensor ¶
func (s *AllReduceStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
BroadcastTensor broadcasts a tensor from the root rank to all other ranks in the distributed system. The tensor is first broadcast within the root's local node, then across node leaders, and finally within each local node to ensure all ranks receive the broadcasted tensor.
func (*AllReduceStrategy[T]) Close ¶ added in v0.2.1
func (s *AllReduceStrategy[T]) Close(_ context.Context) error
Close satisfies the shutdown.Closer interface by calling Shutdown.
func (*AllReduceStrategy[T]) Init ¶
func (s *AllReduceStrategy[T]) Init(rank, size int, coordinatorAddress string) error
Init initializes the hierarchical strategy.
func (*AllReduceStrategy[T]) Rank ¶
func (s *AllReduceStrategy[T]) Rank() int
Rank returns the rank from the local strategy.
func (*AllReduceStrategy[T]) SetCollector ¶ added in v0.2.1
func (s *AllReduceStrategy[T]) SetCollector(c metrics.Collector)
SetCollector replaces the strategy's metrics collector.
func (*AllReduceStrategy[T]) Shutdown ¶
func (s *AllReduceStrategy[T]) Shutdown()
Shutdown gracefully closes all connections.
func (*AllReduceStrategy[T]) Size ¶
func (s *AllReduceStrategy[T]) Size() int
Size returns the size from the local strategy.
type CoordinatorClient ¶
type CoordinatorClient interface {
RegisterWorker(ctx context.Context, in *pb.RegisterWorkerRequest, opts ...grpc.CallOption) (*pb.RegisterWorkerResponse, error)
UnregisterWorker(ctx context.Context, in *pb.UnregisterWorkerRequest, opts ...grpc.CallOption) (*pb.UnregisterWorkerResponse, error)
Heartbeat(ctx context.Context, in *pb.HeartbeatRequest, opts ...grpc.CallOption) (*pb.HeartbeatResponse, error)
}
CoordinatorClient is an interface for a client of the coordinator service.
type GrpcServer ¶
type GrpcServer interface {
RegisterService(desc *grpc.ServiceDesc, impl interface{})
Serve(lis net.Listener) error
Stop()
GracefulStop()
}
GrpcServer is an interface for a gRPC server.
type GrpcStrategy ¶ added in v0.2.1
GrpcStrategy implements InternalStrategy[T] using gRPC transport. It connects to the coordinator for registration, starts a local gRPC server (workerService) for incoming RPCs, and connects to peers for outgoing RPCs.
func NewGrpcStrategy ¶ added in v0.2.1
func NewGrpcStrategy[T tensor.Numeric](cfg GrpcStrategyConfig) *GrpcStrategy[T]
NewGrpcStrategy creates a new GrpcStrategy with the given configuration.
func (*GrpcStrategy[T]) AllReduceGradients ¶ added in v0.2.1
func (s *GrpcStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
AllReduceGradients performs a star-topology all-reduce. Root (rank 0) collects gradients from all peers, averages them, and sends the result back. Non-root workers send gradients to root and receive the averaged result.
func (*GrpcStrategy[T]) Barrier ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Barrier() error
Barrier synchronizes all workers via the root's barrier service.
func (*GrpcStrategy[T]) BroadcastTensor ¶ added in v0.2.1
func (s *GrpcStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
BroadcastTensor broadcasts a tensor from rootRank to all other workers.
func (*GrpcStrategy[T]) Close ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Close(_ context.Context) error
Close satisfies the shutdown.Closer interface.
func (*GrpcStrategy[T]) Init ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Init(rank, size int, coordinatorAddress string) error
Init registers with the coordinator, starts the local gRPC server, and connects to all peers.
func (*GrpcStrategy[T]) Rank ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Rank() int
Rank returns the worker's rank.
func (*GrpcStrategy[T]) Shutdown ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Shutdown()
Shutdown gracefully shuts down the strategy.
func (*GrpcStrategy[T]) Size ¶ added in v0.2.1
func (s *GrpcStrategy[T]) Size() int
Size returns the total number of workers.
type GrpcStrategyConfig ¶ added in v0.2.1
type GrpcStrategyConfig struct {
WorkerAddress string
WorkerID string
ServerManager ServerManager
NetworkManager NetworkManager
Dialer Dialer
Logger log.Logger
Collector metrics.Collector
TLS *TLSConfig
}
GrpcStrategyConfig holds configuration for creating a GrpcStrategy.
type InternalStrategy ¶
type InternalStrategy[T tensor.Numeric] interface { // Init initializes the strategy. Init(rank int, size int, coordinatorAddress string) error // AllReduceGradients performs an all-reduce operation on the gradients. AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error // Barrier blocks until all workers have reached the barrier. Barrier() error // BroadcastTensor broadcasts a tensor from the root to all other workers. BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error // Rank returns the rank of the current worker. Rank() int // Size returns the total number of workers. Size() int // Shutdown cleans up the resources used by the strategy. Shutdown() }
InternalStrategy defines the interface for a distributed training strategy.
type ListenerFactory ¶
ListenerFactory is a function that creates a new net.Listener.
type NCCLComm ¶ added in v1.5.0
type NCCLComm struct {
// contains filtered or unexported fields
}
NCCLComm wraps an NCCL communicator loaded via purego dlopen. Unlike internal/nccl (which uses CGo), this implementation has zero CGo overhead and does not require the cuda build tag.
func NCCLCommInitRank ¶ added in v1.5.0
NCCLCommInitRank initializes a communicator for rank in an nRanks-size group. The id parameter is a 128-byte unique ID that must be the same across all ranks (generated by NCCLGetUniqueID on one rank and distributed to others).
type NetworkManager ¶
type NetworkManager interface {
// ConnectToPeers establishes connections to all other workers in the cluster.
ConnectToPeers(peers []string, selfRank int, timeout time.Duration) ([]pb.DistributedServiceClient, []*grpc.ClientConn, error)
// CloseConnections closes all the given connections.
CloseConnections(conns []*grpc.ClientConn)
}
NetworkManager is an interface for managing network connections between workers.
func NewNetworkManager ¶
func NewNetworkManager(dialer Dialer, clientFactory ServiceClientFactory) NetworkManager
NewNetworkManager creates a new NetworkManager.
type NumericStrategy ¶ added in v0.2.1
type NumericStrategy[T tensor.Numeric] = InternalStrategy[T]
Generic type alias for external use.
type ServerManager ¶
type ServerManager interface {
Start(workerAddress string, service interface{}, serviceDesc *grpc.ServiceDesc) error
Stop()
GracefulStop()
SetLogger(logger log.Logger)
}
ServerManager is an interface for managing the gRPC server of a worker.
func NewServerManager ¶
func NewServerManager(grpcServer GrpcServer, listenerFactory ListenerFactory) ServerManager
NewServerManager creates a new ServerManager.
type ServiceClientFactory ¶
type ServiceClientFactory func(cc *grpc.ClientConn) pb.DistributedServiceClient
ServiceClientFactory is a function that creates a new DistributedServiceClient.
type TLSConfig ¶ added in v0.2.1
type TLSConfig struct {
// CACertPath is the path to the CA certificate for verifying peers.
CACertPath string
// CertPath is the path to the server or client certificate.
CertPath string
// KeyPath is the path to the private key for the certificate.
KeyPath string
}
TLSConfig holds TLS certificate paths for gRPC connections. When nil, plaintext connections are used (for local development).
func (*TLSConfig) ClientCredentials ¶ added in v0.2.1
func (tc *TLSConfig) ClientCredentials() (credentials.TransportCredentials, error)
ClientCredentials returns gRPC transport credentials for a TLS client. If tc is nil, returns nil (plaintext mode).
func (*TLSConfig) ServerCredentials ¶ added in v0.2.1
func (tc *TLSConfig) ServerCredentials() (credentials.TransportCredentials, error)
ServerCredentials returns gRPC transport credentials for a TLS server. If tc is nil, returns nil (plaintext mode).
type WorkerNode ¶ added in v0.2.1
type WorkerNode struct {
// contains filtered or unexported fields
}
WorkerNode encapsulates a distributed training worker. It manages the gRPC strategy, server, and network connections, and provides orderly startup and shutdown semantics compatible with shutdown.Coordinator.
func NewWorkerNode ¶ added in v0.2.1
func NewWorkerNode(cfg WorkerNodeConfig) *WorkerNode
NewWorkerNode creates a new WorkerNode with the given configuration.
func (*WorkerNode) Close ¶ added in v0.2.1
func (wn *WorkerNode) Close(_ context.Context) error
Close shuts down the worker node. It satisfies the shutdown.Closer interface. Calling Close on an unstarted or already-closed node is safe.
func (*WorkerNode) Rank ¶ added in v0.2.1
func (wn *WorkerNode) Rank() int
Rank returns the worker's rank, or -1 if not started.
func (*WorkerNode) Size ¶ added in v0.2.1
func (wn *WorkerNode) Size() int
Size returns the total number of workers, or 0 if not started.
func (*WorkerNode) Start ¶ added in v0.2.1
func (wn *WorkerNode) Start(_ context.Context) error
Start initializes the distributed worker: creates a gRPC server and strategy, registers with the coordinator, connects to peers, and optionally registers a health check. The context is used only for cancellation of the start sequence, not the lifetime of the worker.
func (*WorkerNode) Strategy ¶ added in v0.2.1
func (wn *WorkerNode) Strategy() InternalStrategy[float32]
Strategy returns the underlying InternalStrategy, or nil if not started.
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
Package coordinator provides a distributed training coordinator.
|
Package coordinator provides a distributed training coordinator. |
|
Package fsdp implements Fully Sharded Data Parallelism for distributed training.
|
Package fsdp implements Fully Sharded Data Parallelism for distributed training. |