Skip to content

create input dist module for sharded ec #1923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 142 additions & 56 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -574,34 +574,6 @@ def _generate_permute_coordinates_per_feature_per_sharding(
torch.tensor(permuted_coordinates)
)

def _create_input_dist(
self,
input_feature_names: List[str],
device: torch.device,
input_dist_device: Optional[torch.device] = None,
) -> None:
feature_names: List[str] = []
self._feature_splits: List[int] = []
for sharding in self._sharding_type_to_sharding.values():
self._input_dists.append(
sharding.create_input_dist(device=input_dist_device)
)
feature_names.extend(sharding.feature_names())
self._feature_splits.append(len(sharding.feature_names()))
self._features_order: List[int] = []
for f in feature_names:
self._features_order.append(input_feature_names.index(f))
self._features_order = (
[]
if self._features_order == list(range(len(self._features_order)))
else self._features_order
)
self.register_buffer(
"_features_order_tensor",
torch.tensor(self._features_order, device=device, dtype=torch.int32),
persistent=False,
)

def _create_lookups(
self,
fused_params: Optional[Dict[str, Any]],
@@ -627,46 +599,34 @@ def input_dist(
features: KeyedJaggedTensor,
) -> ListOfKJTList:
if self._has_uninitialized_input_dist:
self._create_input_dist(
self._intput_dist = ShardedQuantEcInputDist(
input_feature_names=features.keys() if features is not None else [],
device=features.device(),
input_dist_device=self._device,
sharding_type_to_sharding=self._sharding_type_to_sharding,
device=self._device,
feature_device=features.device(),
)
self._has_uninitialized_input_dist = False
if self._has_uninitialized_output_dist:
self._create_output_dist(features.device())
self._has_uninitialized_output_dist = False
ret: List[KJTList] = []

(
input_dist_result_list,
features_by_sharding,
unbucketize_permute_tensor_list,
) = self._intput_dist(features)

with torch.no_grad():
features_by_sharding = []
if self._features_order:
features = features.permute(
self._features_order,
self._features_order_tensor,
)
features_by_sharding = (
[features]
if len(self._feature_splits) == 1
else features.split(self._feature_splits)
)
for i in range(len(self._sharding_type_to_sharding)):

for i in range(len(self._input_dists)):
input_dist = self._input_dists[i]
input_dist_result = input_dist.forward(features_by_sharding[i])
ret.append(input_dist_result)
ctx.sharding_contexts.append(
InferSequenceShardingContext(
features=input_dist_result,
features=input_dist_result_list[i],
features_before_input_dist=features_by_sharding[i],
unbucketize_permute_tensor=(
input_dist.unbucketize_permute_tensor
if isinstance(input_dist, InferRwSparseFeaturesDist)
or isinstance(input_dist, InferCPURwSparseFeaturesDist)
else None
),
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
)
)
return ListOfKJTList(ret)
return input_dist_result_list

def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
return (
@@ -680,7 +640,10 @@ def compute(
) -> List[List[torch.Tensor]]:
ret: List[List[torch.Tensor]] = []

for lookup, features in zip(self._lookups, dist_input):
# for lookup, features in zip(self._lookups, dist_input):
for i in range(len(self._lookups)):
lookup = self._lookups[i]
features = dist_input[i]
ret.append(lookup.forward(features))
return ret

@@ -848,3 +811,126 @@ def shard(
@property
def module_type(self) -> Type[QuantEmbeddingCollection]:
return QuantEmbeddingCollection


class ShardedQuantEcInputDist(torch.nn.Module):
"""
This module implements distributed inputs of a ShardedQuantEmbeddingCollection.
Args:
input_feature_names (List[str]): EmbeddingCollection feature names.
sharding_type_to_sharding (Dict[
str,
EmbeddingSharding[
InferSequenceShardingContext,
KJTList,
List[torch.Tensor],
List[torch.Tensor],
],
]): map from sharding type to EmbeddingSharding.
device (Optional[torch.device]): default compute device.
feature_device (Optional[torch.device]): runtime feature device.
Example::
sqec_input_dist = ShardedQuantEcInputDist(
sharding_type_to_sharding={
ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
[],
ShardingEnv(
world_size=2,
rank=0,
pg=0,
),
torch.device("cpu")
)
},
device=torch.device("cpu"),
)
features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
)
sqec_input_dist(features)
"""

def __init__(
self,
input_feature_names: List[str],
sharding_type_to_sharding: Dict[
str,
EmbeddingSharding[
InferSequenceShardingContext,
KJTList,
List[torch.Tensor],
List[torch.Tensor],
],
],
device: Optional[torch.device] = None,
feature_device: Optional[torch.device] = None,
) -> None:
super().__init__()
self._sharding_type_to_sharding = sharding_type_to_sharding
self._input_dists = torch.nn.ModuleList([])
self._feature_splits: List[int] = []
self._features_order: List[int] = []

self._has_features_permute: bool = True

feature_names: List[str] = []
for sharding in sharding_type_to_sharding.values():
self._input_dists.append(sharding.create_input_dist(device=device))
feature_names.extend(sharding.feature_names())
self._feature_splits.append(len(sharding.feature_names()))

for f in feature_names:
self._features_order.append(input_feature_names.index(f))
self._features_order = (
[]
if self._features_order == list(range(len(self._features_order)))
else self._features_order
)
self.register_buffer(
"_features_order_tensor",
torch.tensor(
self._features_order, device=feature_device, dtype=torch.int32
),
persistent=False,
)

def forward(
self, features: KeyedJaggedTensor
) -> Tuple[List[KJTList], List[KeyedJaggedTensor], List[Optional[torch.Tensor]]]:
with torch.no_grad():
ret: List[KJTList] = []
unbucketize_permute_tensor = []
if self._features_order:
features = features.permute(
self._features_order,
self._features_order_tensor,
)
features_by_sharding = (
[features]
if len(self._feature_splits) == 1
else features.split(self._feature_splits)
)

for i in range(len(self._input_dists)):
input_dist = self._input_dists[i]
input_dist_result = input_dist(features_by_sharding[i])
ret.append(input_dist_result)
unbucketize_permute_tensor.append(
input_dist.unbucketize_permute_tensor
if isinstance(input_dist, InferRwSparseFeaturesDist)
or isinstance(input_dist, InferCPURwSparseFeaturesDist)
else None
)

return (
ret,
features_by_sharding,
unbucketize_permute_tensor,
)