Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ifu 2023 12 06 #49

Merged
merged 74 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
b61ae74
Meta functions for dynamic_shapes for block_bucketize_sparse_features…
Nov 3, 2023
002d72b
Namespace doesn't need to be followed by semicolon
r-barnes Nov 6, 2023
6000a5c
Revert D51029740: Namespace doesn't need to be followed by semicolon
Nov 6, 2023
174d473
Revert D50941763: Multisect successfully blamed "D50941763: [fbgemm_g…
Nov 6, 2023
80990a6
Add BF16 support for reorder_batched_ad_indices (#2116)
lequytra Nov 6, 2023
21d0c95
Fix OSS issues with D50741802 (#2117)
q10 Nov 7, 2023
b6bdf04
Update impl_abstract_pystub to be less boilerplatey
zou3519 Nov 7, 2023
b8f6beb
Remove test_aot_dispatch_static* tests from opcheck tests (#2118)
zou3519 Nov 7, 2023
308ffdc
Revert D50972148: Update impl_abstract_pystub to be less boilerplatey
zou3519 Nov 7, 2023
0457bb7
Derandomize all tests in fbgemm_gpu
ezyang Nov 7, 2023
0def4d8
Update impl_abstract_pystub to be less boilerplatey
zou3519 Nov 8, 2023
9b13b5a
SymIntify {}_embedding{}_codegen_forward_{}{}_cuda meta function
ezyang Nov 8, 2023
e5236c8
SymIntify {}_embedding{}_codegen_forward_{}{}_cuda autograd function
ezyang Nov 8, 2023
bd2d5fc
Fix meta function for merge_pooled_embeddings
Nov 8, 2023
79d1729
Refactor `embedding_inplace_update` (#2112)
q10 Nov 9, 2023
5bddcac
Back out "Refactor `embedding_inplace_update`" (#2125)
q10 Nov 9, 2023
293e500
Refactor `embedding_inplace_update` (#2127)
q10 Nov 9, 2023
2117dd3
add pt2_compliant tag to some ops (#2119)
zou3519 Nov 10, 2023
09ab470
Move permute_sparse_features tests and abstract impl to fbgemm (#2129)
williamwen42 Nov 14, 2023
975cb01
Add impl_abstract to segment_sum_csr (#2132)
Microve Nov 15, 2023
abb59a3
Fix fbgemm CI for segment_sum_csr (#2137)
tissue3 Nov 15, 2023
2eb3eb4
Support variable bucket size for block_bucketize_sparse_features (#2107)
tissue3 Nov 16, 2023
528f24d
Add test for fbgemm ops. (#2136)
tissue3 Nov 16, 2023
5da5a16
Re-organize layout_transform_ops (#2133)
q10 Nov 16, 2023
f1bbb60
Add opcheck tests to parts of quantize_ops_test.py (#2139)
williamwen42 Nov 16, 2023
a142e20
Add an auto-vectorization implementation for int4 CPU TBE kernel (#2077)
excelle08 Nov 17, 2023
ad2aca2
Fix illegal memory acesss error on fp8 quantize kernel (#2131)
spcyppt Nov 18, 2023
92388c1
RW Dist change to support uneven sharding (#2142)
gnahzg Nov 18, 2023
4c0fad5
Re-organize ssd_split_embeddings_cache (#2141)
q10 Nov 19, 2023
f65d7e2
Benchmark block_bucketize_sparse_features uneven sharding (#2140)
tissue3 Nov 20, 2023
54340d4
Add variable batch per feature support to EBC (tw/cw) (#1986)
joshuadeng Nov 20, 2023
a5edc61
Back out "Add an auto-vectorization implementation for int4 CPU TBE k…
Nov 21, 2023
8a015d3
Update GitHub checkout actions (#2146)
q10 Nov 21, 2023
2b3f861
Back out "RW Dist change to support uneven sharding"
gnahzg Nov 21, 2023
e778793
Re-organize int8_ops (#2145)
q10 Nov 21, 2023
5436320
Update GitHub checkout actions, pt2 (#2149)
q10 Nov 21, 2023
37111f5
Recompute linear_cache_indices for pipeline prefetching (#2147)
sryap Nov 21, 2023
f49dea6
Disable @optests.generate_opcheck_tests in TBE unit tests (#2152)
sryap Nov 22, 2023
84c7b27
Refactor embedding_bounds_check (#2155)
q10 Nov 23, 2023
b40f419
Back out "Refactor embedding_bounds_check" (#2156)
Nov 24, 2023
934d881
Remove unused pyre-ignore in TBE tests (#2162)
sryap Nov 28, 2023
753bc10
Fix run-lint issue in OSS (#2161)
sryap Nov 28, 2023
ba2b921
Annotate unused params (#2158)
q10 Nov 28, 2023
18af2b2
Add/modify LXU cache lookup ops for pipeline prefetching (#2154)
sryap Nov 28, 2023
ca1da75
Add unit test for unique cache lookup (#2160)
sryap Nov 28, 2023
035ed1f
Use unique cache locations in backward for pipeline prefetching (#2151)
sryap Nov 28, 2023
e62a5e2
use memcpy for cpu emb inplace update (#2166)
842974287 Nov 28, 2023
de731af
Preparatory fixes & lint suppressions for c10::optional->std::optiona…
swolchok Nov 28, 2023
49e7536
Refactor embedding_bounds_check (#2165)
q10 Nov 28, 2023
91a600a
Add warmup_runs to TBE benchmarks and run at least 1 warmup iter (#2163)
sryap Nov 29, 2023
cb7357a
Stop using excess memory in generate_opcheck_tests, re-enable fbgemm …
zou3519 Nov 29, 2023
886bf42
Revert D51607319: Refactor embedding_bounds_check
Nov 29, 2023
c6e3fa2
RW Dist change to support uneven sharding [1] FBGEMM changes (#2168)
gnahzg Nov 29, 2023
98932d6
Make fbgemm::masked_select_jagged_1d pt2_compliant (#2174)
zou3519 Nov 29, 2023
2457605
Add generate_opcheck_tests to input_combine_test (#2173)
zou3519 Nov 29, 2023
71e496a
Make fbgemm::tbe_input_combine pt2_compliant (#2172)
zou3519 Nov 29, 2023
90bb32d
Mark some more ops as pt2_compliant_tag (#2171)
zou3519 Nov 29, 2023
911aec4
suppress errors in `deeplearning/fbgemm/fbgemm_gpu` (#2159)
grievejia Nov 29, 2023
6c29dbd
Initialize empty values for fp8 quantize op (#2176)
spcyppt Nov 30, 2023
48120da
set strict as default typing mode in `deeplearning/fbgemm/fbgemm_gpu`…
Nov 30, 2023
453c80e
Update AVX2 and AVX512 flags (#2167)
q10 Nov 30, 2023
63d1198
Revert D51647391: Multisect successfully blamed "D51647391: Mark some…
Nov 30, 2023
1c40928
Make fbgemm::jagged_index_select pt2_compliant (#2170)
zou3519 Nov 30, 2023
1c93072
Make fbgemm::permute_1D_sparse_data, permute_2D_sparse_data pt2_compl…
zou3519 Nov 30, 2023
3d477a0
Mark some more ops as pt2_compliant (#2181)
zou3519 Nov 30, 2023
121c20b
set strict as default typing mode in `deeplearning/fbgemm/fbgemm_gpu`…
Nov 30, 2023
327dcf9
Get fbgemm::FloatToFP8RowwiseQuantized opcheck tests passing
williamwen42 Nov 30, 2023
bf980a7
Refactor embedding_bounds_check (#2178)
q10 Nov 30, 2023
c58679a
Revert D51688407: Refactor embedding_bounds_check
Dec 1, 2023
0fc0d4e
Remove indices and offsets copying from prefetch (#2186)
sryap Dec 1, 2023
d5cdefd
Add GPU kernel to support variable bucket size in block_bucketize_spa…
gnahzg Dec 1, 2023
88fc6e7
Fix block_bucketize_features with variable bucket size when using tor…
gnahzg Dec 1, 2023
dbc3157
Benchmark block_bucketize_sparse_features uneven sharding for GPU (#2…
gnahzg Dec 4, 2023
5a2c43e
Merge remote-tracking branch 'upstream/main' into IFU-2023-12-06
liligwu Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/fbgemm_gpu_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ build_fbgemm_gpu_package () {
--package_name="${package_name}" \
--python-tag="${python_tag}" \
--plat-name="${plat_name}" \
--verbose \
"${build_args[@]}"

# Run checks on the built libraries
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/fbgemm_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -86,7 +86,7 @@ jobs:

steps:
- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -127,7 +127,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -159,7 +159,7 @@ jobs:

steps:
- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/fbgemm_gpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -126,7 +126,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -191,7 +191,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/fbgemm_gpu_cpu_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -136,7 +136,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/fbgemm_gpu_cpu_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -133,7 +133,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/fbgemm_gpu_cuda_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo tar wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down Expand Up @@ -140,6 +140,7 @@ jobs:
needs: build_artifact

steps:
# Cannot upgrade to actions/checkout@v4 yet because GLIBC on the instance is too old
- name: Checkout the Repository
uses: actions/checkout@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/fbgemm_gpu_cuda_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo tar wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/fbgemm_gpu_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils rsync sudo tar wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
submodules: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/fbgemm_gpu_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:

steps:
- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Setup Miniconda
run: . $PRELUDE; setup_miniconda $HOME/miniconda
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/fbgemm_gpu_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Display System Info
run: . $PRELUDE; print_system_info; print_ec2_info
Expand Down Expand Up @@ -116,6 +116,7 @@ jobs:
cuda-version-publish: [ "11.8.0" ]

steps:
# Cannot upgrade to actions/checkout@v4 yet because GLIBC on the instance is too old
- name: Checkout the Repository
uses: actions/checkout@v3

Expand Down Expand Up @@ -182,7 +183,7 @@ jobs:
git config --global --add safe.directory '*'

- name: Checkout the Repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Display System Info
run: . $PRELUDE; print_system_info
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
# found in:
# https://github.com/github/gitignore/

# General
.DS_Store
*~

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
53 changes: 34 additions & 19 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,22 @@ else()
DEPENDS "${optimizer_codegen_dependencies}")
endif()

set(AVX2_FLAGS "-mavx2;-mf16c;-mfma;-fopenmp")
if(NOT FBGEMM_CPU_ONLY AND WSL_MODE)
# NVCC in WSL complains about unknown -mavx options
# https://github.com/pytorch/FBGEMM/issues/2135
set(AVX2_FLAGS "-Xcompiler;-mavx;-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-fopenmp")
endif()

set(AVX512_FLAGS "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl;-fopenmp")
if(NOT FBGEMM_CPU_ONLY AND WSL_MODE)
set(AVX512_FLAGS "-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-Xcompiler;-mavx512f;-Xcompiler;-mavx512bw;-Xcompiler;-mavx512dq;-Xcompiler;-mavx512vl;-fopenmp")
endif()

if(CXX_AVX2_FOUND)
set_source_files_properties(${gen_cpu_source_files}
PROPERTIES COMPILE_OPTIONS
"-mavx2;-mf16c;-mfma;-fopenmp")
"${AVX2_FLAGS}")
else()
set_source_files_properties(${gen_cpu_source_files}
PROPERTIES COMPILE_OPTIONS
Expand Down Expand Up @@ -504,13 +516,13 @@ set(fbgemm_sources_avx512
if(CXX_AVX2_FOUND)
set_source_files_properties(${fbgemm_sources_avx2}
PROPERTIES COMPILE_OPTIONS
"-mavx2;-mf16c;-mfma")
"${AVX2_FLAGS}")
endif()

if(CXX_AVX512_FOUND)
set_source_files_properties(${fbgemm_sources_avx512}
PROPERTIES COMPILE_OPTIONS
"-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl")
"${AVX512_FLAGS}")
endif()

set(fbgemm_sources ${fbgemm_sources_normal})
Expand Down Expand Up @@ -561,19 +573,20 @@ set(fbgemm_gpu_sources_static_cpu
codegen/embedding_forward_quantized_host_cpu.cpp
codegen/embedding_backward_dense_host_cpu.cpp
codegen/embedding_bounds_check_host_cpu.cpp
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
src/input_combine_cpu.cpp
src/layout_transform_ops_cpu.cpp
src/input_combine_ops/input_combine_cpu.cpp
src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
src/sparse_ops/sparse_ops_cpu.cpp
src/sparse_ops/sparse_ops_meta.cpp
src/embedding_inplace_update_cpu.cpp
src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
src/split_embeddings_cache/linearize_cache_indices.cpp
src/split_embeddings_cache/lfu_cache_populate_byte.cpp
src/split_embeddings_cache/lru_cache_populate_byte.cpp
Expand All @@ -588,16 +601,16 @@ if(NOT FBGEMM_CPU_ONLY)
codegen/embedding_bounds_check_host.cpp
src/memory_utils/memory_utils.cpp
src/memory_utils/memory_utils_ops.cpp
src/layout_transform_ops_gpu.cpp
src/layout_transform_ops/layout_transform_ops_gpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp
src/quantize_ops/quantize_ops_gpu.cpp
src/sparse_ops/sparse_ops_gpu.cpp
src/split_embeddings_utils.cpp
src/split_embeddings_utils/split_embeddings_utils.cpp
src/split_embeddings_cache/split_embeddings_cache_ops.cu
src/metric_ops_host.cpp
src/embedding_inplace_update_gpu.cpp
src/input_combine_gpu.cpp
src/metric_ops/metric_ops_host.cpp
src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp
src/input_combine_ops/input_combine_gpu.cpp
codegen/batch_index_select_dim0_host.cpp)

if(NVML_LIB_PATH)
Expand All @@ -607,8 +620,7 @@ if(NOT FBGEMM_CPU_ONLY)
if(NVML_LIB_PATH OR USE_ROCM)
message(STATUS "Adding merge_pooled_embeddings sources")
list(APPEND fbgemm_gpu_sources_static_cpu
src/merge_pooled_embeddings_cpu.cpp
src/merge_pooled_embeddings_gpu.cpp
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp
src/topology_utils.cpp)
else()
message(STATUS "Skipping merge_pooled_embeddings sources")
Expand All @@ -618,7 +630,7 @@ endif()
if(CXX_AVX2_FOUND)
set_source_files_properties(${fbgemm_gpu_sources_static_cpu}
PROPERTIES COMPILE_OPTIONS
"-mavx;-mf16c;-mfma;-mavx2;-fopenmp")
"${AVX2_FLAGS}")
else()
set_source_files_properties(${fbgemm_gpu_sources_static_cpu}
PROPERTIES COMPILE_OPTIONS
Expand All @@ -631,9 +643,9 @@ if(NOT FBGEMM_CPU_ONLY)
codegen/embedding_forward_quantized_split_lookup.cu
src/memory_utils/memory_utils.cu
src/memory_utils/memory_utils_ops.cu
src/embedding_inplace_update.cu
src/embedding_inplace_ops/embedding_inplace_update.cu
src/histogram_binning_calibration_ops.cu
src/input_combine.cu
src/input_combine_ops/input_combine.cu
src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu
src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu
src/jagged_tensor_ops/dense_to_jagged_forward.cu
Expand All @@ -651,8 +663,8 @@ if(NOT FBGEMM_CPU_ONLY)
src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu
src/jagged_tensor_ops/jagged_unique_indices.cu
src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
src/layout_transform_ops.cu
src/metric_ops.cu
src/layout_transform_ops/layout_transform_ops.cu
src/metric_ops/metric_ops.cu
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
src/quantize_ops/quantize_bfloat16.cu
Expand Down Expand Up @@ -691,7 +703,10 @@ if(NOT FBGEMM_CPU_ONLY)
src/split_embeddings_cache/lxu_cache.cu
src/split_embeddings_cache/linearize_cache_indices.cu
src/split_embeddings_cache/reset_weight_momentum.cu
src/split_embeddings_utils.cu)
src/split_embeddings_utils/generate_vbe_metadata.cu
src/split_embeddings_utils/get_infos_metadata.cu
src/split_embeddings_utils/radix_sort_pairs.cu
src/split_embeddings_utils/transpose_embedding_input.cu)

set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
PROPERTIES COMPILE_OPTIONS
Expand Down
7 changes: 5 additions & 2 deletions fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import functools
from math import sqrt
from typing import List, Tuple
Expand All @@ -29,7 +28,10 @@


def generate_unary_feature(
batch_size: int, num_embeddings: int
batch_size: int,
num_embeddings: int
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List[<element type>]` to avoid runtime subscripting errors.
) -> Tuple[List, List, List]:
lengths = []
offsets = []
Expand Down Expand Up @@ -90,6 +92,7 @@ def forward(
@click.option("--num-tables", default=2)
@click.option("--num-tasks", default=3)
@click.option("--repeats", default=100)
# pyre-fixme[2]: Parameter must be annotated.
def main(batch_size, num_tables, num_tasks, repeats) -> None:
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def benchmark_torch_function( # noqa: C901
copy_f_for_multi_thread_test: bool = False,
) -> Tuple[float, torch.Tensor]:
logging.info(f"Start to benchmark {name}...")
if device != "" and device != "cuda":
if device != "cpu" and device != "" and device != "cuda":
torch.cuda.set_device(device)
for _ in range(num_warmups):
output = f(*args)

assert num_threads > 0
if torch.cuda.is_available() and (num_threads == 1):
if device != "cpu" and torch.cuda.is_available() and (num_threads == 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
Expand All @@ -69,7 +69,7 @@ def benchmark_torch_function( # noqa: C901
[s.elapsed_time(e) for s, e in zip(start_event, end_event)]
)
elapsed_time = torch.mean(times).item() * 1.0e-3
elif torch.cuda.is_available() and (num_threads > 1):
elif device != "cpu" and torch.cuda.is_available() and (num_threads > 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
Expand Down Expand Up @@ -156,6 +156,10 @@ def benchmark_requests(
) -> float:
times = []

# Run at least one warmup iteration to avoid the long cudaLaunchKernel time
# for the first kernel
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1

if num_warmups > 0:
indices, offsets, weights = requests[0]
for _ in range(num_warmups):
Expand Down
Loading
Loading