Skip to content

Commit

Permalink
Fix get_unique_indices_v2 registration
Browse files Browse the repository at this point in the history
Differential Revision: D61294287
  • Loading branch information
sarunya authored and facebook-github-bot committed Aug 14, 2024
1 parent ada1050 commit 82d78bf
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 16 deletions.
5 changes: 4 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@

# pyre-unsafe

from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401
from .split_embeddings_cache_ops import SplitEmbeddingsCacheOpsRegistry # noqa: F401

# Register ops in `torch.ops.fbgemm`
SplitEmbeddingsCacheOpsRegistry.register()
48 changes: 35 additions & 13 deletions fbgemm_gpu/fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,12 @@

# pyre-unsafe

import logging
from typing import Optional, Tuple, Union

import torch

lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices_v2(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
bool compute_inverse_indices=False
) -> (Tensor, Tensor, Tensor?, Tensor?)
"""
)


@torch.library.impl(lib, "get_unique_indices_v2", "CUDA")
def get_unique_indices_v2(
linear_indices: torch.Tensor,
max_indices: int,
Expand Down Expand Up @@ -60,3 +48,37 @@ def get_unique_indices_v2(
return ret[0], ret[1], ret[3]
# Return (unique_indices, length)
return ret[:-2]


class SplitEmbeddingsCacheOpsRegistry:
init = False

@staticmethod
def register():
"""
Register ops in `torch.ops.fbgemm`
"""
if not SplitEmbeddingsCacheOpsRegistry.init:
logging.info("Register split_embeddings_cache_ops")

for op_name, op_def, op_fn in (
(
"get_unique_indices_v2",
(
"("
" Tensor linear_indices, "
" int max_indices, "
" bool compute_count=False, "
" bool compute_inverse_indices=False"
") -> (Tensor, Tensor, Tensor?, Tensor?)"
),
get_unique_indices_v2,
),
):
fbgemm_op_name = "fbgemm::" + op_name
if fbgemm_op_name not in torch.library._defs:
# Define and register op
torch.library.define(fbgemm_op_name, op_def)
torch.library.impl(fbgemm_op_name, "CUDA", op_fn)

SplitEmbeddingsCacheOpsRegistry.init = True
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import tempfile
from math import log2
from typing import Any, Callable, List, Optional, Tuple, Type

import torch # usort:skip

import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
Expand All @@ -34,6 +33,7 @@
UVMCacheStatsIndex,
WeightDecayMode,
)
from ..cache import * # noqa F403

from torch import distributed as dist, nn, Tensor # usort:skip
from torch.autograd.profiler import record_function
Expand Down
1 change: 0 additions & 1 deletion fbgemm_gpu/test/tbe/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import List, Tuple

import fbgemm_gpu
import fbgemm_gpu.tbe.cache # noqa: F401
import numpy as np
import torch
from hypothesis import settings, Verbosity
Expand Down

0 comments on commit 82d78bf

Please sign in to comment.