diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py index 4ef4e1bd9d..3ee30b5552 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py @@ -7,4 +7,7 @@ # pyre-unsafe -from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401 +from .split_embeddings_cache_ops import SplitEmbeddingsCacheOpsRegistry # noqa: F401 + +# Register ops in `torch.ops.fbgemm` +SplitEmbeddingsCacheOpsRegistry.register() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py index 37d04bc24e..25832f9ed1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py @@ -6,24 +6,12 @@ # pyre-unsafe +import logging from typing import Optional, Tuple, Union import torch -lib = torch.library.Library("fbgemm", "FRAGMENT") -lib.define( - """ - get_unique_indices_v2( - Tensor linear_indices, - int max_indices, - bool compute_count=False, - bool compute_inverse_indices=False - ) -> (Tensor, Tensor, Tensor?, Tensor?) - """ -) - -@torch.library.impl(lib, "get_unique_indices_v2", "CUDA") def get_unique_indices_v2( linear_indices: torch.Tensor, max_indices: int, @@ -60,3 +48,37 @@ def get_unique_indices_v2( return ret[0], ret[1], ret[3] # Return (unique_indices, length) return ret[:-2] + + +class SplitEmbeddingsCacheOpsRegistry: + init = False + + @staticmethod + def register(): + """ + Register ops in `torch.ops.fbgemm` + """ + if not SplitEmbeddingsCacheOpsRegistry.init: + logging.info("Register split_embeddings_cache_ops") + + for op_name, op_def, op_fn in ( + ( + "get_unique_indices_v2", + ( + "(" + " Tensor linear_indices, " + " int max_indices, " + " bool compute_count=False, " + " bool compute_inverse_indices=False" + ") -> (Tensor, Tensor, Tensor?, Tensor?)" + ), + get_unique_indices_v2, + ), + ): + fbgemm_op_name = "fbgemm::" + op_name + if fbgemm_op_name not in torch.library._defs: + # Define and register op + torch.library.define(fbgemm_op_name, op_def) + torch.library.impl(fbgemm_op_name, "CUDA", op_fn) + + SplitEmbeddingsCacheOpsRegistry.init = True diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 41c99c9675..7bc0532e46 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -15,7 +15,6 @@ import tempfile from math import log2 from typing import Any, Callable, List, Optional, Tuple, Type - import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -34,6 +33,7 @@ UVMCacheStatsIndex, WeightDecayMode, ) +from ..cache import * # noqa F403 from torch import distributed as dist, nn, Tensor # usort:skip from torch.autograd.profiler import record_function diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py index f6d4abe57f..df38e15de1 100644 --- a/fbgemm_gpu/test/tbe/common.py +++ b/fbgemm_gpu/test/tbe/common.py @@ -10,7 +10,6 @@ from typing import List, Tuple import fbgemm_gpu -import fbgemm_gpu.tbe.cache # noqa: F401 import numpy as np import torch from hypothesis import settings, Verbosity