diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index d55dc9b03..0591abb09 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -418,14 +418,16 @@ def detach(self) -> torch.nn.Module: self._model_attached = False return self._model - def attach(self, model: Optional[torch.nn.Module] = None) -> None: + def attach( + self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True + ) -> None: if model: self._model = model self._model_attached = True if self.contexts: self._pipeline_model( - batch=self.batches[0], + batch=self.batches[0] if sparse_dist else None, context=self.contexts[0], pipelined_forward=PipelinedForward, )