Skip to content

Commit

Permalink
Support config based bound check version via extended modes (#3454)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3454

X-link: facebookresearch/FBGEMM#538

2/2 of enabling bounds check V2 for APS FM, following APS principles, we would like to surface the V2 switch up to the APS user config, hence in this diff we are extending existing BoundsCheckMode with V2 counterparts, and pass the version flag into the operator.

this diff enabled v2 via backward compatible modes update with V2 prefix which is intuitive for user to switch

More context can be found in https://docs.google.com/document/d/1hEhk2isMOXuWPyQJxiOzNq0ivfECsZUT7kT_IBmou_I/edit?tab=t.0#heading=h.q89rllowo3eb

Reviewed By: sryap

Differential Revision: D66512098

fbshipit-source-id: d2181a82462ca1c2c93360d4108766edeb38d000
  • Loading branch information
Fei Yu authored and facebook-github-bot committed Dec 18, 2024
1 parent 62f9db7 commit 0b1739c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ class BoundsCheckMode(enum.IntEnum):
IGNORE = 2
# No bounds checks.
NONE = 3
# IGNORE with V2 enabled
V2_IGNORE = 4
# WARNING with V2 enabled
V2_WARNING = 5
# FATAL with V2 enabled
V2_FATAL = 6


class EmbeddingSpecInfo(enum.IntEnum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,20 @@ def __init__( # noqa C901
self.pooling_mode = pooling_mode
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
# If environment variable is set, it overwrites the default bounds check mode.
self.bounds_check_version: int = 1
if bounds_check_mode.name.startswith("V2_"):
self.bounds_check_version = 2
if bounds_check_mode == BoundsCheckMode.V2_IGNORE:
bounds_check_mode = BoundsCheckMode.IGNORE
elif bounds_check_mode == BoundsCheckMode.V2_WARNING:
bounds_check_mode = BoundsCheckMode.WARNING
elif bounds_check_mode == BoundsCheckMode.V2_FATAL:
bounds_check_mode = BoundsCheckMode.FATAL
else:
raise NotImplementedError(
f"Did not recognize V2 bounds check mode: {bounds_check_mode}"
)

self.bounds_check_mode_int: int = int(
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
)
Expand Down Expand Up @@ -3352,6 +3366,7 @@ def prepare_inputs(
b_t_map=b_t_map,
info_B_num_bits=info_B_num_bits,
info_B_mask=info_B_mask,
bounds_check_version=self.bounds_check_version,
)

return indices, offsets, per_sample_weights, vbe_metadata
Expand Down

0 comments on commit 0b1739c

Please sign in to comment.