From db228c51d6590683688c91353b4eff199f428712 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 15 Aug 2024 12:17:05 -0700 Subject: [PATCH] Add configuration knob for ENSEMBLE_ROWWISE_ADAGRAD, frontend (#2955) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2955 X-link: https://github.com/facebookresearch/FBGEMM/pull/55 - Add configuration knob for ENSEMBLE_ROWWISE_ADAGRAD, frontend Reviewed By: spcyppt Differential Revision: D60986449 --- fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py | 1 + .../split_table_batched_embeddings_ops_training.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 15ce4fe114..365aebbfce 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum): SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables MADGRAD = "madgrad" EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated + ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad" NONE = "none" def __str__(self) -> str: diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index b2b0fc2591..987ff990b8 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -25,7 +25,7 @@ import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers -# from fbgemm_gpu.config import FeatureGateName +from fbgemm_gpu.config import FeatureGate, FeatureGateName from fbgemm_gpu.runtime_monitor import ( AsyncSeriesTimer, TBEStatsReporter, @@ -1549,6 +1549,17 @@ def forward( # noqa: C901 offsets=self.row_counter_offsets, placements=self.row_counter_placements, ) + + if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: + if FeatureGate.is_enabled(FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD): + raise AssertionError( + "ENSEMBLE_ROWWISE_ADAGRAD feature has not landed yet (see D60189486 stack)" + ) + else: + raise AssertionError( + "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!" + ) + if self._used_rowwise_adagrad_with_counter: if ( self._max_counter_update_freq > 0