Skip to content

Commit cd493f1

Browse files
nipung90facebook-github-bot
authored andcommitted
Allow the ability for uneven row wise sharding based on number of buckets for zch (#3341)
Summary: Pull Request resolved: #3341 X-link: #3341 This diff enables the use of num_buckets ParameterConstraint in the planner. The presence of this planner will indicate the use of ZCH bucketing as part of rowwise sharding plans. ## Without num_buckets present: The current row-wise sharding strategy will be used. ## With num_buckets present: * When devices have the same amount of memory available: We will divide the buckets to be evenly distributed across hosts and distribute an additional bucket to the required number of hosts to handle the remainders. For eg. if Test case 2: hash_size = 100, num_devices = 4, num_buckets = 10 Each bucket has 10 rows, buckets distributed as [3,3,2,2] So rows are distributed as [30,30,20,20] * When devices have uneven amount of memory We will distribute the buckets in the proportion of the memory of the device to the total memory of all devices and all the remaining buckets left are stored on the last device in the case where buckets do not completely fit based on the memory ratios. for eg hash_size = 45, num_buckets = 9, bucket_size = 5 With memory ratio 2:1:1, buckets should be distributed as [4,2,3] So rows are distributed as [20,10,15] Reviewed By: emlin Differential Revision: D79659949 fbshipit-source-id: 5d7e0d55bb8371cf2516ba389a9575a42eb4906e
1 parent 980bb4e commit cd493f1

File tree

8 files changed

+528
-28
lines changed

8 files changed

+528
-28
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,16 +456,23 @@ def _get_sharded_local_buckets_for_zero_collision(
456456

457457
for table in embedding_tables:
458458
total_num_buckets = none_throws(table.total_num_buckets)
459-
assert (
460-
total_num_buckets % world_size == 0
461-
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
462459
assert (
463460
table.total_num_buckets
464461
and table.num_embeddings % table.total_num_buckets == 0
465462
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
466-
bucket_offset_start = total_num_buckets // world_size * local_rank
463+
extra_local_buckets = int(local_rank < (total_num_buckets % world_size))
464+
extra_bucket_padding = (
465+
(total_num_buckets % world_size)
466+
if local_rank >= (total_num_buckets % world_size)
467+
else 0
468+
)
469+
bucket_offset_start = (
470+
total_num_buckets // world_size + extra_local_buckets
471+
) * local_rank + extra_bucket_padding
467472
bucket_offset_end = min(
468-
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
473+
total_num_buckets,
474+
(total_num_buckets // world_size + extra_local_buckets) * (local_rank + 1)
475+
+ extra_bucket_padding,
469476
)
470477
bucket_size = (
471478
table.num_embeddings + total_num_buckets - 1

torchrec/distributed/embedding_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,13 @@ def create_virtual_table_global_metadata(
9999
# Otherwise it will only set correct size on current rank and
100100
# virtual PMT will trigger recalc for the correct global size/offset.
101101
# NOTE this currently only works for row-wise sharding
102+
my_rank_shard_size = metadata.shards_metadata[my_rank].shard_sizes[0]
102103
for rank, shard_metadata in enumerate(metadata.shards_metadata):
103104
if use_param_size_as_rows: # respect the param size and treat it as rows
104-
curr_rank_rows = param.size()[0] # pyre-ignore[16]
105+
# The param size only has the information for my_rank. In order to
106+
# correctly calculate the size for other ranks, we need to use the current
107+
# rank's shard size compared to the shard size of my_rank.
108+
curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16]
105109
else:
106110
curr_rank_rows = (
107111
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1

torchrec/distributed/planner/enumerators.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
ShardingType,
3939
)
4040
from torchrec.modules.embedding_configs import DataType
41+
from torchrec.modules.embedding_modules import (
42+
EmbeddingBagCollection,
43+
EmbeddingCollection,
44+
)
4145
from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection
4246

4347

@@ -178,7 +182,7 @@ def enumerate(
178182
# skip for other device groups
179183
if device_group and device_group != self._compute_device:
180184
continue
181-
185+
num_buckets = self._get_num_buckets(name, child_module)
182186
sharding_options_per_table: List[ShardingOption] = []
183187

184188
for sharding_type in self._filter_sharding_types(
@@ -200,6 +204,7 @@ def enumerate(
200204
sharding_type=sharding_type,
201205
col_wise_shard_dim=col_wise_shard_dim,
202206
device_memory_sizes=self._device_memory_sizes,
207+
num_buckets=num_buckets,
203208
)
204209
except ZeroDivisionError as e:
205210
# Re-raise with additional context about the table and module
@@ -264,6 +269,33 @@ def enumerate(
264269
self._last_stored_search_space = copy.deepcopy(sharding_options)
265270
return sharding_options
266271

272+
def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]:
273+
"""
274+
Get the number of buckets for each embedding table.
275+
276+
Args:
277+
parameter (str): name of the embedding table.
278+
module (nn.Module): module to be sharded.
279+
280+
Returns:
281+
Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found.
282+
"""
283+
# If module is not of type EmbeddingBagCollection, return None
284+
if isinstance(module, EmbeddingBagCollection):
285+
embedding_configs = module.embedding_bag_configs()
286+
elif isinstance(module, EmbeddingCollection):
287+
embedding_configs = module.embedding_configs()
288+
else:
289+
return None
290+
291+
# Find the embedding config for the table with the same name as parameter input
292+
for config in embedding_configs:
293+
if config.name == parameter and config.use_virtual_table:
294+
return config.total_num_buckets
295+
296+
# If table with matching name not found, return None
297+
return None
298+
267299
@property
268300
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
269301
# NOTE: This is the last search space stored by enumerate(...), do not use

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
EmbeddingTowerSharder,
1919
)
2020
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
21-
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
21+
from torchrec.distributed.embeddingbag import (
22+
EmbeddingBagCollection,
23+
EmbeddingBagCollectionSharder,
24+
)
2225
from torchrec.distributed.mc_embeddingbag import (
2326
ManagedCollisionEmbeddingBagCollectionSharder,
2427
)
@@ -45,13 +48,27 @@
4548
[[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]],
4649
]
4750

51+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52+
[[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]],
53+
[[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]],
54+
[[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]],
55+
[[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]],
56+
]
57+
4858
EXPECTED_RW_SHARD_OFFSETS = [
4959
[[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]],
5060
[[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]],
5161
[[0, 0], [15, 0], [30, 0], [45, 0], [60, 0], [75, 0], [90, 0], [105, 0]],
5262
[[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]],
5363
]
5464

65+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66+
[[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]],
67+
[[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]],
68+
[[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]],
69+
[[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]],
70+
]
71+
5572

5673
def get_expected_cache_aux_size(rows: int) -> int:
5774
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
101118
],
102119
]
103120

121+
EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
122+
[
123+
Storage(hbm=165888, ddr=0),
124+
Storage(hbm=165888, ddr=0),
125+
Storage(hbm=165888, ddr=0),
126+
Storage(hbm=165888, ddr=0),
127+
Storage(hbm=165888, ddr=0),
128+
Storage(hbm=165888, ddr=0),
129+
Storage(hbm=165888, ddr=0),
130+
Storage(hbm=165888, ddr=0),
131+
],
132+
[
133+
Storage(hbm=1001472, ddr=0),
134+
Storage(hbm=1001472, ddr=0),
135+
Storage(hbm=1001472, ddr=0),
136+
Storage(hbm=1001472, ddr=0),
137+
Storage(hbm=1001472, ddr=0),
138+
Storage(hbm=1001472, ddr=0),
139+
Storage(hbm=1001472, ddr=0),
140+
Storage(hbm=1001472, ddr=0),
141+
],
142+
[
143+
Storage(hbm=1003520, ddr=0),
144+
Storage(hbm=1003520, ddr=0),
145+
Storage(hbm=1003520, ddr=0),
146+
Storage(hbm=1003520, ddr=0),
147+
Storage(hbm=1003520, ddr=0),
148+
Storage(hbm=1003520, ddr=0),
149+
Storage(hbm=1003520, ddr=0),
150+
Storage(hbm=1003520, ddr=0),
151+
],
152+
[
153+
Storage(hbm=2648064, ddr=0),
154+
Storage(hbm=2648064, ddr=0),
155+
Storage(hbm=2648064, ddr=0),
156+
Storage(hbm=2648064, ddr=0),
157+
Storage(hbm=2648064, ddr=0),
158+
Storage(hbm=2648064, ddr=0),
159+
Storage(hbm=2648064, ddr=0),
160+
Storage(hbm=2648064, ddr=0),
161+
],
162+
]
104163

105164
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
106165
[
@@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
145204
],
146205
]
147206

207+
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208+
[
209+
Storage(hbm=166352, ddr=1600),
210+
Storage(hbm=166352, ddr=1600),
211+
Storage(hbm=166120, ddr=800),
212+
Storage(hbm=166120, ddr=800),
213+
Storage(hbm=166120, ddr=800),
214+
Storage(hbm=166120, ddr=800),
215+
Storage(hbm=166120, ddr=800),
216+
Storage(hbm=166120, ddr=800),
217+
],
218+
[
219+
Storage(hbm=1002335, ddr=3520),
220+
Storage(hbm=1002335, ddr=3520),
221+
Storage(hbm=1001904, ddr=1760),
222+
Storage(hbm=1001904, ddr=1760),
223+
Storage(hbm=1001904, ddr=1760),
224+
Storage(hbm=1001904, ddr=1760),
225+
Storage(hbm=1001904, ddr=1760),
226+
Storage(hbm=1001904, ddr=1760),
227+
],
228+
[
229+
Storage(hbm=1004845, ddr=5760),
230+
Storage(hbm=1004845, ddr=5760),
231+
Storage(hbm=1004183, ddr=2880),
232+
Storage(hbm=1004183, ddr=2880),
233+
Storage(hbm=1004183, ddr=2880),
234+
Storage(hbm=1004183, ddr=2880),
235+
Storage(hbm=1004183, ddr=2880),
236+
Storage(hbm=1004183, ddr=2880),
237+
],
238+
[
239+
Storage(hbm=2649916, ddr=8320),
240+
Storage(hbm=2649916, ddr=8320),
241+
Storage(hbm=2648990, ddr=4160),
242+
Storage(hbm=2648990, ddr=4160),
243+
Storage(hbm=2648990, ddr=4160),
244+
Storage(hbm=2648990, ddr=4160),
245+
Storage(hbm=2648990, ddr=4160),
246+
Storage(hbm=2648990, ddr=4160),
247+
],
248+
]
148249

149250
EXPECTED_TWRW_SHARD_SIZES = [
150251
[[25, 20], [25, 20], [25, 20], [25, 20]],
@@ -248,6 +349,16 @@ def compute_kernels(
248349
return [EmbeddingComputeKernel.FUSED.value]
249350

250351

352+
class VirtualTableRWSharder(EmbeddingBagCollectionSharder):
353+
def sharding_types(self, compute_device_type: str) -> List[str]:
354+
return [ShardingType.ROW_WISE.value]
355+
356+
def compute_kernels(
357+
self, sharding_type: str, compute_device_type: str
358+
) -> List[str]:
359+
return [EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value]
360+
361+
251362
class UVMCachingRWSharder(EmbeddingBagCollectionSharder):
252363
def sharding_types(self, compute_device_type: str) -> List[str]:
253364
return [ShardingType.ROW_WISE.value]
@@ -357,6 +468,27 @@ def setUp(self) -> None:
357468
min_partition=40, pooling_factors=[2, 1, 3, 7]
358469
),
359470
}
471+
self._virtual_table_constraints = {
472+
"table_0": ParameterConstraints(
473+
min_partition=20,
474+
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
475+
),
476+
"table_1": ParameterConstraints(
477+
min_partition=20,
478+
pooling_factors=[1, 3, 5],
479+
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
480+
),
481+
"table_2": ParameterConstraints(
482+
min_partition=20,
483+
pooling_factors=[8, 2],
484+
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
485+
),
486+
"table_3": ParameterConstraints(
487+
min_partition=40,
488+
pooling_factors=[2, 1, 3, 7],
489+
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
490+
),
491+
}
360492
self.num_tables = 4
361493
tables = [
362494
EmbeddingBagConfig(
@@ -367,6 +499,17 @@ def setUp(self) -> None:
367499
)
368500
for i in range(self.num_tables)
369501
]
502+
tables_with_buckets = [
503+
EmbeddingBagConfig(
504+
num_embeddings=100 + i * 10,
505+
embedding_dim=20 + i * 20,
506+
name="table_" + str(i),
507+
feature_names=["feature_" + str(i)],
508+
total_num_buckets=10,
509+
use_virtual_table=True,
510+
)
511+
for i in range(self.num_tables)
512+
]
370513
weighted_tables = [
371514
EmbeddingBagConfig(
372515
num_embeddings=(i + 1) * 10,
@@ -377,6 +520,9 @@ def setUp(self) -> None:
377520
for i in range(4)
378521
]
379522
self.model = TestSparseNN(tables=tables, weighted_tables=[])
523+
self.model_with_buckets = EmbeddingBagCollection(
524+
tables=tables_with_buckets,
525+
)
380526
self.enumerator = EmbeddingEnumerator(
381527
topology=Topology(
382528
world_size=self.world_size,
@@ -386,6 +532,15 @@ def setUp(self) -> None:
386532
batch_size=self.batch_size,
387533
constraints=self.constraints,
388534
)
535+
self.virtual_table_enumerator = EmbeddingEnumerator(
536+
topology=Topology(
537+
world_size=self.world_size,
538+
compute_device=self.compute_device,
539+
local_world_size=self.local_world_size,
540+
),
541+
batch_size=self.batch_size,
542+
constraints=self._virtual_table_constraints,
543+
)
389544
self.tower_model = TestTowerSparseNN(
390545
tables=tables, weighted_tables=weighted_tables
391546
)
@@ -514,6 +669,26 @@ def test_rw_sharding(self) -> None:
514669
EXPECTED_RW_SHARD_STORAGE[i],
515670
)
516671

672+
def test_virtual_table_rw_sharding_with_buckets(self) -> None:
673+
sharding_options = self.virtual_table_enumerator.enumerate(
674+
self.model_with_buckets,
675+
[cast(ModuleSharder[torch.nn.Module], VirtualTableRWSharder())],
676+
)
677+
for i, sharding_option in enumerate(sharding_options):
678+
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
679+
self.assertEqual(
680+
[shard.size for shard in sharding_option.shards],
681+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
682+
)
683+
self.assertEqual(
684+
[shard.offset for shard in sharding_option.shards],
685+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
686+
)
687+
self.assertEqual(
688+
[shard.storage for shard in sharding_option.shards],
689+
EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS[i],
690+
)
691+
517692
def test_uvm_caching_rw_sharding(self) -> None:
518693
sharding_options = self.enumerator.enumerate(
519694
self.model,
@@ -535,6 +710,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
535710
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i],
536711
)
537712

713+
def test_uvm_caching_rw_sharding_with_buckets(self) -> None:
714+
sharding_options = self.enumerator.enumerate(
715+
self.model_with_buckets,
716+
[cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())],
717+
)
718+
for i, sharding_option in enumerate(sharding_options):
719+
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
720+
self.assertEqual(
721+
[shard.size for shard in sharding_option.shards],
722+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
723+
)
724+
self.assertEqual(
725+
[shard.offset for shard in sharding_option.shards],
726+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
727+
)
728+
self.assertEqual(
729+
[shard.storage for shard in sharding_option.shards],
730+
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i],
731+
)
732+
538733
def test_twrw_sharding(self) -> None:
539734
sharding_options = self.enumerator.enumerate(
540735
self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())]

0 commit comments

Comments
 (0)