From c4c93325c9d0e66c032746cf9b8a32d60af75cfe Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 6 Feb 2025 17:57:00 -0800 Subject: [PATCH] Remove embedding streams from semi-sync (#2731) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2731 Do embedding lookup on default stream instead of extra stream as using extra streams runs into subtle data races. Reviewed By: dstaay-fb Differential Revision: D69270806 fbshipit-source-id: c26ddb1886f3b7151d5048bcdc47180e5ee9f67b --- .../train_pipeline/train_pipelines.py | 94 ++++++++----------- 1 file changed, 37 insertions(+), 57 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index f7b94b37b..d55dc9b03 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -768,8 +768,6 @@ def __init__( self._start_batch = start_batch self._stash_gradients = stash_gradients logger.debug(f"Starting semi-sync run at batch: {self._start_batch}") - - self._embedding_streams: List[Optional[torch.Stream]] = [] self._gradients: Dict[str, torch.Tensor] = {} def _grad_swap(self) -> None: @@ -779,14 +777,6 @@ def _grad_swap(self) -> None: self._gradients[name] = param.grad.clone() param.grad = grad - def _init_embedding_streams(self) -> None: - for _ in self._pipelined_modules: - self._embedding_streams.append( - (torch.get_device_module(self._device).Stream(priority=0)) - if self._device.type in ["cuda", "mtia"] - else None - ) - def _validate_optimizer(self) -> None: for pipelined_module in self._pipelined_modules: pipelined_params = set(pipelined_module.parameters()) @@ -815,7 +805,6 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # pyre-ignore [6] EmbeddingPipelinedForward, ) - self._init_embedding_streams() self.wait_sparse_data_dist(self.contexts[0]) self._validate_optimizer() # pyre-ignore [6] @@ -916,43 +905,36 @@ def _mlp_forward( return self._model_fwd(batch) def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: - default_stream = torch.get_device_module(self._device).current_stream() assert len(context.embedding_features) == len(context.embedding_tensors) - for stream, emb_tensors, embedding_features, detached_emb_tensors in zip( - self._embedding_streams, + for emb_tensors, embedding_features, detached_emb_tensors in zip( context.embedding_tensors, context.embedding_features, context.detached_embedding_tensors, ): - with self._stream_context(stream): - grads = [tensor.grad for tensor in detached_emb_tensors] - if stream: - stream.wait_stream(default_stream) - # Some embeddings may never get used in the final loss computation, - # so the grads will be `None`. If we don't exclude these, it will fail - # with error: "grad can be implicitly created only for scalar outputs" - # Alternatively, if the tensor has only 1 element, pytorch can still - # figure out how to do autograd - embs_to_backprop, grads_to_use, invalid_features = [], [], [] - assert len(embedding_features) == len(emb_tensors) - for features, tensor, grad in zip( - embedding_features, emb_tensors, grads - ): - if tensor.numel() == 1 or grad is not None: - embs_to_backprop.append(tensor) - grads_to_use.append(grad) + grads = [tensor.grad for tensor in detached_emb_tensors] + # Some embeddings may never get used in the final loss computation, + # so the grads will be `None`. If we don't exclude these, it will fail + # with error: "grad can be implicitly created only for scalar outputs" + # Alternatively, if the tensor has only 1 element, pytorch can still + # figure out how to do autograd + embs_to_backprop, grads_to_use, invalid_features = [], [], [] + assert len(embedding_features) == len(emb_tensors) + for features, tensor, grad in zip(embedding_features, emb_tensors, grads): + if tensor.numel() == 1 or grad is not None: + embs_to_backprop.append(tensor) + grads_to_use.append(grad) + else: + if isinstance(features, str): + invalid_features.append(features) + elif isinstance(features, Iterable): + invalid_features.extend(features) else: - if isinstance(features, str): - invalid_features.append(features) - elif isinstance(features, Iterable): - invalid_features.extend(features) - else: - invalid_features.append(features) - if invalid_features and context.index == 0: - logger.warning( - f"SemiSync, the following features have no gradients: {invalid_features}" - ) - torch.autograd.backward(embs_to_backprop, grads_to_use) + invalid_features.append(features) + if invalid_features and context.index == 0: + logger.warning( + f"SemiSync, the following features have no gradients: {invalid_features}" + ) + torch.autograd.backward(embs_to_backprop, grads_to_use) def copy_batch_to_gpu( self, @@ -1012,23 +994,21 @@ def start_embedding_lookup( """ if batch is None: return + with record_function(f"## start_embedding_lookup {context.index} ##"): - _wait_for_events( - batch, context, torch.get_device_module(self._device).current_stream() - ) + current_stream = torch.get_device_module(self._device).current_stream() + _wait_for_events(batch, context, current_stream) for i, module in enumerate(self._pipelined_modules): - stream = self._embedding_streams[i] - with self._stream_context(stream): - _start_embedding_lookup( - module, - context, - source_stream=self._data_dist_stream, - target_stream=stream, - stream_context=self._stream_context, - ) - event = torch.get_device_module(self._device).Event() - event.record() - context.events.append(event) + _start_embedding_lookup( + module, + context, + source_stream=self._data_dist_stream, + target_stream=current_stream, + stream_context=self._stream_context, + ) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):