Skip to content

Commit e23007b

Browse files
Shuangping Liufacebook-github-bot
authored andcommitted
Add logging when merging VBE embeddings from multiple TBEs (#3304)
Summary: Pull Request resolved: #3304 Add a logging in `_merge_variable_batch_embeddings` so that we can tell if a model has multiple TBEs for lookup with VBE enabled. This function can take a significantly long time when the world size and number of features are large. It can sometimes lead to GPU inefficiencies. See this [doc](https://docs.google.com/document/d/1h5YyeCjYmmN-CIFB98CrBf1uMksidPbNvM1rl8yZeds/edit?addon_store&fbclid=IwY2xjawL3TWBleHRuA2FlbQIxMQBicmlkETEzSzN5RFhBN0hyd0RZa2ZnAR6U1Lg7P5B-BgHUp_-eFqgx-__zTPYiXWSS0eZc9SmJdgiJeZv-fwnrCaSSLA_aem_YhBf3lxMTQ_xrKbE-GXpnQ&tab=t.0) for more context. Reviewed By: TroyGarden Differential Revision: D80686878 fbshipit-source-id: bf63fc5ec5073c5e5a9dd6d1c4cb4bc3306d4c86
1 parent c908517 commit e23007b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,18 @@ def _merge_variable_batch_embeddings(
594594
self, embeddings: List[torch.Tensor], splits: List[List[int]]
595595
) -> List[torch.Tensor]:
596596
assert len(embeddings) > 1 and len(splits) > 1
597+
598+
logger.info(
599+
"Merge VBE embeddings from the following TBEs "
600+
f"(world size: {self._world_size}):\n"
601+
+ "\n".join(
602+
[
603+
f"\t{module.__class__.__name__}:{len(split)} splits"
604+
for module, split in zip(self._emb_modules, splits)
605+
]
606+
)
607+
)
608+
597609
split_embs = [e.split(s) for e, s in zip(embeddings, splits)]
598610
combined_embs = [
599611
emb

0 commit comments

Comments
 (0)