@@ -81,14 +81,25 @@ def _init_offloaded_optimizer(self, model):
8181 )
8282 self ._logger .debug ("offload model to cpu" )
8383
84- def sync_pseudo_gradient (self , model : nn .Module , fake : bool = False , flag : str = "outer" ):
84+ def sync_pseudo_gradient (
85+ self , model : nn .Module , fake : bool = False , flag : str = "outer" , num_effective_peers : int | None = None
86+ ):
8587 """
8688 Sync the pseudo gradient from the local process group to the global process group
8789 """
8890 _start_time = time .perf_counter ()
89- self ._logger .debug ("sync pseudo gradient %s" , " fake" if fake else "" )
9091
92+ world_size_pre_init = self .elastic_device_mesh .global_pg .size ()
9193 self .elastic_device_mesh .maybe_reinit_global_pg (admit_joiners = False )
94+ world_size_post_init = self .elastic_device_mesh .global_pg .size ()
95+
96+ if world_size_pre_init == world_size_post_init and num_effective_peers is not None :
97+ world_size = num_effective_peers
98+ else :
99+ world_size = world_size_post_init
100+
101+ self ._logger .debug ("sync pseudo gradient %s with world size %d" , " fake" if fake else "" , world_size )
102+
92103 global_pg = self .elastic_device_mesh .global_pg
93104 for i in range (self .config .retry_all_reduce ):
94105 for param_offloaded , param in zip (self .param_list_cpu , model .parameters ()):
@@ -98,7 +109,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
98109 param_offloaded .grad .to_local ().copy_ (param_offloaded .data .to_local ())
99110 param_offloaded .grad .to_local ().sub_ (param .data .to_local ().to (param_offloaded .data .device ))
100111 try :
101- self .offloaded_grad_flat_tensor .div_ (global_pg . size () )
112+ self .offloaded_grad_flat_tensor .div_ (world_size )
102113 _collective_start_time = time .perf_counter ()
103114 self ._logger .debug ("Waiting on barrier" )
104115 self .elastic_device_mesh .monitored_barrier (flag )
@@ -198,12 +209,12 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
198209 # )
199210 return offloaded_params
200211
201- def step (self , model : nn .Module , fake : bool = False , flag : str = "outer" ):
212+ def step (self , model : nn .Module , fake : bool = False , num_effective_peers : int | None = None , flag : str = "outer" ):
202213 """
203214 Step the optimizer
204215 """
205216 time_start = time .perf_counter ()
206- self .sync_pseudo_gradient (model , fake = fake , flag = flag )
217+ self .sync_pseudo_gradient (model , fake = fake , flag = flag , num_effective_peers = num_effective_peers )
207218 self ._logger .info (f"all reduce pseudo gradient in: { time .perf_counter () - time_start } seconds" )
208219
209220 if self .outer_optimizer is not None :
0 commit comments