distributed

package
v1.2.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 16, 2026 License: Apache-2.0 Imports: 20 Imported by: 0

Documentation

Overview

Package distributed provides distributed training strategies and coordination mechanisms for multi-node machine learning workloads in the Zerfoo framework.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func NewWorkerService added in v0.2.1

func NewWorkerService(rank, worldSize int32, logger log.Logger) *workerService

NewWorkerService creates a new workerService.

Types

type AllReduceStrategy

type AllReduceStrategy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

AllReduceStrategy implements a more advanced AllReduce algorithm.

func NewAllReduceStrategy

func NewAllReduceStrategy[T tensor.Numeric](
	localStrategy, crossNodeStrategy InternalStrategy[T],
) *AllReduceStrategy[T]

NewAllReduceStrategy creates a new AllReduceStrategy.

func (*AllReduceStrategy[T]) AllReduceGradients

func (s *AllReduceStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error

AllReduceGradients performs hierarchical all-reduce on gradients.

func (*AllReduceStrategy[T]) Barrier

func (s *AllReduceStrategy[T]) Barrier() error

Barrier synchronizes all workers across all nodes.

func (*AllReduceStrategy[T]) BroadcastTensor

func (s *AllReduceStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error

BroadcastTensor broadcasts a tensor from the root rank to all other ranks in the distributed system. The tensor is first broadcast within the root's local node, then across node leaders, and finally within each local node to ensure all ranks receive the broadcasted tensor.

func (*AllReduceStrategy[T]) Close added in v0.2.1

func (s *AllReduceStrategy[T]) Close(_ context.Context) error

Close satisfies the shutdown.Closer interface by calling Shutdown.

func (*AllReduceStrategy[T]) Init

func (s *AllReduceStrategy[T]) Init(rank, size int, coordinatorAddress string) error

Init initializes the hierarchical strategy.

func (*AllReduceStrategy[T]) Rank

func (s *AllReduceStrategy[T]) Rank() int

Rank returns the rank from the local strategy.

func (*AllReduceStrategy[T]) SetCollector added in v0.2.1

func (s *AllReduceStrategy[T]) SetCollector(c metrics.Collector)

SetCollector replaces the strategy's metrics collector.

func (*AllReduceStrategy[T]) Shutdown

func (s *AllReduceStrategy[T]) Shutdown()

Shutdown gracefully closes all connections.

func (*AllReduceStrategy[T]) Size

func (s *AllReduceStrategy[T]) Size() int

Size returns the size from the local strategy.

type CoordinatorClient

type CoordinatorClient interface {
	RegisterWorker(ctx context.Context, in *pb.RegisterWorkerRequest, opts ...grpc.CallOption) (*pb.RegisterWorkerResponse, error)
	UnregisterWorker(ctx context.Context, in *pb.UnregisterWorkerRequest, opts ...grpc.CallOption) (*pb.UnregisterWorkerResponse, error)
	Heartbeat(ctx context.Context, in *pb.HeartbeatRequest, opts ...grpc.CallOption) (*pb.HeartbeatResponse, error)
}

CoordinatorClient is an interface for a client of the coordinator service.

type Dialer

type Dialer func(ctx context.Context, target string) (*grpc.ClientConn, error)

Dialer is a function that creates a gRPC client connection.

type GrpcServer

type GrpcServer interface {
	RegisterService(desc *grpc.ServiceDesc, impl interface{})
	Serve(lis net.Listener) error
	Stop()
	GracefulStop()
}

GrpcServer is an interface for a gRPC server.

type GrpcStrategy added in v0.2.1

type GrpcStrategy[T tensor.Numeric] struct {
	// contains filtered or unexported fields
}

GrpcStrategy implements InternalStrategy[T] using gRPC transport. It connects to the coordinator for registration, starts a local gRPC server (workerService) for incoming RPCs, and connects to peers for outgoing RPCs.

func NewGrpcStrategy added in v0.2.1

func NewGrpcStrategy[T tensor.Numeric](cfg GrpcStrategyConfig) *GrpcStrategy[T]

NewGrpcStrategy creates a new GrpcStrategy with the given configuration.

func (*GrpcStrategy[T]) AllReduceGradients added in v0.2.1

func (s *GrpcStrategy[T]) AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error

AllReduceGradients performs a star-topology all-reduce. Root (rank 0) collects gradients from all peers, averages them, and sends the result back. Non-root workers send gradients to root and receive the averaged result.

func (*GrpcStrategy[T]) Barrier added in v0.2.1

func (s *GrpcStrategy[T]) Barrier() error

Barrier synchronizes all workers via the root's barrier service.

func (*GrpcStrategy[T]) BroadcastTensor added in v0.2.1

func (s *GrpcStrategy[T]) BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error

BroadcastTensor broadcasts a tensor from rootRank to all other workers.

func (*GrpcStrategy[T]) Close added in v0.2.1

func (s *GrpcStrategy[T]) Close(_ context.Context) error

Close satisfies the shutdown.Closer interface.

func (*GrpcStrategy[T]) Init added in v0.2.1

func (s *GrpcStrategy[T]) Init(rank, size int, coordinatorAddress string) error

Init registers with the coordinator, starts the local gRPC server, and connects to all peers.

func (*GrpcStrategy[T]) Rank added in v0.2.1

func (s *GrpcStrategy[T]) Rank() int

Rank returns the worker's rank.

func (*GrpcStrategy[T]) Shutdown added in v0.2.1

func (s *GrpcStrategy[T]) Shutdown()

Shutdown gracefully shuts down the strategy.

func (*GrpcStrategy[T]) Size added in v0.2.1

func (s *GrpcStrategy[T]) Size() int

Size returns the total number of workers.

type GrpcStrategyConfig added in v0.2.1

type GrpcStrategyConfig struct {
	WorkerAddress  string
	WorkerID       string
	ServerManager  ServerManager
	NetworkManager NetworkManager
	Dialer         Dialer
	Logger         log.Logger
	Collector      metrics.Collector
	TLS            *TLSConfig
}

GrpcStrategyConfig holds configuration for creating a GrpcStrategy.

type InternalStrategy

type InternalStrategy[T tensor.Numeric] interface {
	// Init initializes the strategy.
	Init(rank int, size int, coordinatorAddress string) error
	// AllReduceGradients performs an all-reduce operation on the gradients.
	AllReduceGradients(gradients map[string]*tensor.TensorNumeric[T]) error
	// Barrier blocks until all workers have reached the barrier.
	Barrier() error
	// BroadcastTensor broadcasts a tensor from the root to all other workers.
	BroadcastTensor(t *tensor.TensorNumeric[T], rootRank int) error
	// Rank returns the rank of the current worker.
	Rank() int
	// Size returns the total number of workers.
	Size() int
	// Shutdown cleans up the resources used by the strategy.
	Shutdown()
}

InternalStrategy defines the interface for a distributed training strategy.

type ListenerFactory

type ListenerFactory func(network, address string) (net.Listener, error)

ListenerFactory is a function that creates a new net.Listener.

type NetworkManager

type NetworkManager interface {
	// ConnectToPeers establishes connections to all other workers in the cluster.
	ConnectToPeers(peers []string, selfRank int, timeout time.Duration) ([]pb.DistributedServiceClient, []*grpc.ClientConn, error)
	// CloseConnections closes all the given connections.
	CloseConnections(conns []*grpc.ClientConn)
}

NetworkManager is an interface for managing network connections between workers.

func NewNetworkManager

func NewNetworkManager(dialer Dialer, clientFactory ServiceClientFactory) NetworkManager

NewNetworkManager creates a new NetworkManager.

type NumericStrategy added in v0.2.1

type NumericStrategy[T tensor.Numeric] = InternalStrategy[T]

Generic type alias for external use.

type ServerManager

type ServerManager interface {
	Start(workerAddress string, service interface{}, serviceDesc *grpc.ServiceDesc) error
	Stop()
	GracefulStop()
	SetLogger(logger log.Logger)
}

ServerManager is an interface for managing the gRPC server of a worker.

func NewServerManager

func NewServerManager(grpcServer GrpcServer, listenerFactory ListenerFactory) ServerManager

NewServerManager creates a new ServerManager.

type ServiceClientFactory

type ServiceClientFactory func(cc *grpc.ClientConn) pb.DistributedServiceClient

ServiceClientFactory is a function that creates a new DistributedServiceClient.

type TLSConfig added in v0.2.1

type TLSConfig struct {
	// CACertPath is the path to the CA certificate for verifying peers.
	CACertPath string
	// CertPath is the path to the server or client certificate.
	CertPath string
	// KeyPath is the path to the private key for the certificate.
	KeyPath string
}

TLSConfig holds TLS certificate paths for gRPC connections. When nil, plaintext connections are used (for local development).

func (*TLSConfig) ClientCredentials added in v0.2.1

func (tc *TLSConfig) ClientCredentials() (credentials.TransportCredentials, error)

ClientCredentials returns gRPC transport credentials for a TLS client. If tc is nil, returns nil (plaintext mode).

func (*TLSConfig) ServerCredentials added in v0.2.1

func (tc *TLSConfig) ServerCredentials() (credentials.TransportCredentials, error)

ServerCredentials returns gRPC transport credentials for a TLS server. If tc is nil, returns nil (plaintext mode).

type WorkerNode added in v0.2.1

type WorkerNode struct {
	// contains filtered or unexported fields
}

WorkerNode encapsulates a distributed training worker. It manages the gRPC strategy, server, and network connections, and provides orderly startup and shutdown semantics compatible with shutdown.Coordinator.

func NewWorkerNode added in v0.2.1

func NewWorkerNode(cfg WorkerNodeConfig) *WorkerNode

NewWorkerNode creates a new WorkerNode with the given configuration.

func (*WorkerNode) Close added in v0.2.1

func (wn *WorkerNode) Close(_ context.Context) error

Close shuts down the worker node. It satisfies the shutdown.Closer interface. Calling Close on an unstarted or already-closed node is safe.

func (*WorkerNode) Rank added in v0.2.1

func (wn *WorkerNode) Rank() int

Rank returns the worker's rank, or -1 if not started.

func (*WorkerNode) Size added in v0.2.1

func (wn *WorkerNode) Size() int

Size returns the total number of workers, or 0 if not started.

func (*WorkerNode) Start added in v0.2.1

func (wn *WorkerNode) Start(_ context.Context) error

Start initializes the distributed worker: creates a gRPC server and strategy, registers with the coordinator, connects to peers, and optionally registers a health check. The context is used only for cancellation of the start sequence, not the lifetime of the worker.

func (*WorkerNode) Strategy added in v0.2.1

func (wn *WorkerNode) Strategy() InternalStrategy[float32]

Strategy returns the underlying InternalStrategy, or nil if not started.

type WorkerNodeConfig added in v0.2.1

type WorkerNodeConfig struct {
	WorkerAddress      string
	CoordinatorAddress string
	WorldSize          int
	Logger             log.Logger
	Collector          metrics.Collector
	HealthServer       *health.Server
}

WorkerNodeConfig holds configuration for creating a WorkerNode.

Directories

Path Synopsis
Package coordinator provides a distributed training coordinator.
Package coordinator provides a distributed training coordinator.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL