We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c908517 commit e23007bCopy full SHA for e23007b
torchrec/distributed/embedding_lookup.py
@@ -594,6 +594,18 @@ def _merge_variable_batch_embeddings(
594
self, embeddings: List[torch.Tensor], splits: List[List[int]]
595
) -> List[torch.Tensor]:
596
assert len(embeddings) > 1 and len(splits) > 1
597
+
598
+ logger.info(
599
+ "Merge VBE embeddings from the following TBEs "
600
+ f"(world size: {self._world_size}):\n"
601
+ + "\n".join(
602
+ [
603
+ f"\t{module.__class__.__name__}:{len(split)} splits"
604
+ for module, split in zip(self._emb_modules, splits)
605
+ ]
606
+ )
607
608
609
split_embs = [e.split(s) for e, s in zip(embeddings, splits)]
610
combined_embs = [
611
emb
0 commit comments