File tree 2 files changed +13
-1
lines changed 2 files changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum):
33
33
SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
34
34
MADGRAD = "madgrad"
35
35
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
36
+ ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
36
37
NONE = "none"
37
38
38
39
def __str__ (self ) -> str :
Original file line number Diff line number Diff line change 25
25
26
26
import fbgemm_gpu .split_embedding_codegen_lookup_invokers as invokers
27
27
28
- # from fbgemm_gpu.config import FeatureGateName
28
+ from fbgemm_gpu .config import FeatureGate , FeatureGateName
29
29
from fbgemm_gpu .runtime_monitor import (
30
30
AsyncSeriesTimer ,
31
31
TBEStatsReporter ,
@@ -1549,6 +1549,17 @@ def forward( # noqa: C901
1549
1549
offsets = self .row_counter_offsets ,
1550
1550
placements = self .row_counter_placements ,
1551
1551
)
1552
+
1553
+ if self .optimizer == OptimType .ENSEMBLE_ROWWISE_ADAGRAD :
1554
+ if FeatureGate .is_enabled (FeatureGateName .TBE_ENSEMBLE_ROWWISE_ADAGRAD ):
1555
+ raise AssertionError (
1556
+ "ENSEMBLE_ROWWISE_ADAGRAD feature has not landed yet (see D60189486 stack)"
1557
+ )
1558
+ else :
1559
+ logging .warnning (
1560
+ "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!"
1561
+ )
1562
+
1552
1563
if self ._used_rowwise_adagrad_with_counter :
1553
1564
if (
1554
1565
self ._max_counter_update_freq > 0
You can’t perform that action at this time.
0 commit comments