Skip to content

Commit

Permalink
torchrec module (pytorch#2297)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2297

torchrec module to incorporate hash zch kernel with supporting training and inference needs.

Reviewed By: dstaay-fb, bixue2010

Differential Revision: D60942972

fbshipit-source-id: c3a7f6fa77a7edfa6881c2b55454cb1b44779832
  • Loading branch information
Bin Wen authored and facebook-github-bot committed Aug 14, 2024
1 parent ea358f2 commit 5e30669
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 68 deletions.
69 changes: 35 additions & 34 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@
ShardMetadata,
)
from torchrec.distributed.utils import append_prefix
from torchrec.modules.mc_modules import (
apply_mc_method_to_jt_dict,
ManagedCollisionCollection,
)
from torchrec.modules.mc_modules import ManagedCollisionCollection
from torchrec.modules.utils import construct_jagged_tensors
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

Expand Down Expand Up @@ -191,25 +188,28 @@ def _initialize_torch_state(self) -> None:
if name not in shardable_buffers:
continue

sharded_sizes = list(tensor.shape)
sharded_sizes[0] = shard_size
shard_offsets = [0] * len(sharded_sizes)
shard_offsets[0] = shard_offset
global_sizes = list(tensor.shape)
global_sizes[0] = global_size
self._model_parallel_mc_buffer_name_to_sharded_tensor[name] = (
ShardedTensor._init_from_local_shards(
[
Shard(
tensor=tensor,
metadata=ShardMetadata(
# pyre-ignore [6]
shard_offsets=[shard_offset],
# pyre-ignore [6]
shard_sizes=[shard_size],
shard_offsets=shard_offsets,
shard_sizes=sharded_sizes,
placement=(
f"rank:{self._env.rank}/cuda:"
f"{get_local_rank(self._env.world_size, self._env.rank)}"
),
),
)
],
# pyre-ignore [6]
torch.Size([global_size]),
torch.Size(global_sizes),
process_group=self._env.process_group,
)
)
Expand Down Expand Up @@ -256,9 +256,7 @@ def _create_managed_collision_modules(
self, module: ManagedCollisionCollection
) -> None:

self._mc_module_name_shard_metadata: DefaultDict[
str, DefaultDict[str, List[int]]
] = defaultdict(lambda: defaultdict(list))
self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict()
self._feature_to_offset: Dict[str, int] = {}

for sharding in self._embedding_shardings:
Expand Down Expand Up @@ -392,15 +390,19 @@ def input_dist(
self._has_uninitialized_input_dists = False

with torch.no_grad():
features_dict = features.to_dict()
output: Dict[str, JaggedTensor] = features_dict.copy()
for table, mc_module in self._managed_collision_modules.items():
feature_list: List[str] = self._table_to_features[table]
mc_input: Dict[str, JaggedTensor] = {}
for feature in feature_list:
mc_input[feature] = features_dict[feature]
mc_input = mc_module.preprocess(mc_input)
output.update(mc_input)

# NOTE shared features not currently supported
features = KeyedJaggedTensor.from_jt_dict(
apply_mc_method_to_jt_dict(
"preprocess",
features.to_dict(),
self._table_to_features,
self._managed_collision_modules,
)
)
features = KeyedJaggedTensor.from_jt_dict(output)

if self._features_order:
features = features.permute(
self._features_order,
Expand Down Expand Up @@ -456,19 +458,17 @@ def compute(
-1, features.stride()
)
features_dict = features.to_dict()
features_dict = apply_mc_method_to_jt_dict(
"profile",
features_dict=features_dict,
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
features_dict = apply_mc_method_to_jt_dict(
"remap",
features_dict=features_dict,
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(features_dict))
output: Dict[str, JaggedTensor] = features_dict.copy()
for table, mc_module in self._managed_collision_modules.items():
feature_list: List[str] = self._table_to_features[table]
mc_input: Dict[str, JaggedTensor] = {}
for feature in feature_list:
mc_input[feature] = features_dict[feature]
mc_input = mc_module.profile(mc_input)
mc_input = mc_module.remap(mc_input)
output.update(mc_input)

remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(output))

return KJTList(remapped_kjts)

Expand Down Expand Up @@ -522,6 +522,7 @@ def create_context(self) -> ManagedCollisionCollectionContext:
return ManagedCollisionCollectionContext(sharding_contexts=[])

def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
# TODO (bwen): this does not include `_hash_zch_identities`
for fqn, _ in self.named_buffers():
yield append_prefix(prefix, fqn)

Expand Down
95 changes: 61 additions & 34 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#!/usr/bin/env python3

import abc
from collections import defaultdict
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
Expand All @@ -30,23 +29,15 @@

@torch.fx.wrap
def apply_mc_method_to_jt_dict(
mc_module: nn.Module,
method: str,
features_dict: Dict[str, JaggedTensor],
table_to_features: Dict[str, List[str]],
managed_collisions: nn.ModuleDict,
) -> Dict[str, JaggedTensor]:
"""
Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering
"""
mc_output: Dict[str, JaggedTensor] = features_dict.copy()
for table, features in table_to_features.items():
mc_input: Dict[str, JaggedTensor] = {}
for feature in features:
mc_input[feature] = features_dict[feature]
mc_module = managed_collisions[table]
attr = getattr(mc_module, method)
mc_output.update(attr(mc_input))
return mc_output
attr = getattr(mc_module, method)
return attr(features_dict)


@torch.no_grad()
Expand Down Expand Up @@ -153,6 +144,14 @@ def evict(self) -> Optional[torch.Tensor]:
"""
pass

@abc.abstractmethod
def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
pass

@abc.abstractmethod
def profile(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
pass

@abc.abstractmethod
def forward(
self,
Expand Down Expand Up @@ -203,6 +202,8 @@ class ManagedCollisionCollection(nn.Module):
embedding_confgs (List[BaseEmbeddingConfig]): List of embedding configs, for each table with a managed collsion module
"""

_table_to_features: Dict[str, List[str]]

def __init__(
self,
managed_collision_modules: Dict[str, ManagedCollisionModule],
Expand All @@ -216,10 +217,13 @@ def __init__(
for config in embedding_configs
for feature in config.feature_names
}
self._table_to_features: Dict[str, List[str]] = defaultdict(list)
self._table_to_features = {}

self._compute_jt_dict_to_kjt = ComputeJTDictToKJT()
for feature, table in self._feature_to_table.items():
if table not in self._table_to_features:
self._table_to_features[table] = []

self._table_to_features[table].append(feature)

table_to_config = {config.name: config for config in embedding_configs}
Expand All @@ -243,25 +247,18 @@ def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedJaggedTensor:
features_dict = apply_mc_method_to_jt_dict(
"preprocess",
features_dict=features.to_dict(),
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
features_dict = apply_mc_method_to_jt_dict(
"profile",
features_dict=features_dict,
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
features_dict = apply_mc_method_to_jt_dict(
"remap",
features_dict=features_dict,
table_to_features=self._table_to_features,
managed_collisions=self._managed_collision_modules,
)
return self._compute_jt_dict_to_kjt(features_dict)
features_dict = features.to_dict()
output: Dict[str, JaggedTensor] = features_dict.copy()
for table, mc_module in self._managed_collision_modules.items():
feature_list: List[str] = self._table_to_features[table]
mc_input: Dict[str, JaggedTensor] = {}
for feature in feature_list:
mc_input[feature] = features_dict[feature]
mc_input = mc_module.preprocess(mc_input)
mc_input = mc_module.profile(mc_input)
mc_input = mc_module.remap(mc_input)
output.update(mc_input)
return self._compute_jt_dict_to_kjt(output)

def evict(self) -> Dict[str, Optional[torch.Tensor]]:
evictions: Dict[str, Optional[torch.Tensor]] = {}
Expand Down Expand Up @@ -933,7 +930,17 @@ def _init_history_buffers(self, features: Dict[str, JaggedTensor]) -> None:
self._history_metadata[metadata_name] = getattr(self, buffer_name)

@torch.no_grad()
def preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
def preprocess(
self,
features: Dict[str, JaggedTensor],
) -> Dict[str, JaggedTensor]:
return apply_mc_method_to_jt_dict(
self,
"_preprocess",
features,
)

def _preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
if self._input_hash_func is None:
return features
preprocessed_features: Dict[str, JaggedTensor] = {}
Expand Down Expand Up @@ -1070,6 +1077,16 @@ def _coalesce_history(self) -> None:
def profile(
self,
features: Dict[str, JaggedTensor],
) -> Dict[str, JaggedTensor]:
return apply_mc_method_to_jt_dict(
self,
"_profile",
features,
)

def _profile(
self,
features: Dict[str, JaggedTensor],
) -> Dict[str, JaggedTensor]:
if not self.training:
return features
Expand Down Expand Up @@ -1115,7 +1132,17 @@ def profile(
return features

@torch.no_grad()
def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
def remap(
self,
features: Dict[str, JaggedTensor],
) -> Dict[str, JaggedTensor]:
return apply_mc_method_to_jt_dict(
self,
"_remap",
features,
)

def _remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:

remapped_features: Dict[str, JaggedTensor] = {}
for name, feature in features.items():
Expand Down

0 comments on commit 5e30669

Please sign in to comment.