File tree 2 files changed +19
-1
lines changed 2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum):
33
33
SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
34
34
MADGRAD = "madgrad"
35
35
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
36
+ ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
36
37
NONE = "none"
37
38
38
39
def __str__ (self ) -> str :
Original file line number Diff line number Diff line change 25
25
26
26
import fbgemm_gpu .split_embedding_codegen_lookup_invokers as invokers
27
27
28
- # from fbgemm_gpu.config import FeatureGateName
28
+ from fbgemm_gpu .config import FeatureGate , FeatureGateName
29
29
from fbgemm_gpu .runtime_monitor import (
30
30
AsyncSeriesTimer ,
31
31
TBEStatsReporter ,
@@ -1331,6 +1331,12 @@ def _generate_vbe_metadata(
1331
1331
self .current_device ,
1332
1332
)
1333
1333
1334
+ @torch .jit .ignore
1335
+ def _feature_is_enabled (self , feature : FeatureGateName ) -> bool :
1336
+ # Define proxy method so that it can be marked with @torch.jit.ignore
1337
+ # This allows models using this class to compile correctly
1338
+ return FeatureGate .is_enabled (feature )
1339
+
1334
1340
def forward ( # noqa: C901
1335
1341
self ,
1336
1342
indices : Tensor ,
@@ -1549,6 +1555,17 @@ def forward( # noqa: C901
1549
1555
offsets = self .row_counter_offsets ,
1550
1556
placements = self .row_counter_placements ,
1551
1557
)
1558
+
1559
+ if self .optimizer == OptimType .ENSEMBLE_ROWWISE_ADAGRAD :
1560
+ if self ._feature_is_enabled (FeatureGateName .TBE_ENSEMBLE_ROWWISE_ADAGRAD ):
1561
+ raise AssertionError (
1562
+ "ENSEMBLE_ROWWISE_ADAGRAD feature has not landed yet (see D60189486 stack)"
1563
+ )
1564
+ else :
1565
+ raise AssertionError (
1566
+ "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!"
1567
+ )
1568
+
1552
1569
if self ._used_rowwise_adagrad_with_counter :
1553
1570
if (
1554
1571
self ._max_counter_update_freq > 0
You can’t perform that action at this time.
0 commit comments