From 1ed6a6af611faf6374c30549bab911afc043896e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 5 Feb 2025 11:38:56 +0800 Subject: [PATCH] add QuantManagedCollisionEmbeddingBagCollection --- .../distributed/test_utils/infer_utils.py | 9 +- .../distributed/tests/test_infer_shardings.py | 171 ++++++++++++++++- torchrec/quant/embedding_modules.py | 172 +++++++++++++++++- 3 files changed, 345 insertions(+), 7 deletions(-) diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 478e01bb2..827f8b48e 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -81,7 +81,10 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection, @@ -89,6 +92,7 @@ MODULE_ATTR_REGISTER_TBES_BOOL, quant_prep_enable_quant_state_dict_split_scale_bias_for_types, quant_prep_enable_register_tbes, + QuantManagedCollisionEmbeddingBagCollection, QuantManagedCollisionEmbeddingCollection, ) @@ -333,6 +337,7 @@ def quantize( module_types: List[Type[torch.nn.Module]] = [ torchrec.modules.embedding_modules.EmbeddingBagCollection, torchrec.modules.embedding_modules.EmbeddingCollection, + torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection, torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection, ] if register_tbes: @@ -359,11 +364,13 @@ def quantize( qconfig_spec={ EmbeddingBagCollection: qconfig, EmbeddingCollection: qconfig, + ManagedCollisionEmbeddingBagCollection: qconfig, ManagedCollisionEmbeddingCollection: qconfig, }, mapping={ EmbeddingBagCollection: QuantEmbeddingBagCollection, EmbeddingCollection: QuantEmbeddingCollection, + ManagedCollisionEmbeddingBagCollection: QuantManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection: QuantManagedCollisionEmbeddingCollection, }, inplace=inplace, diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 83b4649ee..c3e4c35f7 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -75,7 +75,10 @@ PositionWeightedModuleCollection, ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.modules.mc_modules import ( DistanceLFU_EvictionPolicy, ManagedCollisionCollection, @@ -2088,6 +2091,162 @@ def test_sharded_quant_fp_ebc_tw( gm_script_output = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + # pyre-ignore + @given( + weight_dtype=st.sampled_from([torch.qint8]), + device_type=st.sampled_from(["cpu", "cuda"]), + ) + @settings(max_examples=2, deadline=None) + def test_sharded_quant_mc_ebc_rw( + self, weight_dtype: torch.dtype, device_type: str + ) -> None: + num_embeddings = 10 + emb_dim = 16 + world_size = 2 + batch_size = 2 + local_device = torch.device(f"{device_type}:0") + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi = TestModelInfo( + dense_device=local_device, + sparse_device=local_device, + num_features=1, + num_float_features=10, + num_weighted_features=0, + topology=topology, + ) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + mi.tables = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=emb_dim, + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(mi.num_features) + ] + + mi.model = KJTInputWrapper( + module_kjt_input=torch.nn.Sequential( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=mi.tables, + device=mi.sparse_device, + ), + ManagedCollisionCollection( + managed_collision_modules={ + "table_0": MCHManagedCollisionModule( + zch_size=num_embeddings, + input_hash_size=4000, + device=mi.sparse_device, + eviction_interval=2, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + }, + embedding_configs=mi.tables, + ), + ) + ) + ) + model_inputs: List[ModelInput] = prep_inputs( + mi, world_size, batch_size, long_indices=True + ) + inputs = [] + for model_input in model_inputs: + kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) + kjt = kjt.to(local_device) + weights = None + inputs.append( + ( + kjt._keys, + kjt._values, + weights, + kjt._lengths, + kjt._offsets, + ) + ) + + mi.model(*inputs[0]) + print(f"model:\n{mi.model}") + assert mi.model.training is True + mi.quant_model = quantize( + module=mi.model, + inplace=False, + register_tbes=False, + quant_state_dict_split_scale_bias=True, + weight_dtype=weight_dtype, + ) + quant_model = mi.quant_model + assert quant_model.training is False + non_sharded_output, _ = mi.quant_model(*inputs[0]) + + topology: Topology = Topology(world_size=world_size, compute_device=device_type) + mi.planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology, is_inference=True), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + sharder = QuantEmbeddingCollectionSharder() + # pyre-ignore + plan = mi.planner.plan( + mi.quant_model, + [sharder], + ) + + sharded_model = shard_qec( + mi, + sharding_type=ShardingType.ROW_WISE, + device=local_device, + expected_shards=None, + plan=plan, + ) + + print(f"sharded_model:\n{sharded_model}") + for n, m in sharded_model.named_modules(): + print(f"sharded_model.MODULE[{n}]:{type(m)}") + + sharded_model.load_state_dict(quant_model.state_dict()) + sharded_output, _ = sharded_model(*inputs[0]) + + assert_close(non_sharded_output, sharded_output) + gm: torch.fx.GraphModule = symbolic_trace( + sharded_model, + leaf_modules=[ + "IntNBitTableBatchedEmbeddingBagsCodegen", + "ComputeJTDictToKJT", + ], + ) + + print(f"fx.graph:\n{gm.graph}") + gm_script = torch.jit.script(gm) + print(f"gm_script:\n{gm_script}") + gm_script_output, _ = gm_script(*inputs[0]) + assert_close(sharded_output, gm_script_output) + @unittest.skipIf( torch.cuda.device_count() <= 1, "Not enough GPUs available", @@ -2192,7 +2351,7 @@ def test_sharded_quant_mc_ec_rw( ) quant_model = mi.quant_model assert quant_model.training is False - non_sharded_output = mi.quant_model(*inputs[0]) + non_sharded_output, _ = mi.quant_model(*inputs[0]) topology: Topology = Topology(world_size=world_size, compute_device=device_type) mi.planner = EmbeddingShardingPlanner( @@ -2227,7 +2386,7 @@ def test_sharded_quant_mc_ec_rw( print(f"sharded_model.MODULE[{n}]:{type(m)}") sharded_model.load_state_dict(quant_model.state_dict()) - sharded_output = sharded_model(*inputs[0]) + sharded_output, _ = sharded_model(*inputs[0]) assert_close(non_sharded_output, sharded_output) gm: torch.fx.GraphModule = symbolic_trace( @@ -2241,7 +2400,7 @@ def test_sharded_quant_mc_ec_rw( print(f"fx.graph:\n{gm.graph}") gm_script = torch.jit.script(gm) print(f"gm_script:\n{gm_script}") - gm_script_output = gm_script(*inputs[0]) + gm_script_output, _ = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) @unittest.skipIf( @@ -2409,3 +2568,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: gm_script = torch.jit.script(gm) print(f"gm_script:\n{gm_script}") gm_script(*inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 06239ff7f..d9f7fc5c4 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -56,6 +56,7 @@ FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, ) from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection as OriginalManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ManagedCollisionCollection @@ -921,6 +922,170 @@ def device(self) -> torch.device: return self._device +class QuantManagedCollisionEmbeddingBagCollection(EmbeddingBagCollection): + """QuantManagedCollisionEmbeddingBagCollection represents a quantized EBC module. + + The inputs into the MC-EC/EBC will first be modified by the managed collision module + before being passed into the embedding collection. + + Args: + tables (List[EmbeddingBagConfig]): A list of EmbeddingBagConfig + objects representing the embedding tables in the collection. + is_weighted (bool): whether input `KeyedJaggedTensor` is weighted. + device (torch.device): The device on which the embedding bag collection will + be allocated. + output_dtype (torch.dtype, optional): The data type of the output embeddings. + Defaults to torch.float. + table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]], optional): + A dictionary mapping table names to their corresponding quantized weights. + Defaults to None. + register_tbes (bool, optional): Whether to register the TBEs in the model. + Defaults to False. + quant_state_dict_split_scale_bias (bool, optional): Whether to split the scale + and bias parameters when saving the quantized state dict. Defaults to False. + row_alignment (int, optional): The alignment of rows in the quantized weights. + Defaults to DEFAULT_ROW_ALIGNMENT. + managed_collision_collection (ManagedCollisionCollection, optional): The managed + collision collection to use for managing collisions. Defaults to None. + return_remapped_features (bool, optional): Whether to return the remapped input + features in addition to the embeddings. Defaults to False. + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool, + device: torch.device, + output_dtype: torch.dtype = torch.float, + table_name_to_quantized_weights: Optional[ + Dict[str, Tuple[Tensor, Tensor]] + ] = None, + register_tbes: bool = False, + quant_state_dict_split_scale_bias: bool = False, + row_alignment: int = DEFAULT_ROW_ALIGNMENT, + managed_collision_collection: Optional[ManagedCollisionCollection] = None, + return_remapped_features: bool = False, + ) -> None: + super().__init__( + tables, + is_weighted, + device, + output_dtype, + table_name_to_quantized_weights, + register_tbes, + quant_state_dict_split_scale_bias, + row_alignment, + ) + assert ( + managed_collision_collection + ), "Managed collision collection cannot be None" + self._managed_collision_collection: ManagedCollisionCollection = ( + managed_collision_collection + ) + self._return_remapped_features = return_remapped_features + + assert str(self.embedding_bag_configs()) == str( + self._managed_collision_collection.embedding_configs() + ), ( + "EmbeddingBagCollection and Managed Collision Collection must contain the " + "same Embedding Configs" + ) + + # Assuming quantized MC-EBC is used in inference only + for ( + managed_collision_module + ) in self._managed_collision_collection._managed_collision_modules.values(): + managed_collision_module.reset_inference_mode() + + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingBagCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + + # pyre-ignore + def forward( + self, + features: KeyedJaggedTensor, + ) -> Tuple[KeyedTensor, Optional[KeyedJaggedTensor]]: + features = self._managed_collision_collection(features) + embedding_res = super().forward(features) + + if not self._return_remapped_features: + return embedding_res, None + return embedding_res, features + + def _get_name(self) -> str: + return "QuantManagedCollisionEmbeddingBagCollection" + + @classmethod + # pyre-ignore + def from_float( + cls, + module: OriginalManagedCollisionEmbeddingBagCollection, + return_remapped_features: bool = False, + ) -> "QuantManagedCollisionEmbeddingBagCollection": + mc_ebc = module + ebc = module._embedding_module + + # pyre-ignore[9] + qconfig: torch.quantization.QConfig = module.qconfig + assert hasattr(module, "qconfig"), ( + "QuantManagedCollisionEmbeddingBagCollection input float module must " + "have qconfig defined" + ) + + # pyre-ignore[29] + embedding_configs = copy.deepcopy(ebc.embedding_bag_configs()) + _update_embedding_configs( + cast(List[BaseEmbeddingConfig], embedding_configs), + qconfig, + ) + _update_embedding_configs( + mc_ebc._managed_collision_collection._embedding_configs, + qconfig, + ) + + # pyre-ignore[9] + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] | None = ( + ebc._table_name_to_quantized_weights + if hasattr(ebc, "_table_name_to_quantized_weights") + else None + ) + device = _get_device(ebc) + return cls( + embedding_configs, + ebc.is_weighted(), + device=device, + output_dtype=qconfig.activation().dtype, + table_name_to_quantized_weights=table_name_to_quantized_weights, + register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False), + quant_state_dict_split_scale_bias=getattr( + ebc, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False + ), + row_alignment=getattr( + ebc, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT + ), + managed_collision_collection=mc_ebc._managed_collision_collection, + return_remapped_features=mc_ebc._return_remapped_features, + ) + + class QuantManagedCollisionEmbeddingCollection(EmbeddingCollection): """ QuantManagedCollisionEmbeddingCollection represents a quantized EC module and a set of managed collision modules. @@ -1006,10 +1171,13 @@ def to( def forward( self, features: KeyedJaggedTensor, - ) -> Dict[str, JaggedTensor]: + ) -> Tuple[Dict[str, JaggedTensor], Optional[KeyedJaggedTensor]]: features = self._managed_collision_collection(features) + embedding_res = super().forward(features) - return super().forward(features) + if not self._return_remapped_features: + return embedding_res, None + return embedding_res, features def _get_name(self) -> str: return "QuantManagedCollisionEmbeddingCollection"