Skip to content

Commit

Permalink
Fix FeatureProcessor device only to meta when exporting
Browse files Browse the repository at this point in the history
Summary:
D56021085 - This fixed copying FP parameters to meta device when sharding model on meta
Weird thing is though that FP parameters are not sparse parameters, they are dense. Therefore, they shouldn’t be moved to meta device as a result of sharding.
https://fburl.com/code/xuv9s5k2 - AIMP assumes only sparse params are on meta device.

However, the FP parameters should be properly moved when using .to(). That is not the current case, as FeatureProcessorCollections use ParameterDict but also has a mirroring Dict[str, nn.Parameter] for bypassing TorchScript issues with ParameterDict. Therefore, when .to() is used on a model, only the registered ParameterDict will bem oved but not the mirroring Dict. This diff overrides nn.Module _apply to make sure the ParameterDict and Dict are in sync.

Differential Revision: D56492970
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 25, 2024
1 parent 76e854c commit ded63bc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def __init__(
# Generic copy, for example initailized on cpu but -> sharding as meta
self.feature_processors_per_rank.append(
copy.deepcopy(feature_processor)
if device_type == feature_processor_device
if device_type == "meta"
else copy_to_device(
feature_processor,
feature_processor_device,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,8 @@ def test_sharded_quant_fp_ebc_tw_meta(self) -> None:
if isinstance(input, torch.Tensor):
inputs[i] = input.to(torch.device("meta"))

# move dense params also to meta
sharded_model.to("meta")
sharded_model(*inputs)
# Don't care about the output since we are sharding on meta

Expand Down
9 changes: 9 additions & 0 deletions torchrec/modules/feature_processor_.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,12 @@ def copy(self, device: torch.device) -> nn.Module:
self.position_weights_dict[key] = self.position_weights[key]

return self

# Override to make sure position_weights and position_weights_dict are in sync
# pyre-ignore [2]
def _apply(self, *args, **kwargs) -> nn.Module:
super()._apply(*args, **kwargs)
for k, param in self.position_weights.items():
self.position_weights_dict[k] = param

return self

0 comments on commit ded63bc

Please sign in to comment.