Skip to content

Commit 7fec2e7

Browse files
jiayisusefacebook-github-bot
authored andcommitted
create input dist module for sharded ec (#1923)
Summary: Pull Request resolved: #1923 create separate input dist module for sharded qec, to enable input dist split Reviewed By: gnahzg Differential Revision: D56455884 fbshipit-source-id: 59751bc734575325cee269146388a16f968f2bcd
1 parent b1d3329 commit 7fec2e7

File tree

1 file changed

+142
-56
lines changed

1 file changed

+142
-56
lines changed

torchrec/distributed/quant_embedding.py

Lines changed: 142 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -574,34 +574,6 @@ def _generate_permute_coordinates_per_feature_per_sharding(
574574
torch.tensor(permuted_coordinates)
575575
)
576576

577-
def _create_input_dist(
578-
self,
579-
input_feature_names: List[str],
580-
device: torch.device,
581-
input_dist_device: Optional[torch.device] = None,
582-
) -> None:
583-
feature_names: List[str] = []
584-
self._feature_splits: List[int] = []
585-
for sharding in self._sharding_type_to_sharding.values():
586-
self._input_dists.append(
587-
sharding.create_input_dist(device=input_dist_device)
588-
)
589-
feature_names.extend(sharding.feature_names())
590-
self._feature_splits.append(len(sharding.feature_names()))
591-
self._features_order: List[int] = []
592-
for f in feature_names:
593-
self._features_order.append(input_feature_names.index(f))
594-
self._features_order = (
595-
[]
596-
if self._features_order == list(range(len(self._features_order)))
597-
else self._features_order
598-
)
599-
self.register_buffer(
600-
"_features_order_tensor",
601-
torch.tensor(self._features_order, device=device, dtype=torch.int32),
602-
persistent=False,
603-
)
604-
605577
def _create_lookups(
606578
self,
607579
fused_params: Optional[Dict[str, Any]],
@@ -627,46 +599,34 @@ def input_dist(
627599
features: KeyedJaggedTensor,
628600
) -> ListOfKJTList:
629601
if self._has_uninitialized_input_dist:
630-
self._create_input_dist(
602+
self._intput_dist = ShardedQuantEcInputDist(
631603
input_feature_names=features.keys() if features is not None else [],
632-
device=features.device(),
633-
input_dist_device=self._device,
604+
sharding_type_to_sharding=self._sharding_type_to_sharding,
605+
device=self._device,
606+
feature_device=features.device(),
634607
)
635608
self._has_uninitialized_input_dist = False
636609
if self._has_uninitialized_output_dist:
637610
self._create_output_dist(features.device())
638611
self._has_uninitialized_output_dist = False
639-
ret: List[KJTList] = []
612+
613+
(
614+
input_dist_result_list,
615+
features_by_sharding,
616+
unbucketize_permute_tensor_list,
617+
) = self._intput_dist(features)
618+
640619
with torch.no_grad():
641-
features_by_sharding = []
642-
if self._features_order:
643-
features = features.permute(
644-
self._features_order,
645-
self._features_order_tensor,
646-
)
647-
features_by_sharding = (
648-
[features]
649-
if len(self._feature_splits) == 1
650-
else features.split(self._feature_splits)
651-
)
620+
for i in range(len(self._sharding_type_to_sharding)):
652621

653-
for i in range(len(self._input_dists)):
654-
input_dist = self._input_dists[i]
655-
input_dist_result = input_dist.forward(features_by_sharding[i])
656-
ret.append(input_dist_result)
657622
ctx.sharding_contexts.append(
658623
InferSequenceShardingContext(
659-
features=input_dist_result,
624+
features=input_dist_result_list[i],
660625
features_before_input_dist=features_by_sharding[i],
661-
unbucketize_permute_tensor=(
662-
input_dist.unbucketize_permute_tensor
663-
if isinstance(input_dist, InferRwSparseFeaturesDist)
664-
or isinstance(input_dist, InferCPURwSparseFeaturesDist)
665-
else None
666-
),
626+
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
667627
)
668628
)
669-
return ListOfKJTList(ret)
629+
return input_dist_result_list
670630

671631
def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
672632
return (
@@ -680,7 +640,10 @@ def compute(
680640
) -> List[List[torch.Tensor]]:
681641
ret: List[List[torch.Tensor]] = []
682642

683-
for lookup, features in zip(self._lookups, dist_input):
643+
# for lookup, features in zip(self._lookups, dist_input):
644+
for i in range(len(self._lookups)):
645+
lookup = self._lookups[i]
646+
features = dist_input[i]
684647
ret.append(lookup.forward(features))
685648
return ret
686649

@@ -848,3 +811,126 @@ def shard(
848811
@property
849812
def module_type(self) -> Type[QuantEmbeddingCollection]:
850813
return QuantEmbeddingCollection
814+
815+
816+
class ShardedQuantEcInputDist(torch.nn.Module):
817+
"""
818+
This module implements distributed inputs of a ShardedQuantEmbeddingCollection.
819+
820+
Args:
821+
input_feature_names (List[str]): EmbeddingCollection feature names.
822+
sharding_type_to_sharding (Dict[
823+
str,
824+
EmbeddingSharding[
825+
InferSequenceShardingContext,
826+
KJTList,
827+
List[torch.Tensor],
828+
List[torch.Tensor],
829+
],
830+
]): map from sharding type to EmbeddingSharding.
831+
device (Optional[torch.device]): default compute device.
832+
feature_device (Optional[torch.device]): runtime feature device.
833+
834+
Example::
835+
836+
sqec_input_dist = ShardedQuantEcInputDist(
837+
sharding_type_to_sharding={
838+
ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
839+
[],
840+
ShardingEnv(
841+
world_size=2,
842+
rank=0,
843+
pg=0,
844+
),
845+
torch.device("cpu")
846+
)
847+
},
848+
device=torch.device("cpu"),
849+
)
850+
851+
features = KeyedJaggedTensor(
852+
keys=["f1", "f2"],
853+
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
854+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
855+
)
856+
857+
sqec_input_dist(features)
858+
"""
859+
860+
def __init__(
861+
self,
862+
input_feature_names: List[str],
863+
sharding_type_to_sharding: Dict[
864+
str,
865+
EmbeddingSharding[
866+
InferSequenceShardingContext,
867+
KJTList,
868+
List[torch.Tensor],
869+
List[torch.Tensor],
870+
],
871+
],
872+
device: Optional[torch.device] = None,
873+
feature_device: Optional[torch.device] = None,
874+
) -> None:
875+
super().__init__()
876+
self._sharding_type_to_sharding = sharding_type_to_sharding
877+
self._input_dists = torch.nn.ModuleList([])
878+
self._feature_splits: List[int] = []
879+
self._features_order: List[int] = []
880+
881+
self._has_features_permute: bool = True
882+
883+
feature_names: List[str] = []
884+
for sharding in sharding_type_to_sharding.values():
885+
self._input_dists.append(sharding.create_input_dist(device=device))
886+
feature_names.extend(sharding.feature_names())
887+
self._feature_splits.append(len(sharding.feature_names()))
888+
889+
for f in feature_names:
890+
self._features_order.append(input_feature_names.index(f))
891+
self._features_order = (
892+
[]
893+
if self._features_order == list(range(len(self._features_order)))
894+
else self._features_order
895+
)
896+
self.register_buffer(
897+
"_features_order_tensor",
898+
torch.tensor(
899+
self._features_order, device=feature_device, dtype=torch.int32
900+
),
901+
persistent=False,
902+
)
903+
904+
def forward(
905+
self, features: KeyedJaggedTensor
906+
) -> Tuple[List[KJTList], List[KeyedJaggedTensor], List[Optional[torch.Tensor]]]:
907+
with torch.no_grad():
908+
ret: List[KJTList] = []
909+
unbucketize_permute_tensor = []
910+
if self._features_order:
911+
features = features.permute(
912+
self._features_order,
913+
self._features_order_tensor,
914+
)
915+
features_by_sharding = (
916+
[features]
917+
if len(self._feature_splits) == 1
918+
else features.split(self._feature_splits)
919+
)
920+
921+
for i in range(len(self._input_dists)):
922+
input_dist = self._input_dists[i]
923+
input_dist_result = input_dist(features_by_sharding[i])
924+
ret.append(input_dist_result)
925+
unbucketize_permute_tensor.append(
926+
input_dist.unbucketize_permute_tensor
927+
if isinstance(input_dist, InferRwSparseFeaturesDist)
928+
or isinstance(input_dist, InferCPURwSparseFeaturesDist)
929+
else None
930+
)
931+
932+
return (
933+
ret,
934+
features_by_sharding,
935+
unbucketize_permute_tensor,
936+
)

0 commit comments

Comments
 (0)