Skip to content

Commit f89443a

Browse files
q10facebook-github-bot
authored andcommitted
Add configuration knob for ENSEMBLE_ROWWISE_ADAGRAD, frontend (#2955)
Summary: Pull Request resolved: #2955 X-link: facebookresearch/FBGEMM#55 - Add configuration knob for ENSEMBLE_ROWWISE_ADAGRAD, frontend Reviewed By: spcyppt Differential Revision: D60986449
1 parent 6d3c2fe commit f89443a

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum):
3333
SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables
3434
MADGRAD = "madgrad"
3535
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
36+
ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
3637
NONE = "none"
3738

3839
def __str__(self) -> str:

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
2727

28-
# from fbgemm_gpu.config import FeatureGateName
28+
from fbgemm_gpu.config import FeatureGate, FeatureGateName
2929
from fbgemm_gpu.runtime_monitor import (
3030
AsyncSeriesTimer,
3131
TBEStatsReporter,
@@ -1331,6 +1331,12 @@ def _generate_vbe_metadata(
13311331
self.current_device,
13321332
)
13331333

1334+
@torch.jit.ignore
1335+
def _feature_is_enabled(self, feature: FeatureGateName) -> bool:
1336+
# Define proxy method so that it can be marked with @torch.jit.ignore
1337+
# This allows models using this class to compile correctly
1338+
return FeatureGate.is_enabled(feature)
1339+
13341340
def forward( # noqa: C901
13351341
self,
13361342
indices: Tensor,
@@ -1549,6 +1555,17 @@ def forward( # noqa: C901
15491555
offsets=self.row_counter_offsets,
15501556
placements=self.row_counter_placements,
15511557
)
1558+
1559+
if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
1560+
if self._feature_is_enabled(FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD):
1561+
raise AssertionError(
1562+
"ENSEMBLE_ROWWISE_ADAGRAD feature has not landed yet (see D60189486 stack)"
1563+
)
1564+
else:
1565+
raise AssertionError(
1566+
"ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!"
1567+
)
1568+
15521569
if self._used_rowwise_adagrad_with_counter:
15531570
if (
15541571
self._max_counter_update_freq > 0

0 commit comments

Comments
 (0)