18
18
EmbeddingTowerSharder ,
19
19
)
20
20
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
21
- from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
21
+ from torchrec .distributed .embeddingbag import (
22
+ EmbeddingBagCollection ,
23
+ EmbeddingBagCollectionSharder ,
24
+ )
22
25
from torchrec .distributed .mc_embeddingbag import (
23
26
ManagedCollisionEmbeddingBagCollectionSharder ,
24
27
)
45
48
[[17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [11 , 80 ]],
46
49
]
47
50
51
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52
+ [[20 , 20 ], [20 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ]],
53
+ [[22 , 40 ], [22 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ]],
54
+ [[24 , 60 ], [24 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ]],
55
+ [[26 , 80 ], [26 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ]],
56
+ ]
57
+
48
58
EXPECTED_RW_SHARD_OFFSETS = [
49
59
[[0 , 0 ], [13 , 0 ], [26 , 0 ], [39 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ]],
50
60
[[0 , 0 ], [14 , 0 ], [28 , 0 ], [42 , 0 ], [56 , 0 ], [70 , 0 ], [84 , 0 ], [98 , 0 ]],
51
61
[[0 , 0 ], [15 , 0 ], [30 , 0 ], [45 , 0 ], [60 , 0 ], [75 , 0 ], [90 , 0 ], [105 , 0 ]],
52
62
[[0 , 0 ], [17 , 0 ], [34 , 0 ], [51 , 0 ], [68 , 0 ], [85 , 0 ], [102 , 0 ], [119 , 0 ]],
53
63
]
54
64
65
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66
+ [[0 , 0 ], [20 , 0 ], [40 , 0 ], [50 , 0 ], [60 , 0 ], [70 , 0 ], [80 , 0 ], [90 , 0 ]],
67
+ [[0 , 0 ], [22 , 0 ], [44 , 0 ], [55 , 0 ], [66 , 0 ], [77 , 0 ], [88 , 0 ], [99 , 0 ]],
68
+ [[0 , 0 ], [24 , 0 ], [48 , 0 ], [60 , 0 ], [72 , 0 ], [84 , 0 ], [96 , 0 ], [108 , 0 ]],
69
+ [[0 , 0 ], [26 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ], [104 , 0 ], [117 , 0 ]],
70
+ ]
71
+
55
72
56
73
def get_expected_cache_aux_size (rows : int ) -> int :
57
74
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
101
118
],
102
119
]
103
120
121
+ EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
122
+ [
123
+ Storage (hbm = 165888 , ddr = 0 ),
124
+ Storage (hbm = 165888 , ddr = 0 ),
125
+ Storage (hbm = 165888 , ddr = 0 ),
126
+ Storage (hbm = 165888 , ddr = 0 ),
127
+ Storage (hbm = 165888 , ddr = 0 ),
128
+ Storage (hbm = 165888 , ddr = 0 ),
129
+ Storage (hbm = 165888 , ddr = 0 ),
130
+ Storage (hbm = 165888 , ddr = 0 ),
131
+ ],
132
+ [
133
+ Storage (hbm = 1001472 , ddr = 0 ),
134
+ Storage (hbm = 1001472 , ddr = 0 ),
135
+ Storage (hbm = 1001472 , ddr = 0 ),
136
+ Storage (hbm = 1001472 , ddr = 0 ),
137
+ Storage (hbm = 1001472 , ddr = 0 ),
138
+ Storage (hbm = 1001472 , ddr = 0 ),
139
+ Storage (hbm = 1001472 , ddr = 0 ),
140
+ Storage (hbm = 1001472 , ddr = 0 ),
141
+ ],
142
+ [
143
+ Storage (hbm = 1003520 , ddr = 0 ),
144
+ Storage (hbm = 1003520 , ddr = 0 ),
145
+ Storage (hbm = 1003520 , ddr = 0 ),
146
+ Storage (hbm = 1003520 , ddr = 0 ),
147
+ Storage (hbm = 1003520 , ddr = 0 ),
148
+ Storage (hbm = 1003520 , ddr = 0 ),
149
+ Storage (hbm = 1003520 , ddr = 0 ),
150
+ Storage (hbm = 1003520 , ddr = 0 ),
151
+ ],
152
+ [
153
+ Storage (hbm = 2648064 , ddr = 0 ),
154
+ Storage (hbm = 2648064 , ddr = 0 ),
155
+ Storage (hbm = 2648064 , ddr = 0 ),
156
+ Storage (hbm = 2648064 , ddr = 0 ),
157
+ Storage (hbm = 2648064 , ddr = 0 ),
158
+ Storage (hbm = 2648064 , ddr = 0 ),
159
+ Storage (hbm = 2648064 , ddr = 0 ),
160
+ Storage (hbm = 2648064 , ddr = 0 ),
161
+ ],
162
+ ]
104
163
105
164
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
106
165
[
@@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
145
204
],
146
205
]
147
206
207
+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208
+ [
209
+ Storage (hbm = 166352 , ddr = 1600 ),
210
+ Storage (hbm = 166352 , ddr = 1600 ),
211
+ Storage (hbm = 166120 , ddr = 800 ),
212
+ Storage (hbm = 166120 , ddr = 800 ),
213
+ Storage (hbm = 166120 , ddr = 800 ),
214
+ Storage (hbm = 166120 , ddr = 800 ),
215
+ Storage (hbm = 166120 , ddr = 800 ),
216
+ Storage (hbm = 166120 , ddr = 800 ),
217
+ ],
218
+ [
219
+ Storage (hbm = 1002335 , ddr = 3520 ),
220
+ Storage (hbm = 1002335 , ddr = 3520 ),
221
+ Storage (hbm = 1001904 , ddr = 1760 ),
222
+ Storage (hbm = 1001904 , ddr = 1760 ),
223
+ Storage (hbm = 1001904 , ddr = 1760 ),
224
+ Storage (hbm = 1001904 , ddr = 1760 ),
225
+ Storage (hbm = 1001904 , ddr = 1760 ),
226
+ Storage (hbm = 1001904 , ddr = 1760 ),
227
+ ],
228
+ [
229
+ Storage (hbm = 1004845 , ddr = 5760 ),
230
+ Storage (hbm = 1004845 , ddr = 5760 ),
231
+ Storage (hbm = 1004183 , ddr = 2880 ),
232
+ Storage (hbm = 1004183 , ddr = 2880 ),
233
+ Storage (hbm = 1004183 , ddr = 2880 ),
234
+ Storage (hbm = 1004183 , ddr = 2880 ),
235
+ Storage (hbm = 1004183 , ddr = 2880 ),
236
+ Storage (hbm = 1004183 , ddr = 2880 ),
237
+ ],
238
+ [
239
+ Storage (hbm = 2649916 , ddr = 8320 ),
240
+ Storage (hbm = 2649916 , ddr = 8320 ),
241
+ Storage (hbm = 2648990 , ddr = 4160 ),
242
+ Storage (hbm = 2648990 , ddr = 4160 ),
243
+ Storage (hbm = 2648990 , ddr = 4160 ),
244
+ Storage (hbm = 2648990 , ddr = 4160 ),
245
+ Storage (hbm = 2648990 , ddr = 4160 ),
246
+ Storage (hbm = 2648990 , ddr = 4160 ),
247
+ ],
248
+ ]
148
249
149
250
EXPECTED_TWRW_SHARD_SIZES = [
150
251
[[25 , 20 ], [25 , 20 ], [25 , 20 ], [25 , 20 ]],
@@ -248,6 +349,16 @@ def compute_kernels(
248
349
return [EmbeddingComputeKernel .FUSED .value ]
249
350
250
351
352
+ class VirtualTableRWSharder (EmbeddingBagCollectionSharder ):
353
+ def sharding_types (self , compute_device_type : str ) -> List [str ]:
354
+ return [ShardingType .ROW_WISE .value ]
355
+
356
+ def compute_kernels (
357
+ self , sharding_type : str , compute_device_type : str
358
+ ) -> List [str ]:
359
+ return [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ]
360
+
361
+
251
362
class UVMCachingRWSharder (EmbeddingBagCollectionSharder ):
252
363
def sharding_types (self , compute_device_type : str ) -> List [str ]:
253
364
return [ShardingType .ROW_WISE .value ]
@@ -357,6 +468,27 @@ def setUp(self) -> None:
357
468
min_partition = 40 , pooling_factors = [2 , 1 , 3 , 7 ]
358
469
),
359
470
}
471
+ self ._virtual_table_constraints = {
472
+ "table_0" : ParameterConstraints (
473
+ min_partition = 20 ,
474
+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
475
+ ),
476
+ "table_1" : ParameterConstraints (
477
+ min_partition = 20 ,
478
+ pooling_factors = [1 , 3 , 5 ],
479
+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
480
+ ),
481
+ "table_2" : ParameterConstraints (
482
+ min_partition = 20 ,
483
+ pooling_factors = [8 , 2 ],
484
+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
485
+ ),
486
+ "table_3" : ParameterConstraints (
487
+ min_partition = 40 ,
488
+ pooling_factors = [2 , 1 , 3 , 7 ],
489
+ compute_kernels = [EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ],
490
+ ),
491
+ }
360
492
self .num_tables = 4
361
493
tables = [
362
494
EmbeddingBagConfig (
@@ -367,6 +499,17 @@ def setUp(self) -> None:
367
499
)
368
500
for i in range (self .num_tables )
369
501
]
502
+ tables_with_buckets = [
503
+ EmbeddingBagConfig (
504
+ num_embeddings = 100 + i * 10 ,
505
+ embedding_dim = 20 + i * 20 ,
506
+ name = "table_" + str (i ),
507
+ feature_names = ["feature_" + str (i )],
508
+ total_num_buckets = 10 ,
509
+ use_virtual_table = True ,
510
+ )
511
+ for i in range (self .num_tables )
512
+ ]
370
513
weighted_tables = [
371
514
EmbeddingBagConfig (
372
515
num_embeddings = (i + 1 ) * 10 ,
@@ -377,6 +520,9 @@ def setUp(self) -> None:
377
520
for i in range (4 )
378
521
]
379
522
self .model = TestSparseNN (tables = tables , weighted_tables = [])
523
+ self .model_with_buckets = EmbeddingBagCollection (
524
+ tables = tables_with_buckets ,
525
+ )
380
526
self .enumerator = EmbeddingEnumerator (
381
527
topology = Topology (
382
528
world_size = self .world_size ,
@@ -386,6 +532,15 @@ def setUp(self) -> None:
386
532
batch_size = self .batch_size ,
387
533
constraints = self .constraints ,
388
534
)
535
+ self .virtual_table_enumerator = EmbeddingEnumerator (
536
+ topology = Topology (
537
+ world_size = self .world_size ,
538
+ compute_device = self .compute_device ,
539
+ local_world_size = self .local_world_size ,
540
+ ),
541
+ batch_size = self .batch_size ,
542
+ constraints = self ._virtual_table_constraints ,
543
+ )
389
544
self .tower_model = TestTowerSparseNN (
390
545
tables = tables , weighted_tables = weighted_tables
391
546
)
@@ -514,6 +669,26 @@ def test_rw_sharding(self) -> None:
514
669
EXPECTED_RW_SHARD_STORAGE [i ],
515
670
)
516
671
672
+ def test_virtual_table_rw_sharding_with_buckets (self ) -> None :
673
+ sharding_options = self .virtual_table_enumerator .enumerate (
674
+ self .model_with_buckets ,
675
+ [cast (ModuleSharder [torch .nn .Module ], VirtualTableRWSharder ())],
676
+ )
677
+ for i , sharding_option in enumerate (sharding_options ):
678
+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
679
+ self .assertEqual (
680
+ [shard .size for shard in sharding_option .shards ],
681
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
682
+ )
683
+ self .assertEqual (
684
+ [shard .offset for shard in sharding_option .shards ],
685
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
686
+ )
687
+ self .assertEqual (
688
+ [shard .storage for shard in sharding_option .shards ],
689
+ EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
690
+ )
691
+
517
692
def test_uvm_caching_rw_sharding (self ) -> None :
518
693
sharding_options = self .enumerator .enumerate (
519
694
self .model ,
@@ -535,6 +710,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
535
710
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE [i ],
536
711
)
537
712
713
+ def test_uvm_caching_rw_sharding_with_buckets (self ) -> None :
714
+ sharding_options = self .enumerator .enumerate (
715
+ self .model_with_buckets ,
716
+ [cast (ModuleSharder [torch .nn .Module ], UVMCachingRWSharder ())],
717
+ )
718
+ for i , sharding_option in enumerate (sharding_options ):
719
+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
720
+ self .assertEqual (
721
+ [shard .size for shard in sharding_option .shards ],
722
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
723
+ )
724
+ self .assertEqual (
725
+ [shard .offset for shard in sharding_option .shards ],
726
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
727
+ )
728
+ self .assertEqual (
729
+ [shard .storage for shard in sharding_option .shards ],
730
+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
731
+ )
732
+
538
733
def test_twrw_sharding (self ) -> None :
539
734
sharding_options = self .enumerator .enumerate (
540
735
self .model , [cast (ModuleSharder [torch .nn .Module ], TWRWSharder ())]
0 commit comments