From 0b1739c5321a1d4406ca1642048e122952ec46fa Mon Sep 17 00:00:00 2001 From: Fei Yu Date: Wed, 18 Dec 2024 14:52:06 -0800 Subject: [PATCH] Support config based bound check version via extended modes (#3454) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3454 X-link: https://github.com/facebookresearch/FBGEMM/pull/538 2/2 of enabling bounds check V2 for APS FM, following APS principles, we would like to surface the V2 switch up to the APS user config, hence in this diff we are extending existing BoundsCheckMode with V2 counterparts, and pass the version flag into the operator. this diff enabled v2 via backward compatible modes update with V2 prefix which is intuitive for user to switch More context can be found in https://docs.google.com/document/d/1hEhk2isMOXuWPyQJxiOzNq0ivfECsZUT7kT_IBmou_I/edit?tab=t.0#heading=h.q89rllowo3eb Reviewed By: sryap Differential Revision: D66512098 fbshipit-source-id: d2181a82462ca1c2c93360d4108766edeb38d000 --- .../split_table_batched_embeddings_ops_common.py | 6 ++++++ ...split_table_batched_embeddings_ops_training.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 069f66b02..82e9c9f06 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -67,6 +67,12 @@ class BoundsCheckMode(enum.IntEnum): IGNORE = 2 # No bounds checks. NONE = 3 + # IGNORE with V2 enabled + V2_IGNORE = 4 + # WARNING with V2 enabled + V2_WARNING = 5 + # FATAL with V2 enabled + V2_FATAL = 6 class EmbeddingSpecInfo(enum.IntEnum): diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d8667abe0..85ebd69f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -638,6 +638,20 @@ def __init__( # noqa C901 self.pooling_mode = pooling_mode self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE # If environment variable is set, it overwrites the default bounds check mode. + self.bounds_check_version: int = 1 + if bounds_check_mode.name.startswith("V2_"): + self.bounds_check_version = 2 + if bounds_check_mode == BoundsCheckMode.V2_IGNORE: + bounds_check_mode = BoundsCheckMode.IGNORE + elif bounds_check_mode == BoundsCheckMode.V2_WARNING: + bounds_check_mode = BoundsCheckMode.WARNING + elif bounds_check_mode == BoundsCheckMode.V2_FATAL: + bounds_check_mode = BoundsCheckMode.FATAL + else: + raise NotImplementedError( + f"Did not recognize V2 bounds check mode: {bounds_check_mode}" + ) + self.bounds_check_mode_int: int = int( os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) ) @@ -3352,6 +3366,7 @@ def prepare_inputs( b_t_map=b_t_map, info_B_num_bits=info_B_num_bits, info_B_mask=info_B_mask, + bounds_check_version=self.bounds_check_version, ) return indices, offsets, per_sample_weights, vbe_metadata