From 26e073249399e3b49da6cf9274b827965961bedf Mon Sep 17 00:00:00 2001 From: Yanli Zhao Date: Thu, 6 Feb 2025 19:21:47 -0800 Subject: [PATCH] fix attach (#2726) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2726 allow "attach" to avoid calling sparse data distribution, as in some cases "attach" is called outside training loop, no sparse data distribution when calling attach outside training loop can avoid interference with sparse data distribution inside training loop. Reviewed By: hlin09, ge0405 Differential Revision: D68908008 fbshipit-source-id: bb1b7c362ef0a6de78e3e0bd771d8d22e3e4e985 --- torchrec/distributed/train_pipeline/train_pipelines.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index d55dc9b03..0591abb09 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -418,14 +418,16 @@ def detach(self) -> torch.nn.Module: self._model_attached = False return self._model - def attach(self, model: Optional[torch.nn.Module] = None) -> None: + def attach( + self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True + ) -> None: if model: self._model = model self._model_attached = True if self.contexts: self._pipeline_model( - batch=self.batches[0], + batch=self.batches[0] if sparse_dist else None, context=self.contexts[0], pipelined_forward=PipelinedForward, )