distributed

package
v1.38.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 30, 2026 License: Apache-2.0 Imports: 22 Imported by: 0

Documentation

Overview

Package distributed provides multi-node distributed training for the Zerfoo ML framework. It implements gradient synchronization, barrier coordination, and tensor broadcasting across a cluster of worker nodes. (Stability: beta)

Architecture

The package is built around the InternalStrategy interface, which defines the collective operations required for distributed training: all-reduce, barrier, and broadcast. Two concrete implementations are provided:

  • GrpcStrategy uses gRPC for CPU-based gradient exchange over the network. Workers register with a coordinator, discover peers, and perform a star-topology all-reduce through bidirectional streaming RPCs.

  • [NcclStrategy] (build tag: cuda) uses NVIDIA NCCL for GPU-native collective operations. Gradient tensors remain on-device throughout the all-reduce with no CPU round-trip, delivering significantly higher throughput on multi-GPU nodes.

AllReduceStrategy composes two InternalStrategy instances into a hierarchical scheme: a local strategy handles intra-node communication (typically NCCL) while a cross-node strategy handles inter-node communication (typically gRPC). Node leaders participate in both layers.

Coordinator and Worker Lifecycle

A coordinator process (defined in the distributed/pb protobuf service) manages worker registration, heartbeats, and peer discovery. Each worker follows this lifecycle:

  1. Create a WorkerNode or GrpcStrategy with the desired configuration.
  2. Call Init (or Start for WorkerNode), which registers with the coordinator, starts a local gRPC server, and connects to all peers.
  3. Use AllReduceGradients, Barrier, and BroadcastTensor during training.
  4. Call Shutdown (or Close) for orderly teardown: unregister from the coordinator, close peer connections, and stop the local gRPC server.

WorkerNode wraps GrpcStrategy with mutex-guarded lifecycle management, health check integration, and compatibility with shutdown.Coordinator.

gRPC Protocol

The protobuf service (distributed/pb) defines three RPCs on the worker service:

  • AllReduce: bidirectional streaming. Each non-root worker sends its gradient tensors to the root (rank 0), which collects all submissions, computes the element-wise average, and streams the result back.

  • Barrier: unary RPC. Each worker calls Barrier on the root, which blocks all callers until every rank has arrived, then releases them.

  • Broadcast: unary RPC. The root sets a tensor via SetBroadcastTensor, and non-root workers retrieve it by calling Broadcast on the root.

A separate coordinator service handles RegisterWorker, UnregisterWorker, and Heartbeat RPCs for cluster membership.

NCCL Gradient Exchange

When built with the cuda build tag, [NcclStrategy] provides GPU-native collectives via NCCL. It groups multiple tensor reductions into a single NCCL launch (ncclGroupStart/ncclGroupEnd) for efficiency, synchronizes on a dedicated CUDA stream, and implements barriers as zero-byte all-reduce operations. Use [NcclStrategy.InitWithUID] for single-process multi-GPU setups where the coordinator can distribute the NCCL UniqueID directly.

TLS

TLSConfig provides optional TLS and mutual TLS (mTLS) for all gRPC connections, including coordinator registration and peer-to-peer traffic. When TLSConfig is nil, plaintext connections are used.

Metrics

All strategies and the worker service emit Prometheus-compatible metrics (counters and histograms) through the ztensor metrics.Collector interface. Use SetCollector to wire in a concrete collector. Stability: beta

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func IsNCCLAvailable added in v1.5.0

func IsNCCLAvailable() bool

IsNCCLAvailable checks if libnccl.so can be loaded via dlopen.

func NCCLAllGather added in v1.5.0

func NCCLAllGather(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error

NCCLAllGather gathers sendbuf from all ranks into recvbuf. recvbuf must be nRanks * len(sendbuf) elements. The stream parameter is a cudaStream_t (use 0 for the default stream).

func NCCLAllGatherFloat64 added in v1.5.0

func NCCLAllGatherFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, sendcount int, stream uintptr) error

NCCLAllGatherFloat64 is like NCCLAllGather but for float64 data.

func NCCLGetUniqueID added in v1.5.0

func NCCLGetUniqueID() ([ncclUniqueIDSize]byte, error)

NCCLGetUniqueID generates a new unique ID for communicator initialization. Exactly one rank should call this and distribute the result to all other ranks.

func NCCLReduceScatter added in v1.5.0

func NCCLReduceScatter(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error

NCCLReduceScatter reduces across ranks and scatters the result. recvbuf gets sendcount/nRanks elements per rank. The stream parameter is a cudaStream_t (use 0 for the default stream).

func NCCLReduceScatterFloat64 added in v1.5.0

func NCCLReduceScatterFloat64(comm *NCCLComm, sendbuf, recvbuf unsafe.Pointer, recvcount int, stream uintptr) error

NCCLReduceScatterFloat64 is like NCCLReduceScatter but for float64 data.

func NewWorkerService added in v0.2.1

func NewWorkerService(rank, worldSize int32, logger log.Logger) *workerService

NewWorkerService creates a new workerService.

Types

type AllReduceStrategy

type AllReduceStrategy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AllReduceStrategy implements a more advanced AllReduce algorithm.

func NewAllReduceStrategy

func NewAllReduceStrategy[T tensor.Numeric](
	localStrategy, crossNodeStrategy InternalStrategy[T],
) *AllReduceStrategy[T]

NewAllReduceStrategy creates a new AllReduceStrategy.

func (*AllReduceStrategy[T]) AllReduceGradients

func (s *AllReduceStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error

AllReduceGradients performs hierarchical all-reduce on gradients.

func (*AllReduceStrategy[T]) Barrier

func (s *AllReduceStrategy[T]) Barrier() error

Barrier synchronizes all workers across all nodes.

func (*AllReduceStrategy[T]) BroadcastTensor

func (s *AllReduceStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error

BroadcastTensor broadcasts a tensor from the root rank to all other ranks in the distributed system. The tensor is first broadcast within the root's local node, then across node leaders, and finally within each local node to ensure all ranks receive the broadcasted tensor.

func (*AllReduceStrategy[T]) Close added in v0.2.1

func (s *AllReduceStrategy[T]) Close(_ context.Context) error

Close satisfies the shutdown.Closer interface by calling Shutdown.

func (*AllReduceStrategy[T]) Init

func (s *AllReduceStrategy[T]) Init(rank, size int, coordinatorAddress string) error

Init initializes the hierarchical strategy.

func (*AllReduceStrategy[T]) Rank

func (s *AllReduceStrategy[T]) Rank() int

Rank returns the rank from the local strategy.

func (*AllReduceStrategy[T]) SetCollector added in v0.2.1

func (s *AllReduceStrategy[T]) SetCollector(c metrics.Collector)

SetCollector replaces the strategy's metrics collector.

func (*AllReduceStrategy[T]) Shutdown

func (s *AllReduceStrategy[T]) Shutdown()

Shutdown gracefully closes all connections.

func (*AllReduceStrategy[T]) Size

func (s *AllReduceStrategy[T]) Size() int

Size returns the size from the local strategy.

type CoordinatorClient

type CoordinatorClient interface {
	RegisterWorker(ctx context.Context, in *pb.RegisterWorkerRequest, opts ...grpc.CallOption) (*pb.RegisterWorkerResponse, error)
	UnregisterWorker(ctx context.Context, in *pb.UnregisterWorkerRequest, opts ...grpc.CallOption) (*pb.UnregisterWorkerResponse, error)
	Heartbeat(ctx context.Context, in *pb.HeartbeatRequest, opts ...grpc.CallOption) (*pb.HeartbeatResponse, error)
}

CoordinatorClient is an interface for a client of the coordinator service.

type Dialer

type Dialer func(ctx context.Context, target string) (*grpc.ClientConn, error)

Dialer is a function that creates a gRPC client connection.

type GrpcServer

type GrpcServer interface {
	RegisterService(desc *grpc.ServiceDesc, impl interface{})
	Serve(lis net.Listener) error
	Stop()
	GracefulStop()
}

GrpcServer is an interface for a gRPC server.

type GrpcStrategy added in v0.2.1

type GrpcStrategy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GrpcStrategy implements InternalStrategy[T] using gRPC transport. It connects to the coordinator for registration, starts a local gRPC server (workerService) for incoming RPCs, and connects to peers for outgoing RPCs.

func NewGrpcStrategy added in v0.2.1

func NewGrpcStrategy[T tensor.Numeric](cfg GrpcStrategyConfig) *GrpcStrategy[T]

NewGrpcStrategy creates a new GrpcStrategy with the given configuration.

func (*GrpcStrategy[T]) AllReduceGradients added in v0.2.1

func (s *GrpcStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error

AllReduceGradients performs a star-topology all-reduce. Root (rank 0) collects gradients from all peers, averages them, and sends the result back. Non-root workers send gradients to root and receive the averaged result.

func (*GrpcStrategy[T]) Barrier added in v0.2.1

func (s *GrpcStrategy[T]) Barrier() error

Barrier synchronizes all workers via the root's barrier service.

func (*GrpcStrategy[T]) BroadcastTensor added in v0.2.1

func (s *GrpcStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error

BroadcastTensor broadcasts a tensor from rootRank to all other workers.

func (*GrpcStrategy[T]) Close added in v0.2.1

func (s *GrpcStrategy[T]) Close(_ context.Context) error

Close satisfies the shutdown.Closer interface.

func (*GrpcStrategy[T]) Init added in v0.2.1

func (s *GrpcStrategy[T]) Init(rank, size int, coordinatorAddress string) error

Init registers with the coordinator, starts the local gRPC server, and connects to all peers.

func (*GrpcStrategy[T]) Rank added in v0.2.1

func (s *GrpcStrategy[T]) Rank() int

Rank returns the worker's rank.

func (*GrpcStrategy[T]) Shutdown added in v0.2.1

func (s *GrpcStrategy[T]) Shutdown()

Shutdown gracefully shuts down the strategy.

func (*GrpcStrategy[T]) Size added in v0.2.1

func (s *GrpcStrategy[T]) Size() int

Size returns the total number of workers.

type GrpcStrategyConfig added in v0.2.1

type GrpcStrategyConfig struct {
	WorkerAddress  string
	WorkerID       string
	ServerManager  ServerManager
	NetworkManager NetworkManager
	Dialer         Dialer
	Logger         log.Logger
	Collector      metrics.Collector
	TLS            *TLSConfig
}

GrpcStrategyConfig holds configuration for creating a GrpcStrategy.

type InternalStrategy

type InternalStrategy[T tensor.Numeric] interface {
	// Init initializes the strategy.
	Init(rank int, size int, coordinatorAddress string) error
	// AllReduceGradients performs an all-reduce operation on the gradients.
	AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
	// Barrier blocks until all workers have reached the barrier.
	Barrier() error
	// BroadcastTensor broadcasts a tensor from the root to all other workers.
	BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
	// Rank returns the rank of the current worker.
	Rank() int
	// Size returns the total number of workers.
	Size() int
	// Shutdown cleans up the resources used by the strategy.
	Shutdown()
}

InternalStrategy defines the interface for a distributed training strategy.

type ListenerFactory

type ListenerFactory func(network, address string) (net.Listener, error)

ListenerFactory is a function that creates a new net.Listener.

type NCCLComm added in v1.5.0

type NCCLComm struct {
	// contains filtered or unexported fields
}

NCCLComm wraps an NCCL communicator loaded via purego dlopen. Unlike internal/nccl (which uses CGo), this implementation has zero CGo overhead and does not require the cuda build tag.

func NCCLCommInitRank added in v1.5.0

func NCCLCommInitRank(nRanks, rank int, id [ncclUniqueIDSize]byte) (*NCCLComm, error)

NCCLCommInitRank initializes a communicator for rank in an nRanks-size group. The id parameter is a 128-byte unique ID that must be the same across all ranks (generated by NCCLGetUniqueID on one rank and distributed to others).

func (*NCCLComm) Destroy added in v1.5.0

func (c *NCCLComm) Destroy() error

Destroy releases the communicator resources.

func (*NCCLComm) NRanks added in v1.5.0

func (c *NCCLComm) NRanks() int

NRanks returns the total number of ranks in this communicator.

func (*NCCLComm) Rank added in v1.5.0

func (c *NCCLComm) Rank() int

Rank returns the rank of this communicator.

type NetworkManager

type NetworkManager interface {
	// ConnectToPeers establishes connections to all other workers in the cluster.
	ConnectToPeers(peers []string, selfRank int, timeout time.Duration) ([]pb.DistributedServiceClient, []*grpc.ClientConn, error)
	// CloseConnections closes all the given connections.
	CloseConnections(conns []*grpc.ClientConn)
}

NetworkManager is an interface for managing network connections between workers.

func NewNetworkManager

func NewNetworkManager(dialer Dialer, clientFactory ServiceClientFactory) NetworkManager

NewNetworkManager creates a new NetworkManager.

type NumericStrategy added in v0.2.1

type NumericStrategy[T tensor.Numeric] = InternalStrategy[T]

Generic type alias for external use.

type ServerManager

type ServerManager interface {
	Start(workerAddress string, service interface{}, serviceDesc *grpc.ServiceDesc) error
	Stop()
	GracefulStop()
	SetLogger(logger log.Logger)
}

ServerManager is an interface for managing the gRPC server of a worker.

func NewServerManager

func NewServerManager(grpcServer GrpcServer, listenerFactory ListenerFactory) ServerManager

NewServerManager creates a new ServerManager.

type ServiceClientFactory

type ServiceClientFactory func(cc *grpc.ClientConn) pb.DistributedServiceClient

ServiceClientFactory is a function that creates a new DistributedServiceClient.

type TLSConfig added in v0.2.1

type TLSConfig struct {
	// CACertPath is the path to the CA certificate for verifying peers.
	CACertPath string
	// CertPath is the path to the server or client certificate.
	CertPath string
	// KeyPath is the path to the private key for the certificate.
	KeyPath string
}

TLSConfig holds TLS certificate paths for gRPC connections. When nil, plaintext connections are used (for local development).

func (*TLSConfig) ClientCredentials added in v0.2.1

func (tc *TLSConfig) ClientCredentials() (credentials.TransportCredentials, error)

ClientCredentials returns gRPC transport credentials for a TLS client. If tc is nil, returns nil (plaintext mode).

func (*TLSConfig) ServerCredentials added in v0.2.1

func (tc *TLSConfig) ServerCredentials() (credentials.TransportCredentials, error)

ServerCredentials returns gRPC transport credentials for a TLS server. If tc is nil, returns nil (plaintext mode).

type WorkerNode added in v0.2.1

type WorkerNode struct {
	// contains filtered or unexported fields
}

WorkerNode encapsulates a distributed training worker. It manages the gRPC strategy, server, and network connections, and provides orderly startup and shutdown semantics compatible with shutdown.Coordinator.

func NewWorkerNode added in v0.2.1

func NewWorkerNode(cfg WorkerNodeConfig) *WorkerNode

NewWorkerNode creates a new WorkerNode with the given configuration.

func (*WorkerNode) Close added in v0.2.1

func (wn *WorkerNode) Close(_ context.Context) error

Close shuts down the worker node. It satisfies the shutdown.Closer interface. Calling Close on an unstarted or already-closed node is safe.

func (*WorkerNode) Rank added in v0.2.1

func (wn *WorkerNode) Rank() int

Rank returns the worker's rank, or -1 if not started.

func (*WorkerNode) Size added in v0.2.1

func (wn *WorkerNode) Size() int

Size returns the total number of workers, or 0 if not started.

func (*WorkerNode) Start added in v0.2.1

func (wn *WorkerNode) Start(_ context.Context) error

Start initializes the distributed worker: creates a gRPC server and strategy, registers with the coordinator, connects to peers, and optionally registers a health check. The context is used only for cancellation of the start sequence, not the lifetime of the worker.

func (*WorkerNode) Strategy added in v0.2.1

func (wn *WorkerNode) Strategy() InternalStrategy[float32]

Strategy returns the underlying InternalStrategy, or nil if not started.

type WorkerNodeConfig added in v0.2.1

type WorkerNodeConfig struct {
	WorkerAddress      string
	CoordinatorAddress string
	WorldSize          int
	Logger             log.Logger
	Collector          metrics.Collector
	HealthServer       *health.Server
}

WorkerNodeConfig holds configuration for creating a WorkerNode.

Directories

Path Synopsis
Package coordinator provides a distributed training coordinator.
Package coordinator provides a distributed training coordinator.
Package fsdp implements Fully Sharded Data Parallelism for distributed training.
Package fsdp implements Fully Sharded Data Parallelism for distributed training.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL