@@ -574,34 +574,6 @@ def _generate_permute_coordinates_per_feature_per_sharding(
574
574
torch .tensor (permuted_coordinates )
575
575
)
576
576
577
- def _create_input_dist (
578
- self ,
579
- input_feature_names : List [str ],
580
- device : torch .device ,
581
- input_dist_device : Optional [torch .device ] = None ,
582
- ) -> None :
583
- feature_names : List [str ] = []
584
- self ._feature_splits : List [int ] = []
585
- for sharding in self ._sharding_type_to_sharding .values ():
586
- self ._input_dists .append (
587
- sharding .create_input_dist (device = input_dist_device )
588
- )
589
- feature_names .extend (sharding .feature_names ())
590
- self ._feature_splits .append (len (sharding .feature_names ()))
591
- self ._features_order : List [int ] = []
592
- for f in feature_names :
593
- self ._features_order .append (input_feature_names .index (f ))
594
- self ._features_order = (
595
- []
596
- if self ._features_order == list (range (len (self ._features_order )))
597
- else self ._features_order
598
- )
599
- self .register_buffer (
600
- "_features_order_tensor" ,
601
- torch .tensor (self ._features_order , device = device , dtype = torch .int32 ),
602
- persistent = False ,
603
- )
604
-
605
577
def _create_lookups (
606
578
self ,
607
579
fused_params : Optional [Dict [str , Any ]],
@@ -627,46 +599,34 @@ def input_dist(
627
599
features : KeyedJaggedTensor ,
628
600
) -> ListOfKJTList :
629
601
if self ._has_uninitialized_input_dist :
630
- self ._create_input_dist (
602
+ self ._intput_dist = ShardedQuantEcInputDist (
631
603
input_feature_names = features .keys () if features is not None else [],
632
- device = features .device (),
633
- input_dist_device = self ._device ,
604
+ sharding_type_to_sharding = self ._sharding_type_to_sharding ,
605
+ device = self ._device ,
606
+ feature_device = features .device (),
634
607
)
635
608
self ._has_uninitialized_input_dist = False
636
609
if self ._has_uninitialized_output_dist :
637
610
self ._create_output_dist (features .device ())
638
611
self ._has_uninitialized_output_dist = False
639
- ret : List [KJTList ] = []
612
+
613
+ (
614
+ input_dist_result_list ,
615
+ features_by_sharding ,
616
+ unbucketize_permute_tensor_list ,
617
+ ) = self ._intput_dist (features )
618
+
640
619
with torch .no_grad ():
641
- features_by_sharding = []
642
- if self ._features_order :
643
- features = features .permute (
644
- self ._features_order ,
645
- self ._features_order_tensor ,
646
- )
647
- features_by_sharding = (
648
- [features ]
649
- if len (self ._feature_splits ) == 1
650
- else features .split (self ._feature_splits )
651
- )
620
+ for i in range (len (self ._sharding_type_to_sharding )):
652
621
653
- for i in range (len (self ._input_dists )):
654
- input_dist = self ._input_dists [i ]
655
- input_dist_result = input_dist .forward (features_by_sharding [i ])
656
- ret .append (input_dist_result )
657
622
ctx .sharding_contexts .append (
658
623
InferSequenceShardingContext (
659
- features = input_dist_result ,
624
+ features = input_dist_result_list [ i ] ,
660
625
features_before_input_dist = features_by_sharding [i ],
661
- unbucketize_permute_tensor = (
662
- input_dist .unbucketize_permute_tensor
663
- if isinstance (input_dist , InferRwSparseFeaturesDist )
664
- or isinstance (input_dist , InferCPURwSparseFeaturesDist )
665
- else None
666
- ),
626
+ unbucketize_permute_tensor = unbucketize_permute_tensor_list [i ],
667
627
)
668
628
)
669
- return ListOfKJTList ( ret )
629
+ return input_dist_result_list
670
630
671
631
def _embedding_dim_for_sharding_type (self , sharding_type : str ) -> int :
672
632
return (
@@ -680,7 +640,10 @@ def compute(
680
640
) -> List [List [torch .Tensor ]]:
681
641
ret : List [List [torch .Tensor ]] = []
682
642
683
- for lookup , features in zip (self ._lookups , dist_input ):
643
+ # for lookup, features in zip(self._lookups, dist_input):
644
+ for i in range (len (self ._lookups )):
645
+ lookup = self ._lookups [i ]
646
+ features = dist_input [i ]
684
647
ret .append (lookup .forward (features ))
685
648
return ret
686
649
@@ -848,3 +811,126 @@ def shard(
848
811
@property
849
812
def module_type (self ) -> Type [QuantEmbeddingCollection ]:
850
813
return QuantEmbeddingCollection
814
+
815
+
816
+ class ShardedQuantEcInputDist (torch .nn .Module ):
817
+ """
818
+ This module implements distributed inputs of a ShardedQuantEmbeddingCollection.
819
+
820
+ Args:
821
+ input_feature_names (List[str]): EmbeddingCollection feature names.
822
+ sharding_type_to_sharding (Dict[
823
+ str,
824
+ EmbeddingSharding[
825
+ InferSequenceShardingContext,
826
+ KJTList,
827
+ List[torch.Tensor],
828
+ List[torch.Tensor],
829
+ ],
830
+ ]): map from sharding type to EmbeddingSharding.
831
+ device (Optional[torch.device]): default compute device.
832
+ feature_device (Optional[torch.device]): runtime feature device.
833
+
834
+ Example::
835
+
836
+ sqec_input_dist = ShardedQuantEcInputDist(
837
+ sharding_type_to_sharding={
838
+ ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
839
+ [],
840
+ ShardingEnv(
841
+ world_size=2,
842
+ rank=0,
843
+ pg=0,
844
+ ),
845
+ torch.device("cpu")
846
+ )
847
+ },
848
+ device=torch.device("cpu"),
849
+ )
850
+
851
+ features = KeyedJaggedTensor(
852
+ keys=["f1", "f2"],
853
+ values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
854
+ offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
855
+ )
856
+
857
+ sqec_input_dist(features)
858
+ """
859
+
860
+ def __init__ (
861
+ self ,
862
+ input_feature_names : List [str ],
863
+ sharding_type_to_sharding : Dict [
864
+ str ,
865
+ EmbeddingSharding [
866
+ InferSequenceShardingContext ,
867
+ KJTList ,
868
+ List [torch .Tensor ],
869
+ List [torch .Tensor ],
870
+ ],
871
+ ],
872
+ device : Optional [torch .device ] = None ,
873
+ feature_device : Optional [torch .device ] = None ,
874
+ ) -> None :
875
+ super ().__init__ ()
876
+ self ._sharding_type_to_sharding = sharding_type_to_sharding
877
+ self ._input_dists = torch .nn .ModuleList ([])
878
+ self ._feature_splits : List [int ] = []
879
+ self ._features_order : List [int ] = []
880
+
881
+ self ._has_features_permute : bool = True
882
+
883
+ feature_names : List [str ] = []
884
+ for sharding in sharding_type_to_sharding .values ():
885
+ self ._input_dists .append (sharding .create_input_dist (device = device ))
886
+ feature_names .extend (sharding .feature_names ())
887
+ self ._feature_splits .append (len (sharding .feature_names ()))
888
+
889
+ for f in feature_names :
890
+ self ._features_order .append (input_feature_names .index (f ))
891
+ self ._features_order = (
892
+ []
893
+ if self ._features_order == list (range (len (self ._features_order )))
894
+ else self ._features_order
895
+ )
896
+ self .register_buffer (
897
+ "_features_order_tensor" ,
898
+ torch .tensor (
899
+ self ._features_order , device = feature_device , dtype = torch .int32
900
+ ),
901
+ persistent = False ,
902
+ )
903
+
904
+ def forward (
905
+ self , features : KeyedJaggedTensor
906
+ ) -> Tuple [List [KJTList ], List [KeyedJaggedTensor ], List [Optional [torch .Tensor ]]]:
907
+ with torch .no_grad ():
908
+ ret : List [KJTList ] = []
909
+ unbucketize_permute_tensor = []
910
+ if self ._features_order :
911
+ features = features .permute (
912
+ self ._features_order ,
913
+ self ._features_order_tensor ,
914
+ )
915
+ features_by_sharding = (
916
+ [features ]
917
+ if len (self ._feature_splits ) == 1
918
+ else features .split (self ._feature_splits )
919
+ )
920
+
921
+ for i in range (len (self ._input_dists )):
922
+ input_dist = self ._input_dists [i ]
923
+ input_dist_result = input_dist (features_by_sharding [i ])
924
+ ret .append (input_dist_result )
925
+ unbucketize_permute_tensor .append (
926
+ input_dist .unbucketize_permute_tensor
927
+ if isinstance (input_dist , InferRwSparseFeaturesDist )
928
+ or isinstance (input_dist , InferCPURwSparseFeaturesDist )
929
+ else None
930
+ )
931
+
932
+ return (
933
+ ret ,
934
+ features_by_sharding ,
935
+ unbucketize_permute_tensor ,
936
+ )
0 commit comments