Skip to content

Commit

Permalink
fix attach (#2726)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2726

allow "attach" to avoid calling sparse data distribution, as in some cases "attach" is called outside training loop, no sparse data distribution when calling attach outside training loop can avoid interference with sparse data distribution inside training loop.

Reviewed By: hlin09, ge0405

Differential Revision: D68908008

fbshipit-source-id: bb1b7c362ef0a6de78e3e0bd771d8d22e3e4e985
  • Loading branch information
zhaojuanmao authored and facebook-github-bot committed Feb 7, 2025
1 parent c4c9332 commit 26e0732
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,16 @@ def detach(self) -> torch.nn.Module:
self._model_attached = False
return self._model

def attach(self, model: Optional[torch.nn.Module] = None) -> None:
def attach(
self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True
) -> None:
if model:
self._model = model

self._model_attached = True
if self.contexts:
self._pipeline_model(
batch=self.batches[0],
batch=self.batches[0] if sparse_dist else None,
context=self.contexts[0],
pipelined_forward=PipelinedForward,
)
Expand Down

0 comments on commit 26e0732

Please sign in to comment.