Skip to content

Commit

Permalink
Remove embedding streams from semi-sync (#2731)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2731

Do embedding lookup on default stream instead of extra stream as using extra streams runs into subtle data races.

Reviewed By: dstaay-fb

Differential Revision: D69270806

fbshipit-source-id: c26ddb1886f3b7151d5048bcdc47180e5ee9f67b
  • Loading branch information
sarckk authored and facebook-github-bot committed Feb 7, 2025
1 parent 9269e73 commit c4c9332
Showing 1 changed file with 37 additions and 57 deletions.
94 changes: 37 additions & 57 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,6 @@ def __init__(
self._start_batch = start_batch
self._stash_gradients = stash_gradients
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")

self._embedding_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}

def _grad_swap(self) -> None:
Expand All @@ -779,14 +777,6 @@ def _grad_swap(self) -> None:
self._gradients[name] = param.grad.clone()
param.grad = grad

def _init_embedding_streams(self) -> None:
for _ in self._pipelined_modules:
self._embedding_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)

def _validate_optimizer(self) -> None:
for pipelined_module in self._pipelined_modules:
pipelined_params = set(pipelined_module.parameters())
Expand Down Expand Up @@ -815,7 +805,6 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pyre-ignore [6]
EmbeddingPipelinedForward,
)
self._init_embedding_streams()
self.wait_sparse_data_dist(self.contexts[0])
self._validate_optimizer()
# pyre-ignore [6]
Expand Down Expand Up @@ -916,43 +905,36 @@ def _mlp_forward(
return self._model_fwd(batch)

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
assert len(context.embedding_features) == len(context.embedding_tensors)
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
self._embedding_streams,
for emb_tensors, embedding_features, detached_emb_tensors in zip(
context.embedding_tensors,
context.embedding_features,
context.detached_embedding_tensors,
):
with self._stream_context(stream):
grads = [tensor.grad for tensor in detached_emb_tensors]
if stream:
stream.wait_stream(default_stream)
# Some embeddings may never get used in the final loss computation,
# so the grads will be `None`. If we don't exclude these, it will fail
# with error: "grad can be implicitly created only for scalar outputs"
# Alternatively, if the tensor has only 1 element, pytorch can still
# figure out how to do autograd
embs_to_backprop, grads_to_use, invalid_features = [], [], []
assert len(embedding_features) == len(emb_tensors)
for features, tensor, grad in zip(
embedding_features, emb_tensors, grads
):
if tensor.numel() == 1 or grad is not None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
grads = [tensor.grad for tensor in detached_emb_tensors]
# Some embeddings may never get used in the final loss computation,
# so the grads will be `None`. If we don't exclude these, it will fail
# with error: "grad can be implicitly created only for scalar outputs"
# Alternatively, if the tensor has only 1 element, pytorch can still
# figure out how to do autograd
embs_to_backprop, grads_to_use, invalid_features = [], [], []
assert len(embedding_features) == len(emb_tensors)
for features, tensor, grad in zip(embedding_features, emb_tensors, grads):
if tensor.numel() == 1 or grad is not None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
else:
if isinstance(features, str):
invalid_features.append(features)
elif isinstance(features, Iterable):
invalid_features.extend(features)
else:
if isinstance(features, str):
invalid_features.append(features)
elif isinstance(features, Iterable):
invalid_features.extend(features)
else:
invalid_features.append(features)
if invalid_features and context.index == 0:
logger.warning(
f"SemiSync, the following features have no gradients: {invalid_features}"
)
torch.autograd.backward(embs_to_backprop, grads_to_use)
invalid_features.append(features)
if invalid_features and context.index == 0:
logger.warning(
f"SemiSync, the following features have no gradients: {invalid_features}"
)
torch.autograd.backward(embs_to_backprop, grads_to_use)

def copy_batch_to_gpu(
self,
Expand Down Expand Up @@ -1012,23 +994,21 @@ def start_embedding_lookup(
"""
if batch is None:
return

with record_function(f"## start_embedding_lookup {context.index} ##"):
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
current_stream = torch.get_device_module(self._device).current_stream()
_wait_for_events(batch, context, current_stream)
for i, module in enumerate(self._pipelined_modules):
stream = self._embedding_streams[i]
with self._stream_context(stream):
_start_embedding_lookup(
module,
context,
source_stream=self._data_dist_stream,
target_stream=stream,
stream_context=self._stream_context,
)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
_start_embedding_lookup(
module,
context,
source_stream=self._data_dist_stream,
target_stream=current_stream,
stream_context=self._stream_context,
)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)


class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
Expand Down

0 comments on commit c4c9332

Please sign in to comment.