From ded63bc8652ab1ae4aaef9b18526682d62fb5b35 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 25 Apr 2024 12:46:11 -0700 Subject: [PATCH] Fix FeatureProcessor device only to meta when exporting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: D56021085 - This fixed copying FP parameters to meta device when sharding model on meta Weird thing is though that FP parameters are not sparse parameters, they are dense. Therefore, they shouldn’t be moved to meta device as a result of sharding. https://fburl.com/code/xuv9s5k2 - AIMP assumes only sparse params are on meta device. However, the FP parameters should be properly moved when using .to(). That is not the current case, as FeatureProcessorCollections use ParameterDict but also has a mirroring Dict[str, nn.Parameter] for bypassing TorchScript issues with ParameterDict. Therefore, when .to() is used on a model, only the registered ParameterDict will bem oved but not the mirroring Dict. This diff overrides nn.Module _apply to make sure the ParameterDict and Dict are in sync. Differential Revision: D56492970 --- torchrec/distributed/quant_embeddingbag.py | 2 +- torchrec/distributed/tests/test_infer_shardings.py | 2 ++ torchrec/modules/feature_processor_.py | 9 +++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index d98b6b121..7b7166e1b 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -411,7 +411,7 @@ def __init__( # Generic copy, for example initailized on cpu but -> sharding as meta self.feature_processors_per_rank.append( copy.deepcopy(feature_processor) - if device_type == feature_processor_device + if device_type == "meta" else copy_to_device( feature_processor, feature_processor_device, diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index c112dcda5..1ec5fc4d3 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -2019,6 +2019,8 @@ def test_sharded_quant_fp_ebc_tw_meta(self) -> None: if isinstance(input, torch.Tensor): inputs[i] = input.to(torch.device("meta")) + # move dense params also to meta + sharded_model.to("meta") sharded_model(*inputs) # Don't care about the output since we are sharding on meta diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index 4c953e6c9..c88d3b45a 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -201,3 +201,12 @@ def copy(self, device: torch.device) -> nn.Module: self.position_weights_dict[key] = self.position_weights[key] return self + + # Override to make sure position_weights and position_weights_dict are in sync + # pyre-ignore [2] + def _apply(self, *args, **kwargs) -> nn.Module: + super()._apply(*args, **kwargs) + for k, param in self.position_weights.items(): + self.position_weights_dict[k] = param + + return self