diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 7d1f7aaaec..bd3172c520 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -451,6 +451,8 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/backward/embedding_backward_dense_host_cpu.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/config/feature_gates.cpp + src/memory_utils/memory_utils.cpp + src/memory_utils/memory_utils_ops.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -481,9 +483,6 @@ if(NOT FBGEMM_CPU_ONLY) codegen/utils/embedding_bounds_check_host.cpp src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp src/layout_transform_ops/layout_transform_ops_gpu.cpp - src/memory_utils/memory_utils.cpp - src/memory_utils/memory_utils_ops.cpp - src/memory_utils/memory_utils_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp src/quantize_ops/quantize_ops_gpu.cpp diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f99ad00530..8564fa41c6 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -13,6 +13,7 @@ from itertools import accumulate from typing import List, Optional, Tuple, Union +import fbgemm_gpu # noqa: F401 import torch # usort:skip from torch import nn, Tensor # usort:skip diff --git a/fbgemm_gpu/src/memory_utils/memory_utils.cpp b/fbgemm_gpu/src/memory_utils/memory_utils.cpp index 4e8c0ca7e1..74a1e098c1 100644 --- a/fbgemm_gpu/src/memory_utils/memory_utils.cpp +++ b/fbgemm_gpu/src/memory_utils/memory_utils.cpp @@ -7,11 +7,18 @@ */ #include "common.h" +#include "fbgemm_gpu/cumem_utils.h" using Tensor = at::Tensor; namespace fbgemm_gpu { +Tensor new_managed_tensor_meta( + const Tensor& self, + const std::vector& sizes) { + return at::empty(sizes, self.options()); +} + Tensor new_unified_tensor_cpu( const Tensor& self, const std::vector& sizes, diff --git a/fbgemm_gpu/src/memory_utils/memory_utils.cu b/fbgemm_gpu/src/memory_utils/memory_utils.cu index f12fdd015e..5d4fd0fd88 100644 --- a/fbgemm_gpu/src/memory_utils/memory_utils.cu +++ b/fbgemm_gpu/src/memory_utils/memory_utils.cu @@ -173,12 +173,6 @@ Tensor new_managed_tensor( return t; } -Tensor new_managed_tensor_meta( - const Tensor& self, - const std::vector& sizes) { - return at::empty(sizes, self.options()); -} - // Allocate a cuda Tensor with unified managed memory (UVM) without the // additional steps taked by new_managed_tensor above Tensor new_vanilla_managed_tensor( diff --git a/fbgemm_gpu/src/memory_utils/memory_utils_ops.cpp b/fbgemm_gpu/src/memory_utils/memory_utils_ops.cpp index 1080bd18b9..4864f4270a 100644 --- a/fbgemm_gpu/src/memory_utils/memory_utils_ops.cpp +++ b/fbgemm_gpu/src/memory_utils/memory_utils_ops.cpp @@ -7,7 +7,8 @@ */ #include -#include "common.cuh" +#include "common.h" +#include "fbgemm_gpu/cumem_utils.h" #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -15,31 +16,14 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def("is_uvm_tensor(Tensor t) -> bool", TORCH_FN(is_uvm_tensor)); - m.def("uvm_storage(Tensor t) -> bool", TORCH_FN(uvm_storage)); - m.def( - "uvm_to_device(Tensor self, Tensor prototype) -> Tensor", - TORCH_FN(uvm_to_device)); - m.def("uvm_to_cpu(Tensor t) -> Tensor"); m.def("new_managed_tensor(Tensor self, int[] sizes) -> Tensor"); m.def("new_host_mapped_tensor(Tensor self, int[] sizes) -> Tensor"); m.def("new_vanilla_managed_tensor(Tensor self, int[] sizes) -> Tensor"); m.def( - "cuda_mem_advise(Tensor t, int advice) -> ()", - TORCH_FN(uvm_cuda_mem_advise)); - m.def( - "cuda_mem_prefetch_async(Tensor t, Tensor? device_t) -> ()", - TORCH_FN(uvm_cuda_mem_prefetch_async)); - m.def( - "uvm_mem_advice_dont_fork(Tensor t) -> ()", - TORCH_FN(uvm_mem_advice_dont_fork)); + "new_unified_tensor(Tensor self, int[] sizes, bool is_host_mapped) -> Tensor"); - m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); - m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); -} - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_META("new_managed_tensor", new_managed_tensor_meta); + DISPATCH_TO_CPU("new_unified_tensor", new_unified_tensor_cpu); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/memory_utils/memory_utils_ops.cu b/fbgemm_gpu/src/memory_utils/memory_utils_ops.cu index d0bba091e3..8e50e1cc53 100644 --- a/fbgemm_gpu/src/memory_utils/memory_utils_ops.cu +++ b/fbgemm_gpu/src/memory_utils/memory_utils_ops.cu @@ -14,9 +14,28 @@ namespace fbgemm_gpu { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - DISPATCH_TO_CUDA("uvm_to_cpu", uvm_to_cpu); + m.def("is_uvm_tensor(Tensor t) -> bool", TORCH_FN(is_uvm_tensor)); + m.def("uvm_storage(Tensor t) -> bool", TORCH_FN(uvm_storage)); + m.def( + "uvm_to_device(Tensor self, Tensor prototype) -> Tensor", + TORCH_FN(uvm_to_device)); + + m.def( + "cuda_mem_advise(Tensor t, int advice) -> ()", + TORCH_FN(uvm_cuda_mem_advise)); + m.def( + "cuda_mem_prefetch_async(Tensor t, Tensor? device_t) -> ()", + TORCH_FN(uvm_cuda_mem_prefetch_async)); + m.def( + "uvm_mem_advice_dont_fork(Tensor t) -> ()", + TORCH_FN(uvm_mem_advice_dont_fork)); + + m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); + m.def("uvm_to_cpu(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu)); + + m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); + DISPATCH_TO_CUDA("new_managed_tensor", new_managed_tensor); - DISPATCH_TO_META("new_managed_tensor", new_managed_tensor_meta); DISPATCH_TO_CUDA("new_host_mapped_tensor", new_host_mapped_tensor); DISPATCH_TO_CUDA("new_unified_tensor", new_unified_tensor); DISPATCH_TO_CUDA("new_vanilla_managed_tensor", new_vanilla_managed_tensor); diff --git a/fbgemm_gpu/src/memory_utils/memory_utils_ops_cpu.cpp b/fbgemm_gpu/src/memory_utils/memory_utils_ops_cpu.cpp deleted file mode 100644 index 5cdb8be25b..0000000000 --- a/fbgemm_gpu/src/memory_utils/memory_utils_ops_cpu.cpp +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include "common.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - -using Tensor = at::Tensor; - -namespace fbgemm_gpu { - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "new_unified_tensor(Tensor self, int[] sizes, bool is_host_mapped) -> Tensor"); -} - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - DISPATCH_TO_CPU("new_unified_tensor", new_unified_tensor_cpu); -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/test/config/feature_gate_test.py b/fbgemm_gpu/test/config/feature_gate_test.py index 2fe1409d28..3ec2dcb7ca 100644 --- a/fbgemm_gpu/test/config/feature_gate_test.py +++ b/fbgemm_gpu/test/config/feature_gate_test.py @@ -8,7 +8,6 @@ # pyre-unsafe import unittest -from contextlib import contextmanager # pyre-fixme[21] import fbgemm_gpu @@ -17,21 +16,17 @@ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) -if not open_source: +if open_source: + # pyre-ignore[21] + from test_utils import TestSuite + +else: # pyre-fixme[21] from fbgemm_gpu.fb.config import FeatureGateName as FbFeatureGateName + from fbgemm_gpu.test.test_utils import TestSuite -class FeatureGateTest(unittest.TestCase): - @contextmanager - # pyre-ignore[2] - def assertNotRaised(self, exc_type) -> None: - try: - # pyre-ignore[7] - yield None - except exc_type as e: - raise self.failureException(e) - +class FeatureGateTest(TestSuite): # pyre-ignore[11] def test_feature_gates(self) -> None: for feature in FeatureGateName: # pyre-ignore[16] diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 0f8ddd52d7..4195db896c 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -10,6 +10,7 @@ import os import subprocess import unittest +from contextlib import contextmanager from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -164,6 +165,17 @@ def dontGenerateOpCheckTests(reason: str): return optests.dontGenerateOpCheckTests(reason) +class TestSuite(unittest.TestCase): + @contextmanager + # pyre-ignore[2] + def assertNotRaised(self, exc_type) -> None: + try: + # pyre-ignore[7] + yield None + except exc_type as e: + raise self.failureException(e) + + # Version of torch.autograd.gradcheck that works with generate_opcheck_tests. # The problem with just torch.autograd.gradcheck is that it results in # very slow tests when composed with generate_opcheck_tests. diff --git a/fbgemm_gpu/test/uvm/ops_load_test.py b/fbgemm_gpu/test/uvm/ops_load_test.py new file mode 100644 index 0000000000..10a96ba91c --- /dev/null +++ b/fbgemm_gpu/test/uvm/ops_load_test.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[56] + +import unittest + +import fbgemm_gpu +import hypothesis.strategies as st +import torch +from hypothesis import given, settings + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + # pyre-ignore[21] + from test_utils import cpu_and_maybe_gpu, TestSuite +else: + from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, TestSuite + + +class OpsLoadTest(TestSuite): # pyre-ignore[11] + @given( + device=cpu_and_maybe_gpu(), + host_mapped=st.booleans(), + ) + @settings(deadline=None) + def test_cpu_ops(self, device: torch.device, host_mapped: bool) -> None: + with self.assertNotRaised(Exception): # pyre-ignore[16] + torch.ops.fbgemm.new_unified_tensor( + torch.zeros(1, device=device, dtype=torch.float), + [1000], + host_mapped, + ) + + +if __name__ == "__main__": + unittest.main()