Skip to content

Commit b0171d3

Browse files
committed
fix all reduce fake div
1 parent e3e1fa2 commit b0171d3

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

src/zeroband/diloco.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,25 @@ def _init_offloaded_optimizer(self, model):
8181
)
8282
self._logger.debug("offload model to cpu")
8383

84-
def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = "outer"):
84+
def sync_pseudo_gradient(
85+
self, model: nn.Module, fake: bool = False, flag: str = "outer", num_effective_peers: int | None = None
86+
):
8587
"""
8688
Sync the pseudo gradient from the local process group to the global process group
8789
"""
8890
_start_time = time.perf_counter()
89-
self._logger.debug("sync pseudo gradient %s", " fake" if fake else "")
9091

92+
world_size_pre_init = self.elastic_device_mesh.global_pg.size()
9193
self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False)
94+
world_size_post_init = self.elastic_device_mesh.global_pg.size()
95+
96+
if world_size_pre_init == world_size_post_init and num_effective_peers is not None:
97+
world_size = num_effective_peers
98+
else:
99+
world_size = world_size_post_init
100+
101+
self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size)
102+
92103
global_pg = self.elastic_device_mesh.global_pg
93104
for i in range(self.config.retry_all_reduce):
94105
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
@@ -98,7 +109,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
98109
param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local())
99110
param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device))
100111
try:
101-
self.offloaded_grad_flat_tensor.div_(global_pg.size())
112+
self.offloaded_grad_flat_tensor.div_(world_size)
102113
_collective_start_time = time.perf_counter()
103114
self._logger.debug("Waiting on barrier")
104115
self.elastic_device_mesh.monitored_barrier(flag)
@@ -198,12 +209,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
198209
# )
199210
return offloaded_params
200211

201-
def step(self, model: nn.Module, fake: bool = False, flag: str = "outer"):
212+
def step(self, model: nn.Module, fake: bool = False, num_effective_peers: int | None = None, flag: str = "outer"):
202213
"""
203214
Step the optimizer
204215
"""
205216
time_start = time.perf_counter()
206-
self.sync_pseudo_gradient(model, fake=fake, flag=flag)
217+
self.sync_pseudo_gradient(model, fake=fake, flag=flag, num_effective_peers=num_effective_peers)
207218
self._logger.info(f"all reduce pseudo gradient in: {time.perf_counter() - time_start} seconds")
208219

209220
if self.outer_optimizer is not None:

src/zeroband/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,10 @@ def train(config: Config):
316316
time_start_outer = time.perf_counter()
317317

318318
if config.diloco is not None:
319+
# this is a patch for now to allow live recovery worker to not affect the all reduce at all
320+
num_effective_peers = elastic_device_mesh.global_pg.size()
319321
elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True)
322+
320323
# at the beginning of the inner steps we allow joiner to arrive.
321324
# We maybe reinit before the all reduce but only to allow leaving, not to join anymore
322325

@@ -459,7 +462,7 @@ def train(config: Config):
459462
ckpt_manager.cache_inner_optimizer()
460463

461464
time_start_inner = time.perf_counter()
462-
diloco.step(model, flag=training_progress.outer_step)
465+
diloco.step(model=model, flag=training_progress.outer_step, num_effective_peers=num_effective_peers)
463466
diloco_time = time.perf_counter() - time_start_inner
464467

465468
if config.train.log_model_hash:

0 commit comments

Comments
 (0)