diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index 02e3f6e3cc..26f5330589 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -64,7 +64,7 @@ __configure_fbgemm_gpu_build_clang () { local conda_prefix=$(conda run ${env_prefix} printenv CONDA_PREFIX) # shellcheck disable=SC2206 build_args+=( - --cxxprefix ${conda_prefix} + --cxxprefix=${conda_prefix} ) } @@ -258,6 +258,11 @@ __configure_fbgemm_gpu_build () { __configure_fbgemm_gpu_build_clang fi + # Set verbosity + build_args+=( + --verbose + ) + # shellcheck disable=SC2145 echo "[BUILD] FBGEMM_GPU build arguments have been set: ${build_args[@]}" } @@ -307,7 +312,7 @@ __build_fbgemm_gpu_set_run_multicore () { export run_multicore="" if [[ $core =~ $re && $sockets =~ $re ]] ; then local n_core=$((core * sockets)) - export run_multicore=" -j ${n_core}" + export run_multicore="-j ${n_core}" fi echo "[BUILD] Set multicore run option for setup.py: ${run_multicore}" @@ -443,15 +448,26 @@ build_fbgemm_gpu_package () { echo "################################################################################" echo "" - # Distribute Python extensions as wheels on Linux + # Set packaging options + build_args+=( + --package_channel="${fbgemm_release_channel}" + --python-tag="${python_tag}" + --plat-name="${python_plat_name}" + ) + + # Prepend build options correctly for `python -m build` + # https://build.pypa.io/en/stable/index.html + # https://gregoryszorc.com/blog/2023/10/30/my-user-experience-porting-off-setup.py/ + for i in "${!build_args[@]}"; do + build_args[i]="--config-setting=--build-option=${build_args[i]}" + done + + # Build the wheel. Invoke using `python -m build` + # https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html echo "[BUILD] Building FBGEMM-GPU wheel (VARIANT=${fbgemm_variant}) ..." # shellcheck disable=SC2086 print_exec conda run --no-capture-output ${env_prefix} \ - python setup.py "${run_multicore}" bdist_wheel \ - --package_channel="${fbgemm_release_channel}" \ - --python-tag="${python_tag}" \ - --plat-name="${python_plat_name}" \ - --verbose \ + python -m build --wheel --no-isolation \ "${build_args[@]}" # Run checks on the built libraries @@ -503,7 +519,6 @@ build_fbgemm_gpu_install () { # shellcheck disable=SC2086 print_exec conda run --no-capture-output ${env_prefix} \ python setup.py "${run_multicore}" install \ - --verbose \ "${build_args[@]}" # Run checks on the built libraries @@ -519,47 +534,3 @@ build_fbgemm_gpu_install () { echo "[BUILD] FBGEMM-GPU build + install completed" } - -build_fbgemm_gpu_develop () { - env_name="$1" - fbgemm_variant="$2" - fbgemm_variant_targets="$3" - if [ "$fbgemm_variant" == "" ]; then - echo "Usage: ${FUNCNAME[0]} ENV_NAME VARIANT [TARGETS]" - echo "Example(s):" - echo " ${FUNCNAME[0]} build_env cpu # CPU-only variant" - echo " ${FUNCNAME[0]} build_env cuda # CUDA variant for default target(s)" - echo " ${FUNCNAME[0]} build_env cuda '7.0;8.0' # CUDA variant for custom target(s)" - echo " ${FUNCNAME[0]} build_env rocm # ROCm variant for default target(s)" - echo " ${FUNCNAME[0]} build_env rocm 'gfx906;gfx908;gfx90a' # ROCm variant for custom target(s)" - return 1 - fi - - # shellcheck disable=SC2155 - local env_prefix=$(env_name_or_prefix "${env_name}") - - # Set up and configure the build - __build_fbgemm_gpu_common_pre_steps || return 1 - __configure_fbgemm_gpu_build "${fbgemm_variant}" "${fbgemm_variant_targets}" || return 1 - - echo "################################################################################" - echo "# Build + Install FBGEMM-GPU Package (Develop)" - echo "#" - echo "# [$(date --utc +%FT%T.%3NZ)] + ${FUNCNAME[0]} ${*}" - echo "################################################################################" - echo "" - - # Parallelism may need to be limited to prevent the build from being - # canceled for going over ulimits - echo "[BUILD] Building (develop) FBGEMM-GPU (VARIANT=${fbgemm_variant}) ..." - # shellcheck disable=SC2086 - print_exec conda run --no-capture-output ${env_prefix} \ - python setup.py "${run_multicore}" build develop \ - --verbose \ - "${build_args[@]}" - - # Run checks on the built libraries - (run_fbgemm_gpu_postbuild_checks "${fbgemm_variant}") || return 1 - - echo "[BUILD] FBGEMM-GPU build + develop completed" -} diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index 5769cc3934..06c0e4e5a5 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -205,6 +205,22 @@ run_fbgemm_gpu_tests () { done } +test_all_fbgemm_gpu_modules () { + local env_name="$1" + local fbgemm_variant="$2" + + local target_directories=( + fbgemm_gpu/test + fbgemm_gpu/experimental/example/test + ) + + for test_dir in "${target_directories[@]}"; do + cd "${test_dir}" || return 1 + run_fbgemm_gpu_tests "${env_name}" "${fbgemm_variant}" || return 1 + cd - || return 1 + done +} + ################################################################################ # FBGEMM_GPU Test Bulk-Combination Functions @@ -292,9 +308,8 @@ test_fbgemm_gpu_build_and_install () { cd ~/FBGEMM/ || return 1 install_fbgemm_gpu_wheel "${env_name}" fbgemm_gpu/dist/*.whl || return 1 - cd ~/FBGEMM/fbgemm_gpu/test || return 1 - run_fbgemm_gpu_tests "${env_name}" "${pytorch_variant_type}" || return 1 - cd - || return 1 + cd ~/FBGEMM/ || return 1 + test_all_fbgemm_gpu_modules "${env_name}" "${pytorch_variant_type}" || return 1 } test_fbgemm_gpu_setup_and_pip_install () { @@ -323,11 +338,11 @@ test_fbgemm_gpu_setup_and_pip_install () { local env_name="test_py${py_version}_pytorch_${pytorch_channel_version}_fbgemm_${fbgemm_gpu_channel_version}_${variant_type}/${variant_version}" local env_name="${env_name//\//_}" - test_setup_conda_environment "${env_name}" 'no-compiler' "${py_version}" pip "${pytorch_channel_version}" "${variant_type}" "${variant_version}" || return 1 - install_fbgemm_gpu_pip "${env_name}" "${fbgemm_gpu_channel_version}" "${variant_type}/${variant_version}" || return 1 - cd ~/FBGEMM/fbgemm_gpu/test || return 1 + test_setup_conda_environment "${env_name}" 'no-compiler' "${py_version}" pip "${pytorch_channel_version}" "${variant_type}" "${variant_version}" || return 1 + install_fbgemm_gpu_pip "${env_name}" "${fbgemm_gpu_channel_version}" "${variant_type}/${variant_version}" || return 1 + cd ~/FBGEMM || return 1 - run_fbgemm_gpu_tests "${env_name}" "${variant_type}"; + test_all_fbgemm_gpu_modules "${env_name}" "${variant_type}"; local retcode=$? echo "################################################################################" diff --git a/.github/scripts/nova_postscript.bash b/.github/scripts/nova_postscript.bash index a9f2ad9927..dc3871ca70 100644 --- a/.github/scripts/nova_postscript.bash +++ b/.github/scripts/nova_postscript.bash @@ -42,8 +42,8 @@ else fi $CONDA_RUN python3 -c "import torch; print('cuda.is_available() ', torch.cuda.is_available()); print ('device_count() ',torch.cuda.device_count());" -cd "${FBGEMM_REPO}/fbgemm_gpu/test" || { echo "[NOVA] Failed to cd to fbgemm_gpu/test from $(pwd)"; }; -run_fbgemm_gpu_tests "${BUILD_ENV_NAME}" "${fbgemm_variant}" +cd "${FBGEMM_REPO}" || { echo "[NOVA] Failed to cd to ${FBGEMM_REPO} from $(pwd)"; }; +test_all_fbgemm_gpu_modules "${BUILD_ENV_NAME}" "${fbgemm_variant}" # Workaround EACCES: permission denied error at checkout step chown -R 1000:1000 /__w/FBGEMM/FBGEMM/ || echo "Unable to chown 1000:1000 from $USER, uid: $(id -u)" diff --git a/.github/scripts/utils_base.bash b/.github/scripts/utils_base.bash index 7ea56f816c..bb814617f0 100644 --- a/.github/scripts/utils_base.bash +++ b/.github/scripts/utils_base.bash @@ -88,7 +88,7 @@ env_name_or_prefix () { } test_network_connection () { - wget -q --timeout 1 pypi.org -O /dev/null + exec_with_retries 3 wget -q --timeout 1 pypi.org -O /dev/null local exit_status=$? # https://man7.org/linux/man-pages/man1/wget.1.html @@ -96,7 +96,8 @@ test_network_connection () { echo "[CHECK] Network does not appear to be blocked." else echo "[CHECK] Network check exit status: ${exit_status}" - echo "[CHECK] Network appears to be blocked; please proxy the network connetions, i.e. re-run the command prefixed with 'with-proxy'." + echo "[CHECK] Network appears to be blocked or suffering from poor connection." + echo "[CHECK] Please remember to proxy the network connetions if needed, i.e. re-run the command prefixed with 'with-proxy'." return 1 fi } diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 9198a25b57..deb5d50fb1 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -246,6 +246,7 @@ install_build_tools () { # shellcheck disable=SC2086 (exec_with_retries 3 conda install ${env_prefix} -c conda-forge -y \ bazel \ + build \ click \ cmake \ hypothesis \ diff --git a/.github/workflows/fbgemm_gpu_ci_cpu.yml b/.github/workflows/fbgemm_gpu_ci_cpu.yml index 9d19b06f93..e5fd8d0ada 100644 --- a/.github/workflows/fbgemm_gpu_ci_cpu.yml +++ b/.github/workflows/fbgemm_gpu_ci_cpu.yml @@ -182,7 +182,7 @@ jobs: - name: Test with PyTest timeout-minutes: ${{ matrix.host-machine.timeout }} - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cpu + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cpu - name: Push Wheel to PyPI if: ${{ (github.event_name == 'schedule' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_to_pypi == 'true')) && matrix.compiler == 'gcc' }} diff --git a/.github/workflows/fbgemm_gpu_ci_cuda.yml b/.github/workflows/fbgemm_gpu_ci_cuda.yml index b76870245e..fd68558f21 100644 --- a/.github/workflows/fbgemm_gpu_ci_cuda.yml +++ b/.github/workflows/fbgemm_gpu_ci_cuda.yml @@ -202,7 +202,7 @@ jobs: - name: Test with PyTest timeout-minutes: 20 - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cuda + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cuda - name: Push Wheel to PyPI if: ${{ (github.event_name == 'schedule' && matrix.cuda-version == matrix.cuda-version-publish) || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_to_pypi == 'true' && matrix.cuda-version == matrix.cuda-version-publish) }} diff --git a/.github/workflows/fbgemm_gpu_ci_rocm.yml b/.github/workflows/fbgemm_gpu_ci_rocm.yml index f3fca6f5bf..4e35f8cd56 100644 --- a/.github/workflows/fbgemm_gpu_ci_rocm.yml +++ b/.github/workflows/fbgemm_gpu_ci_rocm.yml @@ -191,4 +191,4 @@ jobs: - name: Test with PyTest timeout-minutes: 20 - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV rocm + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV rocm diff --git a/.github/workflows/fbgemm_gpu_pip.yml b/.github/workflows/fbgemm_gpu_pip.yml index 8ef3f1d85c..342f562946 100644 --- a/.github/workflows/fbgemm_gpu_pip.yml +++ b/.github/workflows/fbgemm_gpu_pip.yml @@ -99,7 +99,7 @@ jobs: - name: Test with PyTest timeout-minutes: ${{ matrix.host-machine.timeout }} - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cpu + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cpu test_pypi_install_cuda: @@ -159,7 +159,7 @@ jobs: - name: Test with PyTest timeout-minutes: 20 - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cuda + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cuda test_pypi_install_rocm: @@ -225,4 +225,4 @@ jobs: - name: Test with PyTest timeout-minutes: 20 - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV rocm + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV rocm diff --git a/.github/workflows/fbgemm_gpu_release_cpu.yml b/.github/workflows/fbgemm_gpu_release_cpu.yml index 4261438145..a21a90eb0e 100644 --- a/.github/workflows/fbgemm_gpu_release_cpu.yml +++ b/.github/workflows/fbgemm_gpu_release_cpu.yml @@ -174,7 +174,7 @@ jobs: - name: Test with PyTest timeout-minutes: ${{ matrix.host-machine.timeout }} - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cpu + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cpu - name: Push FBGEMM_GPU (CPU version) Binary to PYPI if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.publish_to_pypi == 'true' }} diff --git a/.github/workflows/fbgemm_gpu_release_cuda.yml b/.github/workflows/fbgemm_gpu_release_cuda.yml index ea806f3573..c640826608 100644 --- a/.github/workflows/fbgemm_gpu_release_cuda.yml +++ b/.github/workflows/fbgemm_gpu_release_cuda.yml @@ -184,7 +184,7 @@ jobs: - name: Test with PyTest timeout-minutes: 20 - run: . $PRELUDE; cd fbgemm_gpu/test; run_fbgemm_gpu_tests $BUILD_ENV cuda + run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV cuda - name: Push FBGEMM_GPU Binary to PYPI if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.publish_to_pypi == 'true' && matrix.cuda-version == github.event.inputs.cuda_version }} diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc index 5ae11655df..87c5bf656c 100644 --- a/bench/ConvUnifiedBenchmark.cc +++ b/bench/ConvUnifiedBenchmark.cc @@ -281,12 +281,8 @@ void performance_test( #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN cout << "WARNING: the timer may be inaccurate when used by multiple threads." << endl; - cout << header << "Im2Col (ms), " - << "Packing (ms), " - << "Kernel (ms), " - << "Postprocessing (ms), " - << "fbgemmPacked (ms), " - << "Total (ms), " + cout << header << "Im2Col (ms), " << "Packing (ms), " << "Kernel (ms), " + << "Postprocessing (ms), " << "fbgemmPacked (ms), " << "Total (ms), " << "GOPS" << endl; #else cout << setw(6) << header << setw(5) << "GOPS" << endl; diff --git a/bench/ConvertBenchmark.cc b/bench/ConvertBenchmark.cc index ac362dfe7f..8f4b468c9c 100644 --- a/bench/ConvertBenchmark.cc +++ b/bench/ConvertBenchmark.cc @@ -28,9 +28,8 @@ void performance_test() { normal_distribution dist; default_random_engine engine; - cout << setw(4) << "M" - << " elements_per_sec_ref" - << " elements_per_sec_simd" << endl; + cout << setw(4) << "M" << " elements_per_sec_ref" << " elements_per_sec_simd" + << endl; array dims{1, 10, 32, 40, 129, 256, 1024, 8000}; diff --git a/bench/EmbeddingQuantizeBenchmark.cc b/bench/EmbeddingQuantizeBenchmark.cc index e0266f9744..cd255cbf78 100644 --- a/bench/EmbeddingQuantizeBenchmark.cc +++ b/bench/EmbeddingQuantizeBenchmark.cc @@ -34,11 +34,9 @@ void performance_test() { } else { cout << "With scale and bias as float" << endl; } - cout << setw(8) << "bit_rate" - << ", " << setw(6) << "rows" - << "," << setw(6) << "cols" - << "," << setw(16) << "elems_per_usec" - << "," << setw(10) << "GB/Sec" << endl; + cout << setw(8) << "bit_rate" << ", " << setw(6) << "rows" << "," << setw(6) + << "cols" << "," << setw(16) << "elems_per_usec" << "," << setw(10) + << "GB/Sec" << endl; std::vector bit_rates; if (is_same::value) { bit_rates = {2, 4, 8}; diff --git a/bench/EmbeddingSpMDMNBitBenchmark.cc b/bench/EmbeddingSpMDMNBitBenchmark.cc index b29a05247c..5b12e91096 100644 --- a/bench/EmbeddingSpMDMNBitBenchmark.cc +++ b/bench/EmbeddingSpMDMNBitBenchmark.cc @@ -352,17 +352,15 @@ int run_benchmark( cout << "prefetch off, "; } - cout << "b/w, " << bytes / 1e9 / t << ", GB/s, " - << "effective b/w, " << bytes_padded / 1e9 / t << ", GB/s, " - << "time, " << t << ", autovec b/w, " << bytes / 1e9 / t_autovec - << ", GB/s, " + cout << "b/w, " << bytes / 1e9 / t << ", GB/s, " << "effective b/w, " + << bytes_padded / 1e9 / t << ", GB/s, " << "time, " << t + << ", autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, " << "autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec - << ", GB/s, " - << "autovec time, " << t_autovec << ", ref b/w, " - << bytes / 1e9 / t_ref << ", GB/s, " - << "ref eff. b/w, " << bytes_padded / 1e9 / t_ref << ", GB/s, " - << "ref time, " << t_ref << ", autovec speedup, " - << t_ref / t_autovec << ", asmjit speedup, " << t_ref / t << endl; + << ", GB/s, " << "autovec time, " << t_autovec << ", ref b/w, " + << bytes / 1e9 / t_ref << ", GB/s, " << "ref eff. b/w, " + << bytes_padded / 1e9 / t_ref << ", GB/s, " << "ref time, " << t_ref + << ", autovec speedup, " << t_ref / t_autovec << ", asmjit speedup, " + << t_ref / t << endl; } // flush_cache } // has_weight return 0; diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc index a378f82b82..06f488cbad 100644 --- a/bench/GroupwiseConvRequantizeBenchmark.cc +++ b/bench/GroupwiseConvRequantizeBenchmark.cc @@ -90,44 +90,15 @@ void performance_test() { #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN cout << "WARNING: the timer may be inaccurate when used by multiple threads." << endl; - cout << "MB, " - << "IC, " - << "OC, " - << "IH, " - << "IW, " - << "KH, " - << "KW, " - << "stride_h, " - << "stride_w, " - << "pad_h, " - << "pad_w, " - << "Type, " - << "M, " - << "N, " - << "K, " - << "Im2Col (ms), " - << "Packing (ms), " - << "Kernel (ms), " - << "Postprocessing (ms), " - << "fbgemmPacked (ms), " - << "Total (ms), " - << "GOPS" << endl; + cout << "MB, " << "IC, " << "OC, " << "IH, " << "IW, " << "KH, " << "KW, " + << "stride_h, " << "stride_w, " << "pad_h, " << "pad_w, " << "Type, " + << "M, " << "N, " << "K, " << "Im2Col (ms), " << "Packing (ms), " + << "Kernel (ms), " << "Postprocessing (ms), " << "fbgemmPacked (ms), " + << "Total (ms), " << "GOPS" << endl; #else - cout << setw(8) << "MB, " - << "IC, " - << "OC, " - << "IH, " - << "IW, " - << "KH, " - << "KW, " - << "stride_h, " - << "stride_w, " - << "pad_h, " - << "pad_w, " - << "Type, " - << "M, " - << "N, " - << "K, " << setw(5) << "GOPS" << endl; + cout << setw(8) << "MB, " << "IC, " << "OC, " << "IH, " << "IW, " << "KH, " + << "KW, " << "stride_h, " << "stride_w, " << "pad_h, " << "pad_w, " + << "Type, " << "M, " << "N, " << "K, " << setw(5) << "GOPS" << endl; #endif chrono::time_point begin, end; @@ -369,8 +340,8 @@ void performance_test() { k]; if (expected != actual) { cout << "Im2Col fused results differ at (" << n << ", " << h - << ", " << w << ", " << k << ")." - << " expected:" << expected << " actual:" << actual << endl; + << ", " << w << ", " << k << ")." << " expected:" << expected + << " actual:" << actual << endl; } } } @@ -527,8 +498,8 @@ void performance_test() { k]; if (expected != actual) { cout << "direct conv results differ at (" << n << ", " << h - << ", " << w << ", " << k << ")." - << " expected:" << expected << " actual:" << actual << endl; + << ", " << w << ", " << k << ")." << " expected:" << expected + << " actual:" << actual << endl; } } } diff --git a/bench/I8SpmdmBenchmark.cc b/bench/I8SpmdmBenchmark.cc index 4ffc78c4dd..53ed219c71 100644 --- a/bench/I8SpmdmBenchmark.cc +++ b/bench/I8SpmdmBenchmark.cc @@ -54,26 +54,12 @@ int main() { #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN cout << "WARNING: the timer may be inaccurate when used by multiple threads." << endl; - cout << "M, " - << "N, " - << "K, " - << "Density, " - << "Accumulation, " - << "Initialize (ms), " - << "Transpose uint8 (ms), " - << "Transpose 32xN (ms), " - << "Compute (ms), " - << "Transpose 32xN (ms), " - << "Total (ms), " - << "GB/s, " - << "GOPs" << endl; + cout << "M, " << "N, " << "K, " << "Density, " << "Accumulation, " + << "Initialize (ms), " << "Transpose uint8 (ms), " + << "Transpose 32xN (ms), " << "Compute (ms), " << "Transpose 32xN (ms), " + << "Total (ms), " << "GB/s, " << "GOPs" << endl; #else - cout << "M, " - << "N, " - << "K, " - << "Density, " - << "Accumulation, " - << "GB/s, " + cout << "M, " << "N, " << "K, " << "Density, " << "Accumulation, " << "GB/s, " << "GOPs" << endl; #endif diff --git a/bench/Im2ColFusedRequantizeBenchmark.cc b/bench/Im2ColFusedRequantizeBenchmark.cc index 09420bf337..29ad02f03c 100644 --- a/bench/Im2ColFusedRequantizeBenchmark.cc +++ b/bench/Im2ColFusedRequantizeBenchmark.cc @@ -73,46 +73,16 @@ void performance_test() { #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN cout << "WARNING: the timer may be inaccurate when used by multiple threads." << endl; - cout << "MB, " - << "IC, " - << "OC, " - << "IH, " - << "IW, " - << "G, " - << "KH, " - << "KW, " - << "stride_h, " - << "stride_w, " - << "pad_h, " - << "pad_w, " - << "Type, " - << "M, " - << "N, " - << "K, " - << "Im2Col (ms), " - << "Packing (ms), " - << "Kernel (ms), " - << "Postprocessing (ms), " - << "fbgemmPacked (ms), " - << "Total (ms), " - << "GOPS" << endl; + cout << "MB, " << "IC, " << "OC, " << "IH, " << "IW, " << "G, " << "KH, " + << "KW, " << "stride_h, " << "stride_w, " << "pad_h, " << "pad_w, " + << "Type, " << "M, " << "N, " << "K, " << "Im2Col (ms), " + << "Packing (ms), " << "Kernel (ms), " << "Postprocessing (ms), " + << "fbgemmPacked (ms), " << "Total (ms), " << "GOPS" << endl; #else - cout << setw(8) << "MB, " - << "IC, " - << "OC, " - << "IH, " - << "IW, " - << "G, " - << "KH, " - << "KW, " - << "stride_h, " - << "stride_w, " - << "pad_h, " - << "pad_w, " - << "Type, " - << "M, " - << "N, " - << "K, " << setw(5) << "GOPS" << endl; + cout << setw(8) << "MB, " << "IC, " << "OC, " << "IH, " << "IW, " << "G, " + << "KH, " << "KW, " << "stride_h, " << "stride_w, " << "pad_h, " + << "pad_w, " << "Type, " << "M, " << "N, " << "K, " << setw(5) << "GOPS" + << endl; #endif chrono::time_point begin, end; diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc index f86aa01470..17e7f216c6 100644 --- a/bench/PackedRequantizeAcc16Benchmark.cc +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -84,15 +84,9 @@ void performance_test() { #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN cout << "WARNING: the timer may be inaccurate when used by multiple threads." << endl; - cout << "M, " - << "N, " - << "K, " - << "Output Processing, " - << "Packing (ms), " - << "Kernel (ms), " - << "Postprocessing (ms), " - << "Total (ms), " - << "GOPS" << endl; + cout << "M, " << "N, " << "K, " << "Output Processing, " << "Packing (ms), " + << "Kernel (ms), " << "Postprocessing (ms), " << "Total (ms), " << "GOPS" + << endl; #else cout << setw(7) << "M, " << setw(7) << "N, " << setw(7) << "K, " << setw(32) << "Output Processing, " << setw(18) << "Type, " << setw(5) << "GOPS" diff --git a/bench/RequantizeBenchmark.cc b/bench/RequantizeBenchmark.cc index 77eec169bf..9b2a209fcd 100644 --- a/bench/RequantizeBenchmark.cc +++ b/bench/RequantizeBenchmark.cc @@ -33,8 +33,7 @@ void performance_test() { constexpr int NWARMUP = 4; constexpr int NITER = 256; - cout << setw(4) << "len" - << ", " << setw(10) << "Type" + cout << setw(4) << "len" << ", " << setw(10) << "Type" << ", B_elements_per_sec" << endl; for (int len : {1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, diff --git a/bench/RowOffsetBenchmark.cc b/bench/RowOffsetBenchmark.cc index 17cd9aebe2..495c9b1271 100644 --- a/bench/RowOffsetBenchmark.cc +++ b/bench/RowOffsetBenchmark.cc @@ -26,8 +26,7 @@ void performance_test() { constexpr int NWARMUP = 4; constexpr int NITER = 256; - cout << setw(4) << "len" - << ", B_elements_per_sec" << endl; + cout << setw(4) << "len" << ", B_elements_per_sec" << endl; for (int len : {1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256}) { diff --git a/cmake/modules/CudaSetup.cmake b/cmake/modules/CudaSetup.cmake new file mode 100644 index 0000000000..d86963109d --- /dev/null +++ b/cmake/modules/CudaSetup.cmake @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) + + +################################################################################ +# CUDA Setup +################################################################################ + +# Set NVML_LIB_PATH if provided, or detect the default lib path +if(NOT NVML_LIB_PATH) + set(DEFAULT_NVML_LIB_PATH + "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") + + if(EXISTS ${DEFAULT_NVML_LIB_PATH}) + message(STATUS "Setting NVML_LIB_PATH: \ + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") + set(NVML_LIB_PATH "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") + endif() +endif() + +if(NVML_LIB_PATH) + message(STATUS "Found NVML_LIB_PATH: ${NVML_LIB_PATH}") +endif() diff --git a/cmake/modules/CxxCompilerSetup.cmake b/cmake/modules/CxxCompilerSetup.cmake new file mode 100644 index 0000000000..11fb3f8916 --- /dev/null +++ b/cmake/modules/CxxCompilerSetup.cmake @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) + + +################################################################################ +# CMake C++ Setup +################################################################################ + +# SET THE C AND C++ VERSIONS HERE +set(C_VERSION 17) +set(CXX_VERSION 20) + +# Set the default C++ standard to CXX_VERSION if CMAKE_CXX_STANDARD is not +# supplied by CMake command invocation. +# Individual targets can have this value overridden; see +# https://cmake.org/cmake/help/latest/variable/CMAKE_CXX_STANDARD.html +# https://cmake.org/cmake/help/latest/prop_tgt/CXX_STANDARD.html +# https://cmake.org/cmake/help/latest/prop_tgt/HIP_STANDARD.html +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD ${CXX_VERSION}) + set(CMAKE_HIP_STANDARD ${CXX_VERSION}) + set(CXX_STANDARD ${CXX_VERSION}) + set(HIP_STANDARD ${CXX_VERSION}) +endif() +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(HIP_STANDARD_REQUIRED ON) + +# Set the default C standard to C_VERSION if CMAKE_C_STANDARD is not supplied +# by CMake command invocation. +# Individual targets can have this value overridden; see +# https://cmake.org/cmake/help/latest/variable/CMAKE_C_STANDARD.html +# https://cmake.org/cmake/help/latest/prop_tgt/C_STANDARD.html +if(NOT CMAKE_C_STANDARD) + set(C_STANDARD ${C_VERSION}) + set(CMAKE_C_STANDARD ${C_VERSION}) +endif() +set(CMAKE_C_EXTENSIONS OFF) +set(CMAKE_C_STANDARD_REQUIRED ON) + +if(DEFINED GLIBCXX_USE_CXX11_ABI) + if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") + endif() +endif() + +BLOCK_PRINT( + "Default C compiler flags" + "(values may be overridden by CMAKE_CXX_STANDARD and CXX_STANDARD):" + "" + "${CMAKE_C_FLAGS}" +) + +BLOCK_PRINT( + "Default C++ compiler flags" + "(values may be overridden by CMAKE_CXX_STANDARD and CXX_STANDARD):" + "" + "${CMAKE_CXX_FLAGS}" +) + +# Strip all symbols from the .SO file after building +add_link_options($<$:-s>) + +# Set flags for AVX2 +set(AVX2_FLAGS "-mavx2;-mf16c;-mfma;-fopenmp") +if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) + # NVCC in WSL complains about unknown -mavx options + # https://github.com/pytorch/FBGEMM/issues/2135 + set(AVX2_FLAGS "-Xcompiler;-mavx;-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-fopenmp") +endif() + +# Set flags for AVX512 +set(AVX512_FLAGS "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl;-fopenmp") +if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) + set(AVX512_FLAGS "-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-Xcompiler;-mavx512f;-Xcompiler;-mavx512bw;-Xcompiler;-mavx512dq;-Xcompiler;-mavx512vl;-fopenmp") +endif() diff --git a/cmake/modules/FindAVX.cmake b/cmake/modules/FindAVX.cmake index 0cf20f5a4d..5bd8cffd61 100644 --- a/cmake/modules/FindAVX.cmake +++ b/cmake/modules/FindAVX.cmake @@ -82,7 +82,6 @@ MACRO(CHECK_SSE lang type flags) ENDIF() MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) - ENDMACRO() CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX") diff --git a/cmake/modules/PyTorchSetup.cmake b/cmake/modules/PyTorchSetup.cmake new file mode 100644 index 0000000000..a5b73eb6f3 --- /dev/null +++ b/cmake/modules/PyTorchSetup.cmake @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) + + +################################################################################ +# PyTorch Dependencies Setup +################################################################################ + +find_package(Torch REQUIRED) + +# +# Toch Cuda Extensions are normally compiled with the flags below. However we +# disabled -D__CUDA_NO_HALF_CONVERSIONS__ here as it caused "error: no suitable +# constructor exists to convert from "int" to "__half" errors in +# gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu +# + +set(TORCH_CUDA_OPTIONS + --expt-relaxed-constexpr -D__CUDA_NO_HALF_OPERATORS__ + # -D__CUDA_NO_HALF_CONVERSIONS__ + -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__) diff --git a/cmake/modules/RocmSetup.cmake b/cmake/modules/RocmSetup.cmake new file mode 100644 index 0000000000..7e37893bf9 --- /dev/null +++ b/cmake/modules/RocmSetup.cmake @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules/Utilities.cmake) + + +################################################################################ +# ROCm and HIPify Setup +################################################################################ + +if(USE_ROCM) + # Load CMake modules + list(APPEND CMAKE_MODULE_PATH + "${PROJECT_SOURCE_DIR}/cmake" + "${THIRDPARTY}/hipify_torch/cmake") + include(Hip) + include(Hipify) + + # Configure compiler for HIP + list(APPEND HIP_HCC_FLAGS + " \"-Wno-#pragma-messages\" " + " \"-Wno-#warnings\" " + -Wno-cuda-compat + -Wno-deprecated-declarations + -Wno-format + -Wno-ignored-attributes + -Wno-unused-result) + + BLOCK_PRINT( + "HIP found: ${HIP_FOUND}" + "HIPCC compiler flags:" + "" + "${HIP_HCC_FLAGS}" + ) +endif() diff --git a/cmake/modules/Utilities.cmake b/cmake/modules/Utilities.cmake new file mode 100644 index 0000000000..2630a22dfd --- /dev/null +++ b/cmake/modules/Utilities.cmake @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +################################################################################ +# Utility Functions +################################################################################ + +function(BLOCK_PRINT) + message("") + message("") + message("================================================================================") + foreach(ARG IN LISTS ARGN) + message("${ARG}") + endforeach() + message("================================================================================") + message("") +endfunction() diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 0d454ff580..7b3ba7ecd7 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -10,18 +10,12 @@ cmake_minimum_required(VERSION 3.25.0 FATAL_ERROR) -function(BLOCK_PRINT) - message("================================================================================") - foreach(ARG IN LISTS ARGN) - message("${ARG}") - endforeach() - message("================================================================================") - message("") -endfunction() - set(CMAKEMODULES ${CMAKE_CURRENT_SOURCE_DIR}/../cmake/modules) set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) set(THIRDPARTY ${FBGEMM}/third_party) +set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen) + +include(${CMAKEMODULES}/Utilities.cmake) ################################################################################ @@ -51,81 +45,13 @@ else() endif() -################################################################################ -# FBGEMM_GPU C++ Setup -################################################################################ - -# Set the default C++ standard to C++20 if CMAKE_CXX_STANDARD is not supplied -# by CMake command invocation. -# Individual targets can have this value overridden; see -# https://cmake.org/cmake/help/latest/variable/CMAKE_CXX_STANDARD.html -# https://cmake.org/cmake/help/latest/prop_tgt/CXX_STANDARD.html -# https://cmake.org/cmake/help/latest/prop_tgt/HIP_STANDARD.html -if(NOT CMAKE_CXX_STANDARD) - set(CMAKE_CXX_STANDARD 20) - set(CMAKE_HIP_STANDARD 20) - set(CXX_STANDARD 20) - set(HIP_STANDARD 20) -endif() -set(CMAKE_CXX_EXTENSIONS OFF) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(HIP_STANDARD_REQUIRED ON) - -# Set the default C standard to C17 -# Individual targets can have this value overridden; see -# https://cmake.org/cmake/help/latest/variable/CMAKE_C_STANDARD.html -# https://cmake.org/cmake/help/latest/prop_tgt/C_STANDARD.html -set(C_STANDARD 20) -set(CMAKE_C_STANDARD 17) -set(CMAKE_C_EXTENSIONS OFF) -set(CMAKE_C_STANDARD_REQUIRED ON) - -if(DEFINED GLIBCXX_USE_CXX11_ABI) - if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") - else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") - endif() -endif() - -BLOCK_PRINT( - "Default C compiler flags" - "(values may be overridden by CMAKE_CXX_STANDARD and CXX_STANDARD):" - "" - "${CMAKE_C_FLAGS}" -) - -BLOCK_PRINT( - "Default C++ compiler flags" - "(values may be overridden by CMAKE_CXX_STANDARD and CXX_STANDARD):" - "" - "${CMAKE_CXX_FLAGS}" -) - -# Strip all symbols from the .SO file after building -add_link_options($<$:-s>) - -# Set flags for AVX2 -set(AVX2_FLAGS "-mavx2;-mf16c;-mfma;-fopenmp") -if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) - # NVCC in WSL complains about unknown -mavx options - # https://github.com/pytorch/FBGEMM/issues/2135 - set(AVX2_FLAGS "-Xcompiler;-mavx;-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-fopenmp") -endif() - -# Set flags for AVX512 -set(AVX512_FLAGS "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl;-fopenmp") -if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) - set(AVX512_FLAGS "-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-Xcompiler;-mavx512f;-Xcompiler;-mavx512bw;-Xcompiler;-mavx512dq;-Xcompiler;-mavx512vl;-fopenmp") -endif() - -set(CMAKE_CODEGEN_DIR ${CMAKE_CURRENT_SOURCE_DIR}/codegen) - - ################################################################################ # FBGEMM_GPU Build Kickstart ################################################################################ +# FBGEMM_GPU C++ Setup - must be set BEFORE project declaration +include(${CMAKEMODULES}/CxxCompilerSetup.cmake) + if(SKBUILD) BLOCK_PRINT("The project is built using scikit-build") endif() @@ -133,87 +59,26 @@ endif() if(FBGEMM_CPU_ONLY OR USE_ROCM) project( fbgemm_gpu - VERSION 0.3.1 + VERSION 0.7.0 LANGUAGES CXX C) else() project( fbgemm_gpu - VERSION 0.3.1 + VERSION 0.7.0 LANGUAGES CXX C CUDA) endif() +# AVX Flags Setup - must be set AFTER project declaration include(${CMAKEMODULES}/FindAVX.cmake) - -################################################################################ # PyTorch Dependencies Setup -################################################################################ - -find_package(Torch REQUIRED) - -# -# Toch Cuda Extensions are normally compiled with the flags below. However we -# disabled -D__CUDA_NO_HALF_CONVERSIONS__ here as it caused "error: no suitable -# constructor exists to convert from "int" to "__half" errors in -# gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu -# +include(${CMAKEMODULES}/PyTorchSetup.cmake) -set(TORCH_CUDA_OPTIONS - --expt-relaxed-constexpr -D__CUDA_NO_HALF_OPERATORS__ - # -D__CUDA_NO_HALF_CONVERSIONS__ - -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__) - - -################################################################################ # CUDA Setup -################################################################################ - -# Set NVML_LIB_PATH if provided, or detect the default lib path -if(NOT NVML_LIB_PATH) - set(DEFAULT_NVML_LIB_PATH - "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") - - if(EXISTS ${DEFAULT_NVML_LIB_PATH}) - message(STATUS "Setting NVML_LIB_PATH: \ - ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") - set(NVML_LIB_PATH "${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libnvidia-ml.so") - endif() -endif() +include(${CMAKEMODULES}/CudaSetup.cmake) -if(NVML_LIB_PATH) - message(STATUS "Found NVML_LIB_PATH: ${NVML_LIB_PATH}") -endif() - - -################################################################################ # ROCm and HIPify Setup -################################################################################ - -if(USE_ROCM) - # Load CMake modules - list(APPEND CMAKE_MODULE_PATH - "${PROJECT_SOURCE_DIR}/cmake" - "${THIRDPARTY}/hipify_torch/cmake") - include(Hip) - include(Hipify) - - # Configure compiler for HIP - list(APPEND HIP_HCC_FLAGS - " \"-Wno-#pragma-messages\" " - " \"-Wno-#warnings\" " - -Wno-cuda-compat - -Wno-deprecated-declarations - -Wno-format - -Wno-ignored-attributes - -Wno-unused-result) - - BLOCK_PRINT( - "HIP found: ${HIP_FOUND}" - "HIPCC compiler flags:" - "" - "${HIP_HCC_FLAGS}" - ) -endif() +include(${CMAKEMODULES}/RocmSetup.cmake) ################################################################################ @@ -223,7 +88,6 @@ endif() file(GLOB_RECURSE asmjit_sources "${CMAKE_CURRENT_SOURCE_DIR}/../third_party/asmjit/src/asmjit/*/*.cpp") - ################################################################################ # Optimizer Group Definitions ################################################################################ @@ -299,9 +163,11 @@ macro(RUN_GEN_SCRIPT SCRIPT) endmacro() foreach(script - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "${CMAKE_CODEGEN_DIR}/genscript/generate_backward_split.py" "${CMAKE_CODEGEN_DIR}/genscript/generate_embedding_optimizer.py" - "${CMAKE_CODEGEN_DIR}/genscript/generate_forward_quantized.py") + "${CMAKE_CODEGEN_DIR}/genscript/generate_forward_quantized.py" + "${CMAKE_CODEGEN_DIR}/genscript/generate_forward_split.py" + "${CMAKE_CODEGEN_DIR}/genscript/generate_index_select.py") RUN_GEN_SCRIPT(${script}) endforeach() @@ -370,7 +236,9 @@ set(gen_cpu_source_files "gen_embedding_forward_quantized_weighted_codegen_cpu.cpp" "gen_embedding_backward_dense_split_cpu.cpp") -set(gen_python_source_files ${CMAKE_BINARY_DIR}/__init__.py) +set(gen_python_source_files + ${CMAKE_BINARY_DIR}/__init__.py + ${CMAKE_BINARY_DIR}/lookup_args.py) # For each of the optimizers, generate the backward split variant by adding # the Python, CPU-only, GPU host, and GPU kernel source files @@ -438,6 +306,9 @@ foreach(optimizer ${DEFUSED_OPTIMIZERS}) "${CMAKE_BINARY_DIR}/split_embedding_optimizer_${optimizer}.py") endforeach() +list(APPEND gen_defused_optim_py_files + ${CMAKE_BINARY_DIR}/optimizer_args.py) + ################################################################################ # FBGEMM_GPU Generated Sources @@ -554,10 +425,10 @@ set_source_files_properties(${fbgemm_sources} ################################################################################ set(fbgemm_gpu_sources_static_cpu - codegen/embedding_forward_split_cpu.cpp + codegen/training/forward/embedding_forward_split_cpu.cpp codegen/inference/embedding_forward_quantized_host_cpu.cpp - codegen/embedding_backward_dense_host_cpu.cpp - codegen/embedding_bounds_check_host_cpu.cpp + codegen/training/backward/embedding_backward_dense_host_cpu.cpp + codegen/utils/embedding_bounds_check_host_cpu.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp @@ -577,13 +448,13 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp - codegen/batch_index_select_dim0_cpu_host.cpp) + codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) if(NOT FBGEMM_CPU_ONLY) list(APPEND fbgemm_gpu_sources_static_cpu codegen/inference/embedding_forward_quantized_host.cpp - codegen/embedding_backward_dense_host.cpp - codegen/embedding_bounds_check_host.cpp + codegen/training/backward/embedding_backward_dense_host.cpp + codegen/utils/embedding_bounds_check_host.cpp src/memory_utils/memory_utils.cpp src/memory_utils/memory_utils_ops.cpp src/memory_utils/memory_utils_ops_cpu.cpp @@ -597,7 +468,7 @@ if(NOT FBGEMM_CPU_ONLY) src/metric_ops/metric_ops_host.cpp src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp src/input_combine_ops/input_combine_gpu.cpp - codegen/batch_index_select_dim0_host.cpp) + codegen/training/index_select/batch_index_select_dim0_host.cpp) if(NVML_LIB_PATH OR USE_ROCM) message(STATUS "Adding merge_pooled_embeddings sources") @@ -621,7 +492,7 @@ endif() if(NOT FBGEMM_CPU_ONLY) set(fbgemm_gpu_sources_static_gpu - codegen/embedding_bounds_check.cu + codegen/utils/embedding_bounds_check.cu codegen/inference/embedding_forward_quantized_split_lookup.cu src/memory_utils/memory_utils.cu src/memory_utils/memory_utils_ops.cu @@ -796,6 +667,12 @@ if(NVML_LIB_PATH) target_link_libraries(fbgemm_gpu_py ${NVML_LIB_PATH}) endif() +# Silence warnings in asmjit +target_compile_options(fbgemm_gpu_py PRIVATE + -Wno-deprecated-anon-enum-enum-conversion) +target_compile_options(fbgemm_gpu_py PRIVATE + -Wno-deprecated-declarations) + ################################################################################ # FBGEMM_GPU Install @@ -807,11 +684,13 @@ install(TARGETS fbgemm_gpu_py install(FILES ${gen_python_source_files} DESTINATION fbgemm_gpu/split_embedding_codegen_lookup_invokers) -install(FILES ${CMAKE_CODEGEN_DIR}/lookup_args.py - DESTINATION fbgemm_gpu/split_embedding_codegen_lookup_invokers) - install(FILES ${gen_defused_optim_py_files} DESTINATION fbgemm_gpu/split_embedding_optimizer_codegen) -install(FILES ${CMAKE_CODEGEN_DIR}/optimizer_args.py - DESTINATION fbgemm_gpu/split_embedding_optimizer_codegen) + + +################################################################################ +# Build Experimental Modules +################################################################################ + +add_subdirectory(experimental/example) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 269b45e49c..2a8b1e653b 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -1583,6 +1583,7 @@ def nbit_device_with_spec( # noqa C901 @click.option("--enforce-hbm", is_flag=True, default=False) @click.option("--fp8-exponent-bits", type=int, default=None) @click.option("--fp8-exponent-bias", type=int, default=None) +@click.option("--uvm-host-mapped", is_flag=True, default=False) def nbit_uvm( alpha: bool, bag_size: int, @@ -1607,6 +1608,7 @@ def nbit_uvm( enforce_hbm: bool, fp8_exponent_bits: Optional[int], fp8_exponent_bias: Optional[int], + uvm_host_mapped: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1655,6 +1657,7 @@ def nbit_uvm( enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb_uvm.fill_random_weights() @@ -1673,6 +1676,7 @@ def nbit_uvm( output_dtype=output_dtype, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb_gpu.fill_random_weights() @@ -1697,6 +1701,7 @@ def nbit_uvm( enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb_mixed.fill_random_weights() @@ -2113,6 +2118,7 @@ def bench_uvm_cls( @click.option("--gather-uvm-cache-stats", is_flag=True, default=False) @click.option("--fp8-exponent-bits", type=int, default=None) @click.option("--fp8-exponent-bias", type=int, default=None) +@click.option("--uvm-host-mapped", is_flag=True, default=False) def nbit_cache( # noqa C901 alpha: float, bag_size: int, @@ -2137,6 +2143,7 @@ def nbit_cache( # noqa C901 gather_uvm_cache_stats: bool, fp8_exponent_bits: Optional[int], fp8_exponent_bias: Optional[int], + uvm_host_mapped: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -2171,6 +2178,7 @@ def nbit_cache( # noqa C901 fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, cache_assoc=cache_assoc, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb_nc.fill_random_weights() fill_random_scale_bias(emb_nc, T, weights_precision) @@ -2197,6 +2205,7 @@ def nbit_cache( # noqa C901 fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, cache_assoc=cache_assoc, + uvm_host_mapped=uvm_host_mapped, ).cuda() emb.fill_random_weights() fill_random_scale_bias(emb, T, weights_precision) diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py deleted file mode 100644 index baf373e7d2..0000000000 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ /dev/null @@ -1,526 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict -# flake8: noqa F401 - -import re -import sys -from typing import Optional - -try: - # Internal - from .embedding_common_code_generator import * -except ImportError: - # OSS - from embedding_common_code_generator import * - -import re - - -def generate_backward_embedding_cuda( - template_filepath: str, - optimizer: str, - filename_format: str, - kwargs: Dict[str, Any], -) -> None: - if not kwargs.get("has_gpu_support"): - return - template = env.get_template(template_filepath) - vbe_options = [True, False] if kwargs.get("has_vbe_support") else [False] - for weighted in [True, False]: - for nobag in [True, False]: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not kwargs.get("dense") or not vbe - ): - wdesc = f"{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }" - filename = filename_format.format(optimizer, wdesc) - write( - filename, - template.render( - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - kdesc=wdesc, - **kwargs, - ), - ) - print(f"[Backward Split] [{optimizer}]: {filename}") - - -def generate(**kwargs: Any) -> None: - optimizer = kwargs.get("optimizer") - gen_args = kwargs["args"] - - # - # Generate GPU variants of the operators - # - kwargs["args"] = gen_args["cuda"] - kwargs["args_pt2"] = gen_args["any_device"] - - # Generate the backward splits - generate_backward_embedding_cuda( - "embedding_backward_split_template.cu", - optimizer, - "gen_embedding_backward_{}_split_{}_cuda.cu", - kwargs, - ) - - generate_backward_embedding_cuda( - "embedding_backward_split_meta_template.cpp", - optimizer, - "gen_embedding_backward_{}_split_{}_meta.cpp", - kwargs, - ) - - # Generate the cta_per_row kernels for the backward splits - generate_backward_embedding_cuda( - "embedding_backward_split_kernel_cta_template.cu", - optimizer, - "gen_embedding_backward_{}_split_{}_kernel_cta.cu", - kwargs, - ) - - # Generate the warp_per_row kernels for the backward splits - generate_backward_embedding_cuda( - "embedding_backward_split_kernel_warp_template.cu", - optimizer, - "gen_embedding_backward_{}_split_{}_kernel_warp.cu", - kwargs, - ) - - # Generate optimizer kernel - template = env.get_template("embedding_optimizer_split_device_kernel_template.cuh") - filename = f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh" - write(filename, template.render(**kwargs)) - - # Generate the backward splits (non-dense) - # We generate only the API to preserve the backward compatibility if - # has_gpu_support=True - if not kwargs.get("dense"): - # TO DO: deprecate - # Generate CUDA Autograd - template = env.get_template("embedding_backward_split_host_template.cpp") - filename = f"gen_embedding_backward_split_{optimizer}.cpp" - write(filename, template.render(**kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename}") - - if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): - # Generates Python invoker for CUDA + CPU - template = env.get_template( - "split_embedding_codegen_lookup_invoker.template" - ) - filename = f"lookup_{optimizer}.py" - write(filename, template.render(is_fbcode=args.is_fbcode, **kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename}") - - # Generate PT2 unified Autograd - template_pt2 = env.get_template( - "embedding_split_host_pt2_autograd_template.cpp" - ) - filename_pt2 = f"gen_embedding_split_{optimizer}_pt2_autograd.cpp" - write(filename_pt2, template_pt2.render(**kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename_pt2}") - - # Generate PT2 backward wrapper - template_pt2 = env.get_template( - "embedding_split_host_pt2_cuda_wrapper_template.cpp" - ) - filename_pt2 = f"gen_embedding_backward_split_{optimizer}_pt2_cuda_wrapper.cpp" - write(filename_pt2, template_pt2.render(is_forward=False, **kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename_pt2}") - - if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): - # Generate Python invoker for CUDA + CPU PT2 - template_pt2 = env.get_template( - "split_embedding_codegen_lookup_invoker.template" - ) - filename_pt2 = f"lookup_{optimizer}_pt2.py" - write(filename_pt2, template_pt2.render(is_fbcode=args.is_fbcode, **kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename_pt2}") - - # - # Generate CPU variants of the operators - # - kwargs["args"] = gen_args["cpu"] - kwargs["args_pt2"] = gen_args["any_device"] - - # Generate the backward splits - if kwargs.get("has_cpu_support"): - is_approx = "approx" in optimizer - template = ( - env.get_template("embedding_backward_split_cpu_approx_template.cpp") - if is_approx - else env.get_template("embedding_backward_split_cpu_template.cpp") - ) - filename = f"gen_embedding_backward_{optimizer}_split_cpu.cpp" - write(filename, template.render(**kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename}") - - # Generate the backward splits (non-dense) - if not kwargs.get("dense"): - template = env.get_template("embedding_backward_split_host_cpu_template.cpp") - filename = f"gen_embedding_backward_split_{optimizer}_cpu.cpp" - write(filename, template.render(**kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename}") - - # Generate PT2 backward wrapper functions - template_pt2 = env.get_template( - "embedding_split_host_pt2_cpu_wrapper_template.cpp" - ) - filename_pt2 = f"gen_embedding_backward_split_{optimizer}_pt2_cpu_wrapper.cpp" - write(filename_pt2, template_pt2.render(is_forward=False, **kwargs)) - print(f"[Backward Split] [{optimizer}]: {filename_pt2}") - - -# Format the way to generate PackedTensorAccessors -def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]: - new_str_list = [] - for pta_str in pta_str_list: - if "packed_accessor" in pta_str: - match = re.search( - r"([a-zA-z0-9_]*)[.]packed_accessor([3|6][2|4])<(.*)>\(\)", pta_str - ) - assert match is not None and len(match.groups()) == 3 - tensor, acc_nbits, args = match.groups() - if "acc_type" in args: - match = re.search("at::acc_type<([a-zA-Z_]*), true>", args) - assert match is not None and len(match.groups()) == 1 - new_type = match.group(1) - args = re.sub("at::acc_type<[a-zA-Z_]*, true>", new_type, args) - func_name_suffix = "_ACC_TYPE" - else: - func_name_suffix = "" - new_str_list.append( - f"{func_name}{func_name_suffix}({tensor}, {args}, {acc_nbits})" - ) - else: - new_str_list.append(pta_str) - return new_str_list - - -def replace_pta_namespace(pta_str_list: List[str]) -> List[str]: - return [ - pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor") - for pta_str in pta_str_list - ] - - -def backward_indices() -> None: - template = env.get_template("embedding_backward_split_indice_weights_template.cu") - src_cu = template.render() - write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu) - src_cu = template.render(dense=True) - write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu) - - -def backward_dense() -> None: - generate( - optimizer="dense", - dense=True, - args=make_args( - [ - (FLOAT, "unused"), - ] - ), - split_precomputation=split_precomputation, - split_weight_update=split_weight_update, - split_post_update="", - split_weight_update_cpu=split_weight_update_cpu, - has_cpu_support=False, - has_gpu_support=True, - has_vbe_support=False, - ) - - -def generate_forward_embedding_cuda( - template_filepath: str, - filename_format: str, - dense_options: List[bool], - nobag_options: List[bool], - vbe_options: List[bool], -) -> None: - template = env.get_template(template_filepath) - for dense in dense_options: - for weighted in [True, False]: - for nobag in nobag_options: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not dense or not vbe - ): - dense_desc = f"{ 'dense' if dense else 'split'}" - weight_desc = f"{ 'weighted' if weighted else 'unweighted' }" - nobag_desc = f"{ '_nobag' if nobag else '' }" - vbe_desc = f"{ '_vbe' if vbe else '' }" - desc = ( - f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }" - ) - filename = filename_format.format(desc) - write( - filename, - template.render( - dense=dense, - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - ), - ) - print(f"[Forward Split]: {filename}") - - -def forward_split() -> None: - # Generate the forward splits - generate_forward_embedding_cuda( - "embedding_forward_split_template.cu", - "gen_embedding_forward_{}_codegen_cuda.cu", - dense_options=[True, False], - nobag_options=[False], # nobag is not used - vbe_options=[True, False], - ) - - generate_forward_embedding_cuda( - "embedding_forward_split_meta_template.cpp", - "gen_embedding_forward_{}_codegen_meta.cpp", - dense_options=[True, False], - nobag_options=[False], # nobag is not used - vbe_options=[True, False], - ) - - # Generate the kernels for the forward splits - generate_forward_embedding_cuda( - "embedding_forward_split_kernel_template.cu", - "gen_embedding_forward_{}_kernel.cu", - dense_options=[True, False], - nobag_options=[True, False], - vbe_options=[True, False], - ) - - # Generate the kernels for the forward splits v2 - generate_forward_embedding_cuda( - "embedding_forward_split_kernel_v2_template.cu", - "gen_embedding_forward_{}_v2_kernel.cu", - dense_options=[False], # dense is not supported - nobag_options=[False], # nobag is not supported - vbe_options=[False], # vbe is not supported - ) - - # Generate the small kernels (for nobag only) for the forward splits - template = env.get_template( - "embedding_forward_split_kernel_nobag_small_template.cu" - ) - for dense in [True, False]: - wdesc = f"{ 'dense' if dense else 'split' }" - filename = f"gen_embedding_forward_{wdesc}_unweighted_nobag_kernel_small.cu" - write(filename, template.render(dense=dense, is_index_select=False)) - print(f"[Forward Split]: {filename}") - - # Generate PT2 forward wrapper cuda - template_pt2 = env.get_template( - "embedding_split_host_pt2_cuda_wrapper_template.cpp", - ) - filename_pt2 = f"gen_embedding_forward_split_pt2_cuda_wrapper.cpp" - write( - filename_pt2, - template_pt2.render( - has_gpu_support=True, - is_forward=True, - has_vbe_support=True, - ), - ) - print(f"[Forward Split]: {filename_pt2}") - - # Generate PT2 forward wrapper cpu - template_pt2 = env.get_template( - "embedding_split_host_pt2_cpu_wrapper_template.cpp", - ) - filename_pt2 = f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp" - write( - filename_pt2, - template_pt2.render( - has_cpu_support=True, - is_forward=True, - ), - ) - print(f"[Forward Split]: {filename_pt2}") - - -def backward_device_kernel() -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) - template_filepath = "embedding_backward_split_device_kernel_template.cuh" - generate_backward_embedding_cuda( - template_filepath=template_filepath, - optimizer="", - filename_format="{}gen_embedding_backward_{}_split_device_kernel.cuh", - kwargs={ - "has_gpu_support": True, - "has_vbe_support": True, - "dense": False, - "gen_once": False, - }, - ) - - # Generate common backward device kernels (generate only once) - template = env.get_template(template_filepath) - write( - "gen_embedding_backward_common_split_device_kernel.cuh", - template.render(gen_once=True), - ) - - -# TODO: Separate this function into another codegen script -def index_select() -> None: - kwargs = make_args([(FLOAT, "unused")]) - kwargs["args"] = kwargs["cuda"] - for templ_file, gen_file in [ - ( - "embedding_forward_split_template.cu", - "gen_batch_index_select_dim0_forward_codegen_cuda.cu", - ), - ( - "embedding_forward_split_kernel_template.cu", - "gen_batch_index_select_dim0_forward_kernel.cu", - ), - ( - "embedding_forward_split_kernel_nobag_small_template.cu", - "gen_batch_index_select_dim0_forward_kernel_small.cu", - ), - ( - "embedding_backward_split_template.cu", - "gen_batch_index_select_dim0_backward_codegen_cuda.cu", - ), - ( - "embedding_backward_split_kernel_cta_template.cu", - "gen_batch_index_select_dim0_backward_kernel_cta.cu", - ), - ( - "embedding_backward_split_kernel_warp_template.cu", - "gen_batch_index_select_dim0_backward_kernel_warp.cu", - ), - ( - "embedding_backward_split_device_kernel_template.cuh", - "gen_embedding_backward_batch_index_select_split_device_kernel.cuh", - ), - ]: - template = env.get_template(templ_file) - write( - gen_file, - template.render( - weighted=False, - dense=True, - vbe=False, - nobag=True, - is_index_select=True, - gen_once=False, - kdesc="batch_index_select", - **kwargs, - ), - ) - - template = env.get_template("embedding_backward_split_grad_template.cu") - write( - "gen_embedding_backward_split_grad_index_select.cu", - template.render(is_index_select=True), - ) - - # Generate common backward device kernels (generate only once) - template = env.get_template("embedding_backward_split_device_kernel_template.cuh") - write( - "gen_embedding_backward_common_split_device_kernel.cuh", - template.render(gen_once=True), - ) - - -def backward_grad() -> None: - # Generate the common grad functions - template = env.get_template("embedding_backward_split_grad_template.cu") - write( - "gen_embedding_backward_split_grad_embedding_ops.cu", - template.render(is_index_select=False), - ) - - -def backward_indices() -> None: - template = env.get_template("embedding_backward_split_indice_weights_template.cu") - src_cu = template.render() - write("gen_embedding_backward_split_indice_weights_codegen_cuda.cu", src_cu) - src_cu = template.render(dense=True) - write("gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", src_cu) - - -def backward_dense() -> None: - generate( - optimizer="dense", - dense=True, - args=make_args( - [ - (FLOAT, "unused"), - ] - ), - has_cpu_support=True, - has_gpu_support=True, - has_vbe_support=False, - ) - - -def gen__init__py() -> None: - template = env.get_template("__init__.template") - src_py = template.render() - write("__init__.py", src_py) - - -def emb_codegen( - install_dir: Optional[str] = None, is_fbcode: Optional[bool] = None -) -> None: - if install_dir is not None and len(install_dir) != 0: - args.install_dir = install_dir - if is_fbcode is not None: - args.is_fbcode = is_fbcode - backward_grad() - - # Generate forwards and specialized backwards - backward_indices() - backward_dense() - # forward_quantized() - forward_split() - - # Generate common device kernels for backwards - backward_device_kernel() - - # Generate backwards and optimizers - generate(**(adagrad())) - generate(**(adam())) - generate(**(lamb())) - generate(**(lars_sgd())) - generate(**(partial_rowwise_adam())) - generate(**(partial_rowwise_lamb())) - generate(**(rowwise_adagrad())) - generate(**(approx_rowwise_adagrad())) - generate(**(rowwise_adagrad_with_weight_decay())) - generate(**(approx_rowwise_adagrad_with_weight_decay())) - generate(**(rowwise_adagrad_with_counter())) - generate(**(approx_rowwise_adagrad_with_counter())) - generate(**(rowwise_weighted_adagrad())) - generate(**(sgd())) - generate(**(approx_sgd())) - generate(**(none_optimizer())) - - # Generate index_select ops using TBE backend - index_select() - gen__init__py() - - -def main() -> None: - emb_codegen() - - -if __name__ == "__main__": - print(f"[EMBEDDING BACKWARD CODE GENERATOR] {sys.argv}") - main() diff --git a/fbgemm_gpu/codegen/embedding_common_code_generator.py b/fbgemm_gpu/codegen/embedding_common_code_generator.py deleted file mode 100644 index c81b680a26..0000000000 --- a/fbgemm_gpu/codegen/embedding_common_code_generator.py +++ /dev/null @@ -1,1738 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -# flake8: noqa F401 - -import argparse -import os -import re -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import jinja2 - -args: argparse.Namespace -_: List[str] -TENSOR: int -INT_TENSOR: int -LONG_TENSOR: int -INT: int -FLOAT: int - - -parser = argparse.ArgumentParser() -# By default the source template files are in the same folder as -# embedding_backward_code_generator.py; -# The install dir is by default the same as the current folder. -parser.add_argument("--install_dir", default=".", help="where to put generated file") -parser.add_argument("--opensource", action="store_false", dest="is_fbcode") -parser.add_argument("--is_rocm", action="store_true") -args, _ = parser.parse_known_args() - - -env = jinja2.Environment( - loader=jinja2.FileSystemLoader(os.path.dirname(os.path.abspath(__file__))) -) -# Upper Limit of "max_embedding_dim (max_D)": -# BT_block_size * sizeof(float) * 4 * kWarpSize * {{ kMaxVecsPerThread }} -# needs to be smaller than the allocated shared memory size (2/3 of 96 KB -# on V100 and 160 KB on A100. -# BT_block_size * 4 * 4 * 32 * (max_D // 128) <= 64 * 1024 (V100) or 96 * 1024 (A100) -# Since BT_block_size >= 1, max_D <= 16K (V100) or 24K (A100). -# Note that if we increase max_D, it will increase the compilation time significantly. -env.globals["max_embedding_dim"] = 2048 -# Max embedding dimension for legacy embedding kernels. TBE v2 can support -# larger max embedding dimension. -env.globals["legacy_max_embedding_dim"] = 1024 -# An optimization for ROCm -env.globals["items_per_warp"] = 128 if args.is_rocm is False else 256 -env.globals["dense"] = False -# The fixed max vectors per thread for different kernels. The numbers were -# derived from empirical studies -env.globals["fixed_max_vecs_per_thread"] = {"backward": 2, "backward_indice_weights": 6} -env.globals["is_rocm"] = args.is_rocm - -###################################################################### -## Helper functions in Jinja's env.globals ## -###################################################################### - - -def prepare_string_for_formatting(blob: str, format_keywords: List[str]) -> str: - """ - Replace curly brackets ('{' or '}') with escape characters ('{{' or '}}') - to prepare the string to be formatted by `str.format()`. `str.format()` - searches curly brackets to find keywords to format. It will run into an - error if the string contains curly brackets. - """ - blob = blob.replace("{", "{{").replace("}", "}}") - for kw in format_keywords: - blob = blob.replace("{{" + kw + "}}", "{" + kw + "}") - return blob - - -def generate_optimized_grad_sum_loop_access( - blob: str, other_formats: Optional[Dict[str, str]] = None -) -> str: - """ - Generate an optimized code for grad_sum accessing - - The indices of `grad_sum` when `kUseVecBlocking` is true and false are - different. When `kUseVecBlocking` is true, `d_vec` is the index. - Otherwise, `vec` is the index. - - When `kUseVecBlocking` is false, the number times that the for-loop is - executed is known at compile time. Thus, we can add the `#pragma unroll` - hint to tell the compiler to optimize the for-loop. - """ - blob = prepare_string_for_formatting(blob, ["grad_vec"]) - - smem_blob = blob.format(grad_vec="smem_grad_sum[d_vec]") - reg_blob = blob.format(grad_vec="grad_sum[vec]") - gen_blob = """ - if (kUseVecBlocking) { - // max_vecs is not known at compile time - for (int32_t vec = 0; - vec < max_vecs && - (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; - ++vec) { - const int32_t d_vec = vec * kThreadGroupSize + threadIdx.x; - [[maybe_unused]] const int32_t d = d_vec * VEC_WIDTH; - {smem_blob} - } - } - else { - // kFixedMaxVecsPerThread is known at compile time - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; - vec < kFixedMaxVecsPerThread - && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; - ++vec) { - const int32_t d_vec = vec * kThreadGroupSize + threadIdx.x; - [[maybe_unused]] const int32_t d = d_vec * VEC_WIDTH; - {reg_blob} - } - } - """ - gen_blob = prepare_string_for_formatting(gen_blob, ["smem_blob", "reg_blob"]) - gen_blob = gen_blob.format(smem_blob=smem_blob, reg_blob=reg_blob) - if other_formats is not None: - gen_blob = prepare_string_for_formatting(gen_blob, list(other_formats.keys())) - gen_blob = gen_blob.format(**other_formats) - return gen_blob - - -def get_max_vecs_template_configs( - items_per_warp: int, - fixed_max_vecs_per_thread: int, - use_subwarp_shuffle: bool, - use_vec_blocking: bool, -) -> List[Tuple[int, int, str]]: - """ - Generate the template configs for each kFixedMaxVecsPerThread, - kThreadGroupSize, and kUseVecBlocking - """ - warp_size = items_per_warp // 4 - configs: List[Tuple[int, int, str]] = [] - - if use_vec_blocking: - # kFixedMaxVecsPerThread = fixed_max_vecs_per_thread - # kThreadGroupSize = kWarpSize - # kUseVecBlocking = true - configs.append((fixed_max_vecs_per_thread, warp_size, "true")) - - # Generate the cases where an entire embedding row can fit in the - # thread-local buffer (i.e., shared memory is not need for grad_sum) - if use_subwarp_shuffle: - # Generate configs for sub-warp templates - group_size = 8 # Smallest group size that TBE supports - while group_size < warp_size: - # kFixedMaxVecsPerThread = 1 - # kThreadGroupSize = group_size - # kUseVecBlocking = false - configs.append((1, group_size, "false")) - group_size *= 2 - - # Generate configs for the full-warp templates - for v in range(1, fixed_max_vecs_per_thread + 1): - configs.append((v, warp_size, "false")) - - return configs - - -def dispatch_non_vec_blocking_kernel( - items_per_warp: int, - fixed_max_vecs_per_thread: int, - use_subwarp_shuffle: bool, -) -> str: - """ - Generate code for kernel dispatching for kernels that do not use vector - blocking (i.e., an entire embedding row can fit in the allocated Vec4T - buffer) - """ - blob = "" - for ( - kFixedMaxVecsPerThread, - kThreadGroupSize, - kUseVecBlocking, - ) in get_max_vecs_template_configs( - items_per_warp, - fixed_max_vecs_per_thread, - use_subwarp_shuffle, - use_vec_blocking=False, - ): - formats = { - "max_D_val": kFixedMaxVecsPerThread * kThreadGroupSize * 4, - "kFixedMaxVecsPerThread": kFixedMaxVecsPerThread, - "kThreadGroupSize": kThreadGroupSize, - "kUseVecBlocking": kUseVecBlocking, - } - max_D_val = kFixedMaxVecsPerThread * kThreadGroupSize * 4 - d_blob = """if (MAX_D <= {max_D_val}) { \\ - [[ maybe_unused ]] const int max_vecs_per_thread = \\ - {kFixedMaxVecsPerThread}; \\ - constexpr int kFixedMaxVecsPerThread = {kFixedMaxVecsPerThread}; \\ - [[ maybe_unused ]] constexpr int kThreadGroupSize = \\ - {kThreadGroupSize}; \\ - [[ maybe_unused ]] constexpr bool kUseVecBlocking = \\ - {kUseVecBlocking}; \\ - return __VA_ARGS__(); \\ - } \\ - """ - d_blob = prepare_string_for_formatting(d_blob, list(formats.keys())) - blob += d_blob.format(**formats) - return blob - - -def dispatch_vec_blocking_kernel( - items_per_warp: int, - fixed_max_vecs_per_thread: int, -) -> str: - """ - Generate code for kernel dispatching for kernels that use vector blocking - (i.e., an entire embedding row cannot fit in the allocated Vec4T buffer) - """ - formats = { - "max_D_val": fixed_max_vecs_per_thread * items_per_warp, - "items_per_warp": items_per_warp, - "fixed_max_vecs_per_thread": fixed_max_vecs_per_thread, - } - blob = """if (MAX_D > {max_D_val}) { \\ - [[ maybe_unused ]] const int max_vecs_per_thread = \\ - (MAX_D + {items_per_warp} - 1) / {items_per_warp}; \\ - constexpr int kFixedMaxVecsPerThread = {fixed_max_vecs_per_thread}; \\ - [[ maybe_unused ]] constexpr int kThreadGroupSize = kWarpSize; \\ - [[ maybe_unused ]] constexpr bool kUseVecBlocking = true; \\ - return __VA_ARGS__(); \\ - } \\ - """ - blob = prepare_string_for_formatting(blob, list(formats.keys())) - return blob.format(**formats) - - -def dispatch_optimal_kernel( - items_per_warp: int, - fixed_max_vecs_per_thread: int, - use_subwarp_shuffle: bool, -) -> str: - """ - Generate code for kernel dispatching for both kernels that use/do not use - vector blocking - """ - blob = dispatch_non_vec_blocking_kernel( - items_per_warp, - fixed_max_vecs_per_thread, - use_subwarp_shuffle, - ) - blob += dispatch_vec_blocking_kernel( - items_per_warp, - fixed_max_vecs_per_thread, - ) - return blob - - -def is_valid_forward_config( - nobag: bool, - weighted: bool, - vbe: bool, - is_index_select: bool, -) -> bool: - """ - Check if the given combination of configs is valid for forward - - nobag does not have weighted or vbe supports - - is_index_select is nobag - """ - return (not nobag or (not weighted and not vbe)) and ( - nobag or (not is_index_select) - ) - - -def has_experimental_support( - dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool -) -> bool: - """ - Check if the given combination of configs has TBE v2 support - - TBE v2 does not support dense, nobag, vbe, is_index_select, and is_rocm - """ - return not dense and not nobag and not vbe and not is_index_select and not is_rocm - - -# Make helper functions visible to code gen -env.globals["generate_optimized_grad_sum_loop_access"] = ( - generate_optimized_grad_sum_loop_access -) -env.globals["get_max_vecs_template_configs"] = get_max_vecs_template_configs -env.globals["dispatch_optimal_kernel"] = dispatch_optimal_kernel -env.globals["dispatch_non_vec_blocking_kernel"] = dispatch_non_vec_blocking_kernel -env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel -env.globals["is_valid_forward_config"] = is_valid_forward_config -env.globals["has_experimental_support"] = has_experimental_support - - -###################################################################### -## Helper functions for the code generator script ## -###################################################################### - - -def write(filename: str, s: str) -> None: - with open(os.path.join(args.install_dir, filename), "w") as f: - f.write(s) - - -def _arg_constructor( - type: str, name: str, gpu: bool = True, precision: int = 32 -) -> str: - return ( - f"{name}.packed_accessor{precision}<{type}, 1, at::RestrictPtrTraits>()" - if gpu - else f"{name}.accessor<{type}, 1>()" - ) - - -def _arg( - type: str, - name: str, - gpu: bool = True, - precision: int = 32, - pass_by_ref: bool = False, -) -> str: - ref = "&" if pass_by_ref else "" - return ( - f"at::PackedTensorAccessor{precision}<{type}, 1, at::RestrictPtrTraits>{ref} {name}" - if gpu - else f"at::TensorAccessor<{type}, 1>{ref} {name}" - ) - - -def acc_cache_tensor_arg_constructor(name: str, gpu: bool = True) -> str: - return _arg_constructor( - "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>", - name, - gpu=gpu, - precision=64, - ) - - -def acc_cache_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> str: - return _arg( - "at::acc_type<" + ("cache_t" if gpu else "scalar_t") + ", true>", - name, - gpu=gpu, - precision=64, - pass_by_ref=pass_by_ref, - ) - - -def long_tensor_arg_constructor(name: str, gpu: bool = True) -> str: - return _arg_constructor("int64_t", name, gpu=gpu) - - -def long_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> str: - return _arg("int64_t", name, gpu=gpu, pass_by_ref=pass_by_ref) - - -def int_tensor_arg_constructor(name: str, gpu: bool = True) -> str: - return _arg_constructor("int32_t", name, gpu=gpu) - - -def int_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> str: - return _arg("int32_t", name, gpu=gpu, pass_by_ref=pass_by_ref) - - -def tensor_arg(name: str) -> str: - return f"Tensor {name}" - - -def double_arg(name: str, default: float = 0.0) -> str: - return f"double {name} = {default}" - - -def double_arg_no_default(name: str) -> str: - return f"double {name}" - - -def float_arg(name: str, default: float = 0.0) -> str: - return f"float {name} = {default}" - - -def float_arg_no_default(name: str) -> str: - return f"float {name}" - - -def int64_arg(name: str, default: int = 0) -> str: - return f"int64_t {name} = {default}" - - -def int64_arg_no_default(name: str) -> str: - return f"int64_t {name}" - - -def int_arg(name: str, default: int = 0) -> str: - return f"int {name} = {default}" - - -# Format the macro call to generate pta::PackedTensorAccessors -def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]: - new_str_list = [] - for pta_str in pta_str_list: - if "packed_accessor" in pta_str: - match = re.search( - r"([a-zA-z0-9_]*)[.]packed_accessor([3|6][2|4])<(.*)>\(\)", pta_str - ) - assert match is not None and len(match.groups()) == 3 - tensor, acc_nbits, args = match.groups() - if "acc_type" in args: - match = re.search("at::acc_type<([a-zA-Z_]*), true>", args) - assert match is not None and len(match.groups()) == 1 - new_type = match.group(1) - args = re.sub("at::acc_type<[a-zA-Z_]*, true>", new_type, args) - macro_name = "MAKE_PTA_ACC_WITH_NAME" - else: - macro_name = "MAKE_PTA_WITH_NAME" - args = args.replace(", at::RestrictPtrTraits", "") - new_str_list.append( - f"{macro_name}({func_name}, {tensor}, {args}, {acc_nbits})" - ) - else: - new_str_list.append(pta_str) - return new_str_list - - -def replace_pta_namespace(pta_str_list: List[str]) -> List[str]: - return [ - pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor") - for pta_str in pta_str_list - ] - - -env.filters["make_pta_acc_format"] = make_pta_acc_format -env.filters["replace_pta_namespace"] = replace_pta_namespace - - -@dataclass -class Args: - split_kernel_args: List[str] - split_kernel_args_no_defaults: List[str] - split_kernel_arg_constructors: List[str] - split_cpu_kernel_args: List[str] - split_cpu_kernel_arg_constructors: List[str] - split_function_args: List[str] - split_function_args_no_defaults: List[str] - split_saved_tensors: List[str] - split_tensors: List[str] - saved_data: List[Tuple[str, str]] - split_function_arg_names: List[str] - split_function_schemas: List[str] - split_variables: List[str] - split_ref_kernel_args: List[str] - - -TENSOR, INT_TENSOR, LONG_TENSOR, INT, FLOAT = range(5) - - -def make_args( - arg_spec: List[Union[Tuple[int, str], Tuple[int, str, Union[float, int]]]] -) -> Dict[str, Any]: - def make_kernel_arg( - ty: int, name: str, default: Union[int, float, None], pass_by_ref: bool = False - ) -> str: - return { - TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref), - INT_TENSOR: lambda x: int_tensor_arg(x, pass_by_ref=pass_by_ref), - LONG_TENSOR: lambda x: long_tensor_arg(x, pass_by_ref=pass_by_ref), - INT: ( - (lambda x: int64_arg(x, default=int(default))) - if default is not None - else int64_arg_no_default - ), - FLOAT: ( - (lambda x: float_arg(x, default=default)) - if default is not None - else float_arg_no_default - ), - }[ty](name) - - def make_kernel_arg_constructor(ty: int, name: str) -> str: - return { - TENSOR: acc_cache_tensor_arg_constructor, - INT_TENSOR: int_tensor_arg_constructor, - LONG_TENSOR: long_tensor_arg_constructor, - INT: lambda x: x, - FLOAT: lambda x: x, - }[ty](name) - - def make_cpu_kernel_arg(ty: int, name: str, default: Union[int, float]) -> str: - return { - TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False), - INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False), - LONG_TENSOR: lambda x: long_tensor_arg(x, gpu=False), - INT: lambda x: int64_arg(x, default=int(default)), - FLOAT: lambda x: float_arg(x, default=default), - }[ty](name) - - def make_cpu_kernel_arg_constructor(ty: int, name: str) -> str: - return { - TENSOR: lambda x: acc_cache_tensor_arg_constructor(x, gpu=False), - INT_TENSOR: lambda x: int_tensor_arg_constructor(x, gpu=False), - LONG_TENSOR: lambda x: long_tensor_arg_constructor(x, gpu=False), - INT: lambda x: x, - FLOAT: lambda x: x, - }[ty](name) - - def make_function_arg( - ty: int, name: str, default: Optional[Union[int, float]] - ) -> str: - return { - TENSOR: tensor_arg, - INT_TENSOR: tensor_arg, - LONG_TENSOR: tensor_arg, - INT: ( - (lambda x: int64_arg(x, default=int(default))) - if default is not None - else int64_arg_no_default - ), - FLOAT: ( - (lambda x: double_arg(x, default=default)) - if default is not None - else double_arg_no_default - ), - }[ty](name) - - def make_function_schema_arg(ty: int, name: str, default: Union[int, float]) -> str: - return { - TENSOR: tensor_arg, - INT_TENSOR: tensor_arg, - LONG_TENSOR: tensor_arg, - INT: lambda x: int_arg(x, default=int(default)), - FLOAT: lambda x: float_arg(x, default=default), - }[ty](name) - - def make_ivalue_cast(ty: int) -> str: - return {INT: "toInt", FLOAT: "toDouble"}[ty] - - def make_args_for_compute_device( - split_arg_spec: List[Tuple[int, str, Union[int, float]]] - ) -> Args: - return Args( - split_kernel_args=[ - make_kernel_arg(ty, name, default) - for (ty, name, default) in split_arg_spec - ], - split_kernel_args_no_defaults=[ - make_kernel_arg(ty, name, None) for (ty, name, _) in split_arg_spec - ], - split_kernel_arg_constructors=[ - make_kernel_arg_constructor(ty, name) - for (ty, name, default) in split_arg_spec - ], - split_cpu_kernel_args=[ - make_cpu_kernel_arg(ty, name, default) - for (ty, name, default) in split_arg_spec - ], - split_cpu_kernel_arg_constructors=[ - make_cpu_kernel_arg_constructor(ty, name) - for (ty, name, default) in split_arg_spec - ], - split_function_args=[ - make_function_arg(ty, name, default) - for (ty, name, default) in split_arg_spec - ], - split_function_args_no_defaults=[ - make_function_arg(ty, name, None) - for (ty, name, default) in split_arg_spec - ], - split_tensors=[ - name for (ty, name, default) in augmented_arg_spec if ty == TENSOR - ], - split_saved_tensors=[ - name - for (ty, name, default) in split_arg_spec - if ty in (TENSOR, INT_TENSOR, LONG_TENSOR) - ], - saved_data=[ - (name, make_ivalue_cast(ty)) - for (ty, name, default) in augmented_arg_spec - if ty != TENSOR - ], - split_function_arg_names=[name for (ty, name, default) in split_arg_spec], - split_function_schemas=[ - make_function_schema_arg(ty, name, default) - for (ty, name, default) in split_arg_spec - ], - split_variables=["Variable()" for _ in split_arg_spec], - split_ref_kernel_args=[ - make_kernel_arg(ty, name, default, pass_by_ref=True) - for (ty, name, default) in split_arg_spec - ], - ) - - DEFAULT_ARG_VAL = 0 - augmented_arg_spec = [ - item if len(item) == 3 else (*item, DEFAULT_ARG_VAL) for item in arg_spec - ] - - split_arg_spec = [] - for ty, arg, default in augmented_arg_spec: - if ty in (FLOAT, INT): - split_arg_spec.append((ty, arg, default)) - else: - assert ty == TENSOR - split_arg_spec.extend( - [ - (TENSOR, f"{arg}_host", default), - (INT_TENSOR, f"{arg}_placements", default), - (LONG_TENSOR, f"{arg}_offsets", default), - ] - ) - cpu = make_args_for_compute_device(split_arg_spec) - - split_arg_spec = [] - for ty, arg, default in augmented_arg_spec: - if ty in (FLOAT, INT): - split_arg_spec.append((ty, arg, default)) - else: - assert ty == TENSOR - split_arg_spec.extend( - [ - (TENSOR, f"{arg}_dev", default), - (TENSOR, f"{arg}_uvm", default), - (INT_TENSOR, f"{arg}_placements", default), - (LONG_TENSOR, f"{arg}_offsets", default), - ] - ) - cuda = make_args_for_compute_device(split_arg_spec) - - split_arg_spec = [] - for ty, arg, default in augmented_arg_spec: - if ty in (FLOAT, INT): - split_arg_spec.append((ty, arg, default)) - else: - assert ty == TENSOR - split_arg_spec.extend( - [ - (TENSOR, f"{arg}_host", default), - (TENSOR, f"{arg}_dev", default), - (TENSOR, f"{arg}_uvm", default), - (INT_TENSOR, f"{arg}_placements", default), - (LONG_TENSOR, f"{arg}_offsets", default), - ] - ) - any_device = make_args_for_compute_device(split_arg_spec) - - return {"cpu": cpu, "cuda": cuda, "any_device": any_device} - - -###################################################################### -## Optimizer templates ## -###################################################################### - - -def adagrad() -> Dict[str, Any]: - split_weight_update = """ - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x += grad.acc.x * grad.acc.x; - m_t.acc.y += grad.acc.y * grad.acc.y; - m_t.acc.z += grad.acc.z * grad.acc.z; - m_t.acc.w += grad.acc.w * grad.acc.w; - m_t.store(&momentum1[idx * D + d]); - - weight_new.acc.x -= learning_rate * grad.acc.x / (sqrtf(m_t.acc.x) + eps); - weight_new.acc.y -= learning_rate * grad.acc.y / (sqrtf(m_t.acc.y) + eps); - weight_new.acc.z -= learning_rate * grad.acc.z / (sqrtf(m_t.acc.z) + eps); - weight_new.acc.w -= learning_rate * grad.acc.w / (sqrtf(m_t.acc.w) + eps); - """ - split_weight_update_cpu = """ - for (int64_t d = 0; d < D; ++d) { - momentum1_host[embedding_begin + d] += - grad_buffer[d] * grad_buffer[d]; - host_weights_data[embedding_begin + d] -= - learning_rate * grad_buffer[d] / - (sqrt(momentum1_host[embedding_begin + d]) + eps); - } - """ - - return { - "optimizer": "adagrad", - "args": make_args( - [(TENSOR, "momentum1"), (FLOAT, "eps"), (FLOAT, "learning_rate")] - ), - "split_precomputation": "", - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": True, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def table_info_precomputation(momentum_prefix: str = "momentum1") -> str: - template = """ - // table_begin -> (E, D, {momentum_prefix}_row_begin). - std::map> table_info_map; - for (int64_t t = 0; t < T; ++t) { - const auto D = D_offsets_data[t + 1] - D_offsets_data[t]; - const auto table_begin = weights_offsets_data[t]; - const auto {momentum_prefix}_row_begin = {momentum_prefix}_offsets_data[t]; - table_info_map[table_begin] = std::make_tuple(0, D, {momentum_prefix}_row_begin); - } - int64_t previous_table_begin = host_weights.numel(); - // NOTE: table_info_map is sorted by table_begin! - for (auto it = table_info_map.rbegin(); it != table_info_map.rend(); ++it) { - const auto D = std::get<1>(it->second); - // Calculates number of rows of each table. - std::get<0>(it->second) = (previous_table_begin - it->first) / D; - previous_table_begin = it->first; - } - """ - return template.replace("{momentum_prefix}", momentum_prefix) - - -def rowwise_adagrad() -> Dict[str, Any]: - split_weight_update = """ - weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w; - """ - split_post_update = """ - if (max_norm > 0.0) { - CUDA_KERNEL_ASSERT(!(std::is_same::value && !cache_weights)); // not supported for uint8 yet - - // compute weight norm - at::acc_type weight_sum_square = 0.0; - for (int32_t vec = 0; - vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; - ++vec) { - const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH; - Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); - weight_sum_square - += weight_new.acc.x * weight_new.acc.x - + weight_new.acc.y * weight_new.acc.y - + weight_new.acc.z * weight_new.acc.z - + weight_new.acc.w * weight_new.acc.w; - } - const at::acc_type weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_square, shfl_sync_mask)); - - // scale by max_norm if weight_norm exceeds max_norm - if (threadIdx.x == 0) { - multiplier = weight_norm > max_norm ? max_norm / weight_norm : 1.0f; - } - multiplier = SHFL_SYNC(multiplier, 0); - if (weight_norm > max_norm) { - for (int32_t vec = 0; - vec < max_vecs && (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH < D; - ++vec) { - const int32_t d = (kThreadGroupSize * vec + threadIdx.x) * VEC_WIDTH; - Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); - - weight_new.acc.x *= multiplier; - weight_new.acc.y *= multiplier; - weight_new.acc.z *= multiplier; - weight_new.acc.w *= multiplier; - weight_row_template.store(weight_new, d, qparams_new); // qparams_new not used if embedding is not int8 - } - } - } - """ - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - if (weight_decay_mode == 1) { - // L2 regularization - Vec4TAcc weight = weight_row_template.load(d, qparams_template); - gx += weight_decay * weight.acc.x; - gy += weight_decay * weight.acc.y; - gz += weight_decay * weight.acc.z; - gw += weight_decay * weight.acc.w; - } - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; - - at::acc_type multiplier; - at::acc_type correction; - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum1[idx] + g_avg_square; - momentum1[idx] = new_sum_square_grads; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - if (weight_decay_mode == 1) { - // L2 regularization - correction = 1.0 - multiplier * weight_decay; - } else if (weight_decay_mode == 2) { - // Decoupled weight decay - correction = 1.0 - learning_rate * weight_decay; - } else { - // default value - correction = 1.0; - } - } - multiplier = SHFL_SYNC(multiplier, 0); - correction = SHFL_SYNC(correction, 0); - """ - split_weight_update_cpu = """ - at::acc_type g_local_sum_square = 0.0; - for (int64_t d = 0; d < D; ++d) { - auto grad = grad_buffer[d]; - if (weight_decay_mode == 1) { - // L2 regularization - grad += weight_decay * host_weights_data[embedding_begin + d]; - } - g_local_sum_square += grad * grad; - } - auto g_avg_square = g_local_sum_square / D; - at::acc_type new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square; - momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; - at::acc_type multiplier; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - at::acc_type correction; - if (weight_decay_mode == 1) { - // L2 regularization - correction = 1.0 - multiplier * weight_decay; - } else if (weight_decay_mode == 2) { - // Decoupled weight decay - correction = 1.0 - learning_rate * weight_decay; - } else { - // default value - correction = 1.0; - } - for (int64_t d = 0; d < D; ++d) { - host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier; - } - """ - - return { - "optimizer": "rowwise_adagrad", - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "weight_decay_mode", 0), - (FLOAT, "max_norm", 0.0), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": split_post_update, - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": True, - "has_gpu_support": True, - "has_vbe_support": True, - } - - -def approx_rowwise_adagrad() -> Dict[str, Any]: - rowwise_adagrad_args = rowwise_adagrad() - - approx_split_weight_update = """ - // dummy computation to avoid unused variable warning - weight_new.fma_(grad, -multiplier); - assert(false); // approx rowwise AdaGrad is not supported on GPU - """ - - return { - "optimizer": "approx_rowwise_adagrad", - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "weight_decay_mode", 0), - ] - ), - "split_precomputation": rowwise_adagrad_args["split_precomputation"], - "split_weight_update": approx_split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": rowwise_adagrad_args["split_weight_update_cpu"], - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -# Deprecated, to be cleaned up -def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: - split_weight_update = """ - weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w; - """ - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - if (weight_decay_mode == 1) { - // L2 regularization - Vec4TAcc weight = weight_row_template.load(d, qparams_template); - gx += weight_decay * weight.acc.x; - gy += weight_decay * weight.acc.y; - gz += weight_decay * weight.acc.z; - gw += weight_decay * weight.acc.w; - } - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; - - at::acc_type multiplier; - at::acc_type correction; - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum1[idx] + g_avg_square; - momentum1[idx] = new_sum_square_grads; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - if (weight_decay_mode == 1) { - // L2 regularization - correction = 1.0 - multiplier * weight_decay; - } else if (weight_decay_mode == 2) { - // Decoupled weight decay - correction = 1.0 - learning_rate * weight_decay; - } else { - // default value - correction = 1.0; - } - } - multiplier = SHFL_SYNC(multiplier, 0); - correction = SHFL_SYNC(correction, 0); - """ - split_weight_update_cpu = """ - at::acc_type g_local_sum_square = 0.0; - for (int64_t d = 0; d < D; ++d) { - auto grad = grad_buffer[d]; - if (weight_decay_mode == 1) { - // L2 regularization - grad += weight_decay * host_weights_data[embedding_begin + d]; - } - g_local_sum_square += grad * grad; - } - auto g_avg_square = g_local_sum_square / D; - at::acc_type new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + g_avg_square; - momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; - at::acc_type multiplier; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - at::acc_type correction; - if (weight_decay_mode == 1) { - // L2 regularization - correction = 1.0 - multiplier * weight_decay; - } else if (weight_decay_mode == 2) { - // Decoupled weight decay - correction = 1.0 - learning_rate * weight_decay; - } else { - // default value - correction = 1.0; - } - for (int64_t d = 0; d < D; ++d) { - host_weights_data[embedding_begin + d] = correction * host_weights_data[embedding_begin + d] - grad_buffer[d] * multiplier; - } - """ - - return { - "optimizer": "rowwise_adagrad_with_weight_decay", - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "weight_decay_mode", 0), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -# Deprecated, to be cleaned up -def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: - rowwise_adagrad_with_weight_decay_args = rowwise_adagrad_with_weight_decay() - - approx_split_weight_update = """ - // dummy computation to avoid unused variable warning - weight_new.fma_(grad, -multiplier); - assert(false); // approx rowwise AdaGrad is not supported on GPU - """ - - return { - "optimizer": "approx_rowwise_adagrad_with_weight_decay", - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "weight_decay_mode", 0), - ] - ), - "split_precomputation": rowwise_adagrad_with_weight_decay_args[ - "split_precomputation" - ], - "split_weight_update": approx_split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": rowwise_adagrad_with_weight_decay_args[ - "split_weight_update_cpu" - ], - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -def rowwise_adagrad_with_counter() -> Dict[str, Any]: - split_weight_update = """ - weight_new.acc.x = (exp_reg_correction * weight_new.acc.x - adjusted_multiplier * grad.acc.x); - weight_new.acc.y = (exp_reg_correction * weight_new.acc.y - adjusted_multiplier * grad.acc.y); - weight_new.acc.z = (exp_reg_correction * weight_new.acc.z - adjusted_multiplier * grad.acc.z); - weight_new.acc.w = (exp_reg_correction * weight_new.acc.w - adjusted_multiplier * grad.acc.w); - """ - split_precomputation = """ - at::acc_type freq = 1.0; - at::acc_type tail_id_threshold_val = tail_id_threshold; - CUDA_KERNEL_ASSERT(max_counter != 0.0); // avoid divide by zero error - if (is_tail_id_thresh_ratio == 1){ - tail_id_threshold_val = floorf(tail_id_threshold * max_counter); - } - if (threadIdx.x == 0) { - if (counter_halflife > 0) { // decay based on counter_halflife - // if id occurs multiple times in a batch, iter_delta=1 - const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx]; - prev_iter[idx] = iter * 1.0; - const auto counter_log_rho = logf(2.0) / counter_halflife; - row_counter[idx] = 1.0 + expf(-iter_delta * counter_log_rho) * row_counter[idx]; - } else if (counter_halflife == 0) { // count only 1 (appear or not) - row_counter[idx] = 1.0; - } else { // count raw appearance without decaying - row_counter[idx] += 1.0; - } - freq = counter_halflife / row_counter[idx]; - } - freq = SHFL_SYNC(freq, 0); - tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0); - - at::acc_type g_local_sum_square = 0.0; - at::acc_type w_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - - Vec4TAcc weight = weight_row_template.load(d, qparams_template); - - // for L2 regularization (weight_decay_mode=1) - // add weight_decay to gradient before other computation - if (weight_decay_mode == 1) { - gx += weight_decay * weight.acc.x; - gy += weight_decay * weight.acc.y; - gz += weight_decay * weight.acc.z; - gw += weight_decay * weight.acc.w; - } - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - - // cow_clip (regularization_mode=4) requires weight norm - if (regularization_mode == 4) { - w_local_sum_square += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; - } - """ - ) - split_precomputation += """ - const at::acc_type g_sum_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask); - const at::acc_type g_avg_square = g_sum_square / D; - const at::acc_type w_sum_square = - warpReduceAllSum, kThreadGroupSize>(w_local_sum_square, shfl_sync_mask); - - at::acc_type adjusted_multiplier; - at::acc_type exp_reg_correction; - - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum1[idx] + g_avg_square; - momentum1[idx] = new_sum_square_grads; - const auto multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter); - - if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3) - adjusted_multiplier = multiplier; - if (learning_rate_mode >=0 && adjustment_enabled) { - if (row_counter[idx] > tail_id_threshold_val) { - if ( learning_rate_mode == 0 ) { - adjusted_multiplier = multiplier * max(min(powf(max_counter/(row_counter[idx] + 1.0), adjustment_ub), 10.0), 1.0); - } else if ( learning_rate_mode == 1 ) { - adjusted_multiplier = multiplier * min(max(powf((row_counter[idx] + 1.0)/max_counter, adjustment_ub), 0.1), 1.0); - } else if (learning_rate_mode == 2) { - adjusted_multiplier = learning_rate / (sqrtf(adjustment_ub*row_counter[idx]) + eps); - } - } - } - } else if (regularization_mode == 4) { // cow-clip (regularization_mode=4) - const auto clip_thresh = row_counter[idx] * max(weight_norm_coefficient * sqrtf(w_sum_square), lower_bound); - adjusted_multiplier = min(1.0f, clip_thresh / sqrtf(g_sum_square)) * multiplier; - } - - exp_reg_correction = 1.0; - if (regularization_mode == 3) { // counter-based regularization (regularization_mode=3) - if (adjustment_enabled) { - if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) - exp_reg_correction = 1.0 - freq * weight_decay * learning_rate; - } else if (weight_decay_mode == 1) { // L2 regularization (coupled wd) - exp_reg_correction = 1.0 - freq * weight_decay * multiplier; - } - } - } else if (regularization_mode == 4) { // cow-clip (regularization_mode=4) - if (weight_decay_mode == 2) { // Decoupled weight decay (weight_decay_mode=2) - exp_reg_correction = 1.0 - weight_decay * learning_rate; - } else if (weight_decay_mode == 1) { // L2 regularization (coupled wd) - exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier; - } - } - } - adjusted_multiplier = SHFL_SYNC(adjusted_multiplier, 0); - exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0); - """ - split_weight_update_cpu = """ - at::acc_type g_local_sum_square = 0.0; - for (int64_t d = 0; d < D; ++d) { - g_local_sum_square += grad_buffer[d] * grad_buffer[d]; - } - auto g_avg_square = g_local_sum_square / D; - auto offset_idx = momentum1_offsets_data[feature_begin] + idx; - at::acc_type new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square; - momentum1_host[offset_idx] = new_sum_square_grads; - at::acc_type multiplier; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx]; - prev_iter_host[offset_idx] = iter * 1.0; - const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0); - const auto exp_reg_correction = powf(exp_reg, iter_delta); - for (int64_t d = 0; d < D; ++d) { - const auto weight = host_weights_data[embedding_begin + d]; - host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d]; - } - """ - - return { - "optimizer": "rowwise_adagrad_with_counter", - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "prev_iter"), - (TENSOR, "row_counter"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "iter"), - (INT, "counter_halflife", -1), - (INT, "adjustment_iter", -1), - (FLOAT, "adjustment_ub", 1.0), - (INT, "learning_rate_mode", -1), - (INT, "weight_decay_mode", 1), - (INT, "grad_sum_decay", -1), - (FLOAT, "max_counter"), - (FLOAT, "tail_id_threshold", 0.0), - (INT, "is_tail_id_thresh_ratio", 0), - (INT, "regularization_mode", 0), - (FLOAT, "weight_norm_coefficient", 0.0), - (FLOAT, "lower_bound", 0.0), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": True, - } - - -def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]: - rowwise_adagrad_with_counter_args = rowwise_adagrad_with_counter() - - approx_split_weight_update = """ - // dummy computation to avoid unused variable warning - weight_new.fma_(grad, -learning_rate); - assert(false); // approx rowwise AdaGrad is not supported on GPU - """ - - return { - "optimizer": "approx_rowwise_adagrad_with_counter", - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "prev_iter"), - (TENSOR, "row_counter"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay", 0.0), - (INT, "iter"), - (INT, "counter_halflife", -1), - (INT, "adjustment_iter", -1), - (FLOAT, "adjustment_ub", 1.0), - (INT, "learning_rate_mode", -1), - (INT, "weight_decay_mode", 1), - (INT, "grad_sum_decay", -1), - (FLOAT, "max_counter"), - (FLOAT, "tail_id_threshold", 0.0), - (INT, "is_tail_id_thresh_ratio", 0), - (INT, "regularization_mode", 0), - (FLOAT, "weight_norm_coefficient", 0.0), - (FLOAT, "lower_bound", 0.0), - ] - ), - "split_precomputation": rowwise_adagrad_with_counter_args[ - "split_precomputation" - ], - "split_weight_update": approx_split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": rowwise_adagrad_with_counter_args[ - "split_weight_update_cpu" - ], - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -# Deprecated, to be cleaned up -def rowwise_weighted_adagrad() -> Dict[str, Any]: - split_weight_update = """ - weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = correction * weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = correction * weight_new.acc.w - multiplier * grad.acc.w; - """ - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - Vec4TAcc weight = weight_row_template.load(d, qparams_template); - auto gx = grad->x + weight_decay * weight.acc.x; - auto gy = grad->y + weight_decay * weight.acc.y; - auto gz = grad->z + weight_decay * weight.acc.z; - auto gw = grad->w + weight_decay * weight.acc.w; - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; - - at::acc_type multiplier; - at::acc_type correction; - if (threadIdx.x == 0) { - at::acc_type lambda = sqrtf(iter + 1); - at::acc_type new_sum_square_grads = momentum1[idx] + lambda * g_avg_square; - momentum1[idx] = new_sum_square_grads; - multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); - correction = 1.0 - multiplier * weight_decay; - } - multiplier = SHFL_SYNC(multiplier, 0); - correction = SHFL_SYNC(correction, 0); - """ - split_weight_update_cpu = """ - // weight_decay not supported for cpu version - at::acc_type g_local_sum_square = 0.0; - for (int64_t d = 0; d < D; ++d) { - g_local_sum_square += grad_buffer[d] * grad_buffer[d]; - } - auto g_avg_square = g_local_sum_square / D; - at::acc_type lambda = sqrtf(iter + 1); - at::acc_type new_sum_square_grads = momentum1_host[momentum1_offsets_data[feature_begin] + idx] + lambda * g_avg_square; - momentum1_host[momentum1_offsets_data[feature_begin] + idx] = new_sum_square_grads; - at::acc_type multiplier; - multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); - for (int64_t d = 0; d < D; ++d) { - host_weights_data[embedding_begin + d] -= grad_buffer[d] * multiplier; - } - """ - - return { - "optimizer": "rowwise_weighted_adagrad", - "is_experimental_optimizer": True, - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "eps"), - (FLOAT, "learning_rate"), - (FLOAT, "weight_decay"), - (INT, "iter"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -def sgd() -> Dict[str, Any]: - split_weight_update = """ - weight_new.fma_(grad, -learning_rate); - """ - split_weight_update_cpu = """ - for (int64_t d = 0; d < D; ++d) { - host_weights_data[embedding_begin + d] -= learning_rate * grad_buffer[d]; - } - """ - - return { - "optimizer": "sgd", - "args": make_args([(FLOAT, "learning_rate")]), - "split_precomputation": "", - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": True, - "has_gpu_support": True, - "has_vbe_support": True, - } - - -def approx_sgd() -> Dict[str, Any]: - sgd_args = sgd() - - approx_split_weight_update = """ - // approx_sgd not supported for GPU. - // Just do the same thing as exact sgd to avoid unused variable warning. - weight_new.fma_(grad, -learning_rate); - assert(false); // approx SGD is not supported on GPU - """ - - return { - "optimizer": "approx_sgd", - "args": make_args([(FLOAT, "learning_rate")]), - "split_precomputation": "", - "split_weight_update": approx_split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": sgd_args["split_weight_update_cpu"], - "has_cpu_support": False, - "has_gpu_support": False, - "has_vbe_support": False, - } - - -def lamb() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type weight_sum_sq = 0.0; - at::acc_type rtw_sum_sq = 0.0; - auto weight_row = WeightRow>(weights, cache_weights, D); - float2 qparams; - if (std::is_same::value && !cache_weights) { - qparams = weight_row.load_qparams(); - } - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - float4* grad = &{grad_vec}.acc; - - Vec4TAcc weight = weight_row.load(d, qparams); - Vec4TAcc m1(&momentum1[idx * D + d]); - - m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad->x; - m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad->y; - m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad->z; - m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad->w; - m1.store(&momentum1[idx * D + d]); - - Vec4TAcc m2(&momentum2[idx * D + d]); - m2.acc.x = beta2 * m2.acc.x + (1.0 - beta2) * grad->x * grad->x; - m2.acc.y = beta2 * m2.acc.y + (1.0 - beta2) * grad->y * grad->y; - m2.acc.z = beta2 * m2.acc.z + (1.0 - beta2) * grad->z * grad->z; - m2.acc.w = beta2 * m2.acc.w + (1.0 - beta2) * grad->w * grad->w; - m2.store(&momentum2[idx * D + d]); - - // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight; - grad->x = (m1.acc.x / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.x; - grad->y = (m1.acc.y / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.y; - grad->z = (m1.acc.z / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.z; - grad->w = (m1.acc.w / (1.0 - powf(beta1, iter))) / (sqrtf((m2.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight.acc.w; - - weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; - rtw_sum_sq += grad->x * grad->x + grad->y * grad->y + grad->z * grad->z + grad->w * grad->w; - """ - ) - split_precomputation += """ - const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq, shfl_sync_mask)); - const auto rtw_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(rtw_sum_sq, shfl_sync_mask)); - const auto true_ratio = weight_norm / rtw_norm; - """ - split_weight_update = """ - weight_new.fma_(grad, -learning_rate * true_ratio); - """ - split_weight_update_cpu = "" - - return { - "optimizer": "lamb", - "is_experimental_optimizer": True, - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "momentum2"), - (FLOAT, "learning_rate"), - (FLOAT, "eps"), - (FLOAT, "beta1"), - (FLOAT, "beta2"), - (FLOAT, "weight_decay"), - (INT, "iter"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def partial_rowwise_lamb() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - g_local_sum_square += grad->x * grad->x + - grad->y * grad->y + - grad->z * grad->z + - grad->w * grad->w; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; - - at::acc_type m2; - if (threadIdx.x == 0) { - m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; - momentum2[idx] = m2; - } - m2 = SHFL_SYNC(m2, 0); - at::acc_type m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); - - at::acc_type weight_sum_sq = 0.0; - at::acc_type rtw_sum_sq = 0.0; - auto weight_row = WeightRow>(weights, cache_weights, D); - float2 qparams; - if (std::is_same::value && !cache_weights) { - qparams = weight_row.load_qparams(); - } - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - float4* grad = &{grad_vec}.acc; - Vec4TAcc m1(&momentum1[idx * D + d]); - m1.acc.x = beta1 * m1.acc.x + (1.0 - beta1) * grad->x; - m1.acc.y = beta1 * m1.acc.y + (1.0 - beta1) * grad->y; - m1.acc.z = beta1 * m1.acc.z + (1.0 - beta1) * grad->z; - m1.acc.w = beta1 * m1.acc.w + (1.0 - beta1) * grad->w; - m1.store(&momentum1[idx * D + d]); - - // now, we are finished with grad_sum. We can *reuse* grad_sum to store r_t + weight_decay * weight; - Vec4TAcc weight = weight_row.load(d, qparams); - grad->x = (m1.acc.x / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.x; - grad->y = (m1.acc.y / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.y; - grad->z = (m1.acc.z / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.z; - grad->w = (m1.acc.w / (1.0 - powf(beta1, iter))) * m2_hat + weight_decay * weight.acc.w; - - weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; - rtw_sum_sq += grad->x * grad->x + grad->y * grad->y + grad->z * grad->z + grad->w * grad->w; - """ - ) - split_precomputation += """ - const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq)); - const auto rtw_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(rtw_sum_sq)); - const auto true_ratio = weight_norm / rtw_norm; - """ - - split_weight_update = """ - weight_new.fma_(grad, -learning_rate * true_ratio); - """ - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "partial_rowwise_lamb", - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "momentum2"), - (FLOAT, "learning_rate"), - (FLOAT, "eps"), - (FLOAT, "beta1"), - (FLOAT, "beta2"), - (FLOAT, "weight_decay"), - (INT, "iter"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def adam() -> Dict[str, Any]: - split_weight_update = """ - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x *= beta1; - m_t.acc.y *= beta1; - m_t.acc.z *= beta1; - m_t.acc.w *= beta1; - m_t.fma_(grad, 1.0 - beta1); - m_t.store(&momentum1[idx * D + d]); - - Vec4T v_t(&momentum2[idx * D + d]); - v_t.acc.x *= beta2; - v_t.acc.y *= beta2; - v_t.acc.z *= beta2; - v_t.acc.w *= beta2; - - grad.acc.x *= grad.acc.x; - grad.acc.y *= grad.acc.y; - grad.acc.z *= grad.acc.z; - grad.acc.w *= grad.acc.w; - v_t.fma_(grad, 1.0 - beta2); - v_t.store(&momentum2[idx * D + d]); - - weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x); - weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y); - weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z); - weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w); - """ - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "adam", - "is_experimental_optimizer": True, - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "momentum2"), - (FLOAT, "learning_rate"), - (FLOAT, "eps"), - (FLOAT, "beta1"), - (FLOAT, "beta2"), - (FLOAT, "weight_decay"), - (INT, "iter"), - ] - ), - "split_precomputation": "", - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def partial_rowwise_adam() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - g_local_sum_square += grad->x * grad->x + - grad->y * grad->y + - grad->z * grad->z + - grad->w * grad->w; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square) / D; - - at::acc_type v_hat_t; - if (threadIdx.x == 0) { - at::acc_type v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2); - momentum2[idx] = v_t; - v_hat_t = v_t / (1.0 - powf(beta2, iter)); - } - v_hat_t = SHFL_SYNC(v_hat_t, 0); - """ - - split_weight_update = """ - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x *= beta1; - m_t.acc.y *= beta1; - m_t.acc.z *= beta1; - m_t.acc.w *= beta1; - m_t.fma_(grad, 1.0 - beta1); - m_t.store(&momentum1[idx * D + d]); - - weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x); - weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y); - weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.z); - weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.w); - """ - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "partial_rowwise_adam", - "args": make_args( - [ - (TENSOR, "momentum1"), - (TENSOR, "momentum2"), - (FLOAT, "learning_rate"), - (FLOAT, "eps"), - (FLOAT, "beta1"), - (FLOAT, "beta2"), - (FLOAT, "weight_decay"), - (INT, "iter"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def lars_sgd() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type weight_sum_sq = 0.0; - at::acc_type grad_sum_sq = 0.0; - - auto weight_row = WeightRow>(weights, cache_weights, D); - float2 qparams; - if (std::is_same::value && !cache_weights) { - qparams = weight_row.load_qparams(); - } - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - Vec4TAcc weight = weight_row.load(d, qparams); - weight_sum_sq += weight.acc.x * weight.acc.x + weight.acc.y * weight.acc.y + weight.acc.z * weight.acc.z + weight.acc.w * weight.acc.w; - grad_sum_sq += grad->x * grad->x + grad->y * grad->y + grad->z * grad->z + grad->w * grad->w; - """ - ) - split_precomputation += """ - const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq)); - const auto grad_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(grad_sum_sq)); - const at::acc_type adjusted_lr = learning_rate * eta * weight_norm / (grad_norm + weight_decay * weight_norm); - """ - - split_weight_update = """ - Vec4T m1(&momentum1[idx * D + d]); - m1.acc.x = momentum * m1.acc.x + adjusted_lr * (grad.acc.x + weight_decay * weight_new.acc.x); - m1.acc.y = momentum * m1.acc.y + adjusted_lr * (grad.acc.y + weight_decay * weight_new.acc.y); - m1.acc.z = momentum * m1.acc.z + adjusted_lr * (grad.acc.z + weight_decay * weight_new.acc.z); - m1.acc.w = momentum * m1.acc.w + adjusted_lr * (grad.acc.w + weight_decay * weight_new.acc.w); - m1.store(&momentum1[idx * D + d]); - - weight_new.acc.x -= m1.acc.x; - weight_new.acc.y -= m1.acc.y; - weight_new.acc.z -= m1.acc.z; - weight_new.acc.w -= m1.acc.w; - """ - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "lars_sgd", - "is_experimental_optimizer": True, - "args": make_args( - [ - (TENSOR, "momentum1"), - (FLOAT, "learning_rate"), - (FLOAT, "eta"), - (FLOAT, "momentum"), - (FLOAT, "weight_decay"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } - - -def none_optimizer() -> Dict[str, Any]: - return { - "optimizer": "none", - "dense": False, - "args": make_args( - [ - (INT, "total_hash_size"), - (INT, "total_unique_indices"), - ] - ), - # Generate only GPU code - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": False, - } diff --git a/fbgemm_gpu/codegen/genscript/__init__.py b/fbgemm_gpu/codegen/genscript/__init__.py index 2e41cd717f..581f84e46a 100644 --- a/fbgemm_gpu/codegen/genscript/__init__.py +++ b/fbgemm_gpu/codegen/genscript/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +# pyre-strict diff --git a/fbgemm_gpu/codegen/genscript/common.py b/fbgemm_gpu/codegen/genscript/common.py index c67062833d..4f58d1e1ec 100644 --- a/fbgemm_gpu/codegen/genscript/common.py +++ b/fbgemm_gpu/codegen/genscript/common.py @@ -42,3 +42,9 @@ def write(self, filename: str, **kwargs: Any) -> None: with open(os.path.join(args.install_dir, filename), "w") as f: f.write(output) print(f"Written: {filename}") + + @staticmethod + def copy_to_root(relative_path: str) -> None: + # Copy template from its relative path to root of the output directory + # e.g. sub/directory/foo.py -> foo.py + CodeTemplate.load(relative_path).write(relative_path.split("/")[-1]) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py new file mode 100644 index 0000000000..4d66cc9f03 --- /dev/null +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# flake8: noqa F401 + +import sys + +try: + from .optimizers import * + from .common import CodeTemplate + from .optimizer_args import OptimizerArgsSet + from .scripts_argsparse import args +except ImportError: + from optimizers import * + + # pyre-ignore[21] + from common import CodeTemplate + + # pyre-ignore[21] + from optimizer_args import OptimizerArgsSet + + # pyre-ignore[21] + from scripts_argsparse import args + + +class BackwardSplitGenerator: + @staticmethod + def render_backward_templates( + template_filepath: str, + optimizer: str, + filename_format: str, + kwargs: Dict[str, Any], + ) -> None: + if not kwargs.get("has_gpu_support"): + return + vbe_options = [True, False] if kwargs.get("has_vbe_support") else [False] + template = CodeTemplate.load(template_filepath) + + for weighted in [True, False]: + for nobag in [True, False]: + for vbe in vbe_options: + if (not nobag or (not weighted and not vbe)) and ( + not kwargs.get("dense") or not vbe + ): + wdesc = f"{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }" + template.write( + filename_format.format(optimizer, wdesc), + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, + kdesc=wdesc, + **kwargs, + ) + + @staticmethod + def generate_backward_split_gpu(**kwargs: Any) -> None: + """ + Generate CUDA variants of the TBE backward split operators + """ + + optimizer = kwargs.get("optimizer") + # Generate the backward split kernels + for template_filepath, filename_format in [ + ( + "training/backward/embedding_backward_split_template.cu", + "gen_embedding_backward_{}_split_{}_cuda.cu", + ), + ( + "training/backward/embedding_backward_split_meta_template.cpp", + "gen_embedding_backward_{}_split_{}_meta.cpp", + ), + ( + "training/backward/embedding_backward_split_kernel_cta_template.cu", + "gen_embedding_backward_{}_split_{}_kernel_cta.cu", + ), + ( + "training/backward/embedding_backward_split_kernel_warp_template.cu", + "gen_embedding_backward_{}_split_{}_kernel_warp.cu", + ), + ]: + BackwardSplitGenerator.render_backward_templates( + template_filepath, + optimizer, + filename_format, + kwargs, + ) + + # Generate optimizer kernel + CodeTemplate.load( + "training/optimizer/embedding_optimizer_split_device_kernel_template.cuh" + ).write( + f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", **kwargs + ) + + # Generate the backward splits (non-dense) + # We generate only the API to preserve the backward compatibility if + # has_gpu_support=True + if not kwargs.get("dense"): + # Generate CUDA autograd, PT2 unified autograd, and PT2 backward wrapper + for template_filepath, filename in [ + ( + "training/backward/embedding_backward_split_host_template.cpp", + f"gen_embedding_backward_split_{optimizer}.cpp", + ), + ( + "training/pt2/embedding_split_host_pt2_autograd_template.cpp", + f"gen_embedding_split_{optimizer}_pt2_autograd.cpp", + ), + ( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + f"gen_embedding_backward_split_{optimizer}_pt2_cuda_wrapper.cpp", + ), + ]: + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, **kwargs + ) + + if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): + # Generates Python invoker for CUDA + CPU, and PT2 + template = CodeTemplate.load( + "training/python/split_embedding_codegen_lookup_invoker.template" + ) + for filename in [ + f"lookup_{optimizer}.py", + f"lookup_{optimizer}_pt2.py", + ]: + template.write(filename, is_fbcode=args.is_fbcode, **kwargs) + + @staticmethod + def generate_backward_split_cpu(**kwargs: Any) -> None: + """ + Generate CPU variants of the TBE backward split operators + """ + + optimizer = kwargs.get("optimizer") + + # Generate the backward splits + if kwargs.get("has_cpu_support"): + CodeTemplate.load( + "training/backward/embedding_backward_split_cpu_approx_template.cpp" + if "approx" in optimizer + else "training/backward/embedding_backward_split_cpu_template.cpp" + ).write(f"gen_embedding_backward_{optimizer}_split_cpu.cpp", **kwargs) + + # Generate the backward splits (non-dense) + if not kwargs.get("dense"): + for template_filepath, filename in [ + ( + "training/backward/embedding_backward_split_host_cpu_template.cpp", + f"gen_embedding_backward_split_{optimizer}_cpu.cpp", + ), + ( + "training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp", + f"gen_embedding_backward_split_{optimizer}_pt2_cpu_wrapper.cpp", + ), + ]: + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, **kwargs + ) + + @staticmethod + def generate_backward_split(**kwargs: Any) -> None: + gen_args = kwargs["args"] + kwargs["args_pt2"] = gen_args.any + + kwargs["args"] = gen_args.cuda + BackwardSplitGenerator.generate_backward_split_gpu(**kwargs) + + kwargs["args"] = gen_args.cpu + BackwardSplitGenerator.generate_backward_split_cpu(**kwargs) + + @staticmethod + def generate_backward_device() -> None: + # Generate backward device kernels based on weighted (True/False), VBE + # (True/False), no bag (True/False) + template_filepath = ( + "training/backward/embedding_backward_split_device_kernel_template.cuh" + ) + + BackwardSplitGenerator.render_backward_templates( + template_filepath, + "", + "{}gen_embedding_backward_{}_split_device_kernel.cuh", + { + "has_gpu_support": True, + "has_vbe_support": True, + "dense": False, + "gen_once": False, + }, + ) + + # Generate common backward device kernels (generate only once) + CodeTemplate.load(template_filepath).write( + "gen_embedding_backward_common_split_device_kernel.cuh", + gen_once=True, + ) + + @staticmethod + def generate_backward_grad() -> None: + # Generate the common grad functions + CodeTemplate.load( + "training/backward/embedding_backward_split_grad_template.cu" + ).write( + "gen_embedding_backward_split_grad_embedding_ops.cu", is_index_select=False + ) + + @staticmethod + def generate_backward_indices() -> None: + template = CodeTemplate.load( + "training/backward/embedding_backward_split_indice_weights_template.cu" + ) + for dense in [True, False]: + template.write( + f"gen_embedding_backward_{'dense' if dense else 'split'}_indice_weights_codegen_cuda.cu", + dense=dense, + ) + + @staticmethod + def generate_python_sources() -> None: + CodeTemplate.load("training/python/__init__.template").write("__init__.py") + CodeTemplate.copy_to_root("training/python/lookup_args.py") + + @staticmethod + def generate() -> None: + # Generate backwards and optimizers + optimizers = [ + dense(), + adagrad(), + adam(), + lamb(), + lars_sgd(), + partial_rowwise_adam(), + partial_rowwise_lamb(), + rowwise_adagrad(), + approx_rowwise_adagrad(), + rowwise_adagrad_with_weight_decay(), + approx_rowwise_adagrad_with_weight_decay(), + rowwise_adagrad_with_counter(), + approx_rowwise_adagrad_with_counter(), + rowwise_weighted_adagrad(), + sgd(), + approx_sgd(), + none_optimizer(), + ] + + for optimizer in optimizers: + BackwardSplitGenerator.generate_backward_split(**optimizer) + + # Generate common device kernels for backwards + BackwardSplitGenerator.generate_backward_device() + + # Generate forwards and specialized backwards + BackwardSplitGenerator.generate_backward_grad() + BackwardSplitGenerator.generate_backward_indices() + + BackwardSplitGenerator.generate_python_sources() + + +def main() -> None: + BackwardSplitGenerator.generate() + + +if __name__ == "__main__": + print(f"[GENERAATE BACKWARD SPLIT]: {sys.argv}") + main() diff --git a/fbgemm_gpu/codegen/genscript/generate_embedding_optimizer.py b/fbgemm_gpu/codegen/genscript/generate_embedding_optimizer.py index 28c942db43..8b4ef09d03 100644 --- a/fbgemm_gpu/codegen/genscript/generate_embedding_optimizer.py +++ b/fbgemm_gpu/codegen/genscript/generate_embedding_optimizer.py @@ -28,7 +28,7 @@ class EmbeddingOptimizerGenerator: @staticmethod - def generate(**kwargs: Any) -> None: + def generate_embedding_optimizer(**kwargs: Any) -> None: """ Generate embedding optimizer code blocks (host, CUDA host, CUDA kernel, and header files) given the optimizer's parameters. @@ -40,41 +40,48 @@ def generate(**kwargs: Any) -> None: ) kwargs["args"] = kwargs["args"].cuda - # Generate CUDA host code - CodeTemplate.load("embedding_optimizer_split_template.cu").write( - f"gen_embedding_optimizer_{optimizer}_split_cuda.cu", **kwargs - ) - - # Generate CUDA kernel code - CodeTemplate.load("embedding_optimizer_split_kernel_template.cu").write( - f"gen_embedding_optimizer_{optimizer}_split_kernel.cu", **kwargs - ) + PREFIX = "training/optimizer" + + for template_filepath, filename in [ + ( # CUDA host code + f"{PREFIX}/embedding_optimizer_split_template.cu", + f"gen_embedding_optimizer_{optimizer}_split_cuda.cu", + ), + ( # CUDA kernel code + f"{PREFIX}/embedding_optimizer_split_kernel_template.cu", + f"gen_embedding_optimizer_{optimizer}_split_kernel.cu", + ), + ( # CPU code + f"{PREFIX}/embedding_optimizer_split_host_template.cpp", + f"gen_embedding_optimizer_{optimizer}_split.cpp", + ), + ( # Optimizer kernel headers + f"{PREFIX}/embedding_optimizer_split_device_kernel_template.cuh", + f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", + ), + ( # Python kernel invokers + "training/python/split_embedding_optimizer_codegen.template", + f"split_embedding_optimizer_{optimizer}.py", + ), + ]: + CodeTemplate.load(template_filepath).write( + filename, is_fbcode=args.is_fbcode, **kwargs + ) - # Generate host code - CodeTemplate.load("embedding_optimizer_split_host_template.cpp").write( - f"gen_embedding_optimizer_{optimizer}_split.cpp", **kwargs - ) + @staticmethod + def generate() -> None: + optimizers = [rowwise_adagrad()] - # Generates Python invoker for CUDA - CodeTemplate.load("split_embedding_optimizer_codegen.template").write( - f"split_embedding_optimizer_{optimizer}.py", - is_fbcode=args.is_fbcode, - **kwargs, - ) + for optimizer in optimizers: + EmbeddingOptimizerGenerator.generate_embedding_optimizer(**optimizer) - # Generate optimizer kernel headers - CodeTemplate.load("embedding_optimizer_split_device_kernel_template.cuh").write( - f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", **kwargs - ) + CodeTemplate.copy_to_root("training/python/optimizer_args.py") def main() -> None: - optimizers = [rowwise_adagrad()] - - for optimizer in optimizers: - EmbeddingOptimizerGenerator.generate(**optimizer) + EmbeddingOptimizerGenerator.generate() if __name__ == "__main__": - print(f"[GENERATE OPTIMIZERS] {sys.argv}") + print(f"[GENERATE OPTIMIZERS]: {sys.argv}") main() diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py b/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py index a9f15386af..788e83240b 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py @@ -174,5 +174,5 @@ def main() -> None: if __name__ == "__main__": - print(f"[GENERATE FORWARD QUANTIZED] {sys.argv}") + print(f"[GENERATE FORWARD QUANTIZED]: {sys.argv}") main() diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py new file mode 100644 index 0000000000..7abaed7afc --- /dev/null +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# flake8: noqa F401 + +import sys +from typing import List + +try: + from .common import CodeTemplate +except ImportError: + # pyre-ignore[21] + from common import CodeTemplate + + +class ForwardSplitGenerator: + @staticmethod + def render_forward_templates( + template_filepath: str, + filename_format: str, + dense_options: List[bool], + nobag_options: List[bool], + vbe_options: List[bool], + ) -> None: + template = CodeTemplate.load(template_filepath) + for dense in dense_options: + for weighted in [True, False]: + for nobag in nobag_options: + for vbe in vbe_options: + if (not nobag or (not weighted and not vbe)) and ( + not dense or not vbe + ): + dense_desc = f"{ 'dense' if dense else 'split'}" + weight_desc = ( + f"{ 'weighted' if weighted else 'unweighted' }" + ) + nobag_desc = f"{ '_nobag' if nobag else '' }" + vbe_desc = f"{ '_vbe' if vbe else '' }" + + template.write( + filename_format.format( + f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }" + ), + dense=dense, + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, + ) + + @staticmethod + def generate_pt2_wrappers() -> None: + # Generate PT2 forward wrapper (CUDA) + CodeTemplate.load( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + ).write( + f"gen_embedding_forward_split_pt2_cuda_wrapper.cpp", + has_gpu_support=True, + is_forward=True, + has_vbe_support=True, + ) + + # Generate PT2 forward wrapper (CPU) + CodeTemplate.load( + "training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp", + ).write( + f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", + has_cpu_support=True, + is_forward=True, + ) + + @staticmethod + def generate_small_kernels() -> None: + # Generate the small kernels (for nobag only) for the forward splits + template = CodeTemplate.load( + "training/forward/embedding_forward_split_kernel_nobag_small_template.cu" + ) + for dense in [True, False]: + wdesc = f"{ 'dense' if dense else 'split' }" + template.write( + f"gen_embedding_forward_{wdesc}_unweighted_nobag_kernel_small.cu", + dense=dense, + is_index_select=False, + ) + + @staticmethod + def generate_kernels() -> None: + # Generate the CUDA host code + ForwardSplitGenerator.render_forward_templates( + "training/forward/embedding_forward_split_template.cu", + "gen_embedding_forward_{}_codegen_cuda.cu", + dense_options=[True, False], + nobag_options=[False], # nobag is not used + vbe_options=[True, False], + ) + + # Generate the meta kernels + ForwardSplitGenerator.render_forward_templates( + "training/forward/embedding_forward_split_meta_template.cpp", + "gen_embedding_forward_{}_codegen_meta.cpp", + dense_options=[True, False], + nobag_options=[False], # nobag is not used + vbe_options=[True, False], + ) + + # Generate the CUDA kernels + ForwardSplitGenerator.render_forward_templates( + "training/forward/embedding_forward_split_kernel_template.cu", + "gen_embedding_forward_{}_kernel.cu", + dense_options=[True, False], + nobag_options=[True, False], + vbe_options=[True, False], + ) + + # Generate the v2 CUDA kernels + ForwardSplitGenerator.render_forward_templates( + "training/forward/embedding_forward_split_kernel_v2_template.cu", + "gen_embedding_forward_{}_v2_kernel.cu", + dense_options=[False], # dense is not supported + nobag_options=[False], # nobag is not supported + vbe_options=[False], # vbe is not supported + ) + + @staticmethod + def generate() -> None: + ForwardSplitGenerator.generate_kernels() + ForwardSplitGenerator.generate_small_kernels() + ForwardSplitGenerator.generate_pt2_wrappers() + + +def main() -> None: + ForwardSplitGenerator.generate() + + +if __name__ == "__main__": + print(f"[GENERATE FORWARD SPLIT]: {sys.argv}") + main() diff --git a/fbgemm_gpu/codegen/genscript/generate_index_select.py b/fbgemm_gpu/codegen/genscript/generate_index_select.py new file mode 100644 index 0000000000..d070fb3fad --- /dev/null +++ b/fbgemm_gpu/codegen/genscript/generate_index_select.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# flake8: noqa F401 + +import re +import sys +from typing import Optional + +try: + from .common import CodeTemplate + from .optimizer_args import FLOAT, OptimizerArgsSet +except ImportError: + # pyre-ignore[21] + from common import CodeTemplate + + # pyre-ignore[21] + from optimizer_args import FLOAT, OptimizerArgsSet + + +class IndexSelectGenerator: + @staticmethod + def generate() -> None: + optargs = OptimizerArgsSet.create([(FLOAT, "unused")]) + for template_file, generated_file in [ + ( + "training/forward/embedding_forward_split_template.cu", + "gen_batch_index_select_dim0_forward_codegen_cuda.cu", + ), + ( + "training/forward/embedding_forward_split_kernel_template.cu", + "gen_batch_index_select_dim0_forward_kernel.cu", + ), + ( + "training/forward/embedding_forward_split_kernel_nobag_small_template.cu", + "gen_batch_index_select_dim0_forward_kernel_small.cu", + ), + ( + "training/backward/embedding_backward_split_template.cu", + "gen_batch_index_select_dim0_backward_codegen_cuda.cu", + ), + ( + "training/backward/embedding_backward_split_kernel_cta_template.cu", + "gen_batch_index_select_dim0_backward_kernel_cta.cu", + ), + ( + "training/backward/embedding_backward_split_kernel_warp_template.cu", + "gen_batch_index_select_dim0_backward_kernel_warp.cu", + ), + ( + "training/backward/embedding_backward_split_device_kernel_template.cuh", + "gen_embedding_backward_batch_index_select_split_device_kernel.cuh", + ), + ]: + CodeTemplate.load(template_file).write( + generated_file, + weighted=False, + dense=True, + vbe=False, + nobag=True, + is_index_select=True, + gen_once=False, + kdesc="batch_index_select", + args=optargs.cuda, + ) + + CodeTemplate.load( + "training/backward/embedding_backward_split_grad_template.cu" + ).write( + "gen_embedding_backward_split_grad_index_select.cu", + is_index_select=True, + ) + + # Generate common backward device kernels (generate only once) + CodeTemplate.load( + "training/backward/embedding_backward_split_device_kernel_template.cuh" + ).write( + "gen_embedding_backward_common_split_device_kernel.cuh", + gen_once=True, + ) + + +def main() -> None: + IndexSelectGenerator.generate() + + +if __name__ == "__main__": + print(f"[INDEX SELECT GENERATOR]: {sys.argv}") + main() diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index b056c3c1ba..c31bedb076 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -39,6 +39,21 @@ ###################################################################### +def dense() -> Dict[str, Any]: + return { + "optimizer": "dense", + "dense": True, + "args": OptimizerArgsSet.create( + [ + (FLOAT, "unused"), + ] + ), + "has_cpu_support": True, + "has_gpu_support": True, + "has_vbe_support": False, + } + + def adagrad() -> Dict[str, Any]: split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); @@ -125,7 +140,7 @@ def rowwise_adagrad() -> Dict[str, Any]: + weight_new.acc.w * weight_new.acc.w; } const at::acc_type weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_square, shfl_sync_mask)); + sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_square, at::acc_type)); // scale by max_norm if weight_norm exceeds max_norm if (threadIdx.x == 0) { @@ -171,7 +186,7 @@ def rowwise_adagrad() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; at::acc_type multiplier; at::acc_type correction; @@ -307,7 +322,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; at::acc_type multiplier; at::acc_type correction; @@ -478,10 +493,10 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_sum_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask); + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type); const at::acc_type g_avg_square = g_sum_square / D; const at::acc_type w_sum_square = - warpReduceAllSum, kThreadGroupSize>(w_local_sum_square, shfl_sync_mask); + GROUP_REDUCE_ALL_SUM(w_local_sum_square, at::acc_type); at::acc_type adjusted_multiplier; at::acc_type exp_reg_correction; @@ -658,7 +673,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; at::acc_type multiplier; at::acc_type correction; @@ -799,9 +814,9 @@ def lamb() -> Dict[str, Any]: ) split_precomputation += """ const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq, shfl_sync_mask)); + sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_sq, at::acc_type)); const auto rtw_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(rtw_sum_sq, shfl_sync_mask)); + sqrtf(GROUP_REDUCE_ALL_SUM(rtw_sum_sq, at::acc_type)); const auto true_ratio = weight_norm / rtw_norm; """ split_weight_update = """ @@ -849,7 +864,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square, shfl_sync_mask) / D; + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; at::acc_type m2; if (threadIdx.x == 0) { @@ -890,9 +905,9 @@ def partial_rowwise_lamb() -> Dict[str, Any]: ) split_precomputation += """ const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq)); + sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_sq, at::acc_type)); const auto rtw_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(rtw_sum_sq)); + sqrtf(GROUP_REDUCE_ALL_SUM(rtw_sum_sq, at::acc_type)); const auto true_ratio = weight_norm / rtw_norm; """ @@ -995,7 +1010,7 @@ def partial_rowwise_adam() -> Dict[str, Any]: ) split_precomputation += """ const at::acc_type g_avg_square = - warpReduceAllSum, kThreadGroupSize>(g_local_sum_square) / D; + GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; at::acc_type v_hat_t; if (threadIdx.x == 0) { @@ -1067,9 +1082,9 @@ def lars_sgd() -> Dict[str, Any]: ) split_precomputation += """ const auto weight_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(weight_sum_sq)); + sqrtf(GROUP_REDUCE_ALL_SUM(weight_sum_sq, at::acc_type)); const auto grad_norm = - sqrtf(warpReduceAllSum, kThreadGroupSize>(grad_sum_sq)); + sqrtf(GROUP_REDUCE_ALL_SUM(grad_sum_sq, at::acc_type)); const at::acc_type adjusted_lr = learning_rate * eta * weight_norm / (grad_norm + weight_decay * weight_norm); """ diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 00f1cc1f5c..47189f79e6 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index e556cc241c..bb04e13b10 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -8,7 +8,7 @@ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 9e087ac37f..3c2ce5e435 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -8,7 +8,7 @@ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_dense_host.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host.cpp diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp similarity index 99% rename from fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index ef6d2f6f0f..f114ab203d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -10,8 +10,8 @@ #include #include -#include "codegen/embedding_forward_split_cpu.h" #include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp similarity index 99% rename from fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp index 5dedfd0f57..27bb491b05 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp @@ -13,7 +13,7 @@ #include #include -#include "codegen/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm_gpu/cpu_utils.h" #include "fbgemm_gpu/embedding_common.h" diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp similarity index 99% rename from fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp index 5f0f445f2f..d2dc7f290d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp @@ -16,7 +16,7 @@ #include #include -#include "codegen/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm_gpu/embedding_common.h" diff --git a/fbgemm_gpu/codegen/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_device_kernel_template.cuh rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp similarity index 99% rename from fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp index a5e892ed1f..18dfea9217 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp @@ -11,7 +11,7 @@ #include #include -#include "codegen/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/sparse_ops_utils.h" diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu similarity index 99% rename from fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 1f7e1f1b64..501137b61d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -14,7 +14,7 @@ // Required for op registrations #include "fbgemm_gpu/embedding_op_registration.h" //////////////////////////////////////////////////////////////////////////////// -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -167,7 +167,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void Vec4TAcc go(grad_output_ + d); grad_out[vec] = go; } - + for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { int32_t l = l_start + threadIdx.x; int64_t idx = l < L ? indices[indices_start + l] : 0; @@ -496,4 +496,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { } {%- endif %} {#-/* if not dense or not vbe */#} {%- endfor %} {#-/* for vbe */#} -// clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu diff --git a/fbgemm_gpu/codegen/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_meta_template.cpp rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_backward_split_template.cu rename to fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu diff --git a/fbgemm_gpu/codegen/embedding_ops_placeholder.cpp b/fbgemm_gpu/codegen/training/embedding_ops_placeholder.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_ops_placeholder.cpp rename to fbgemm_gpu/codegen/training/embedding_ops_placeholder.cpp diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp similarity index 99% rename from fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp index 665425bf58..6e68a1811d 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include "codegen/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm/Utils.h" diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu similarity index 99% rename from fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu index a60fe16ae1..c1b876a0b6 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu @@ -15,7 +15,7 @@ // See https://fburl.com/dw9ljh4h #} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu similarity index 99% rename from fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 9e1bf4febb..86c60aad6b 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -20,7 +20,7 @@ {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu similarity index 99% rename from fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index c878b8d024..48c1d5e9b9 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -16,7 +16,7 @@ #} {%- set wdesc = "weighted" if weighted else "unweighted" %} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu similarity index 99% rename from fbgemm_gpu/codegen/embedding_forward_split_template.cu rename to fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 0995c05e89..0cdeb12701 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -29,7 +29,7 @@ #include "fbgemm_gpu/embedding_op_registration.h" //////////////////////////////////////////////////////////////////////////////// {%- endif %} -#include "codegen/embedding_forward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp similarity index 97% rename from fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp rename to fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp index 31f08fc87d..51470bd17d 100644 --- a/fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp @@ -217,9 +217,9 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { "batch_index_select_dim0(" " Tensor inputs," " Tensor indices," - " int[] input_num_indices," - " int[] input_rows," - " int[] input_columns," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," " bool permute_output_dim_0_1=False) -> Tensor"); DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); } @@ -232,9 +232,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "batch_index_select_dim0(" " Tensor inputs," " Tensor indices," - " int[] input_num_indices," - " int[] input_rows," - " int[] input_columns," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," " bool permute_output_dim_0_1=False) -> Tensor"); DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); } diff --git a/fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp similarity index 100% rename from fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp rename to fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh similarity index 97% rename from fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh rename to fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index 66efc3c44e..2407ba3ced 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,6 +11,9 @@ #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#define GROUP_REDUCE_ALL_SUM(val, ...) \ + warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) + using namespace fbgemm_gpu; template < diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_host_template.cpp b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_host_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_optimizer_split_host_template.cpp rename to fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_host_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu rename to fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_template.cu b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_optimizer_split_template.cu rename to fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu diff --git a/fbgemm_gpu/codegen/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_split_host_pt2_autograd_template.cpp rename to fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_split_host_pt2_cpu_wrapper_template.cpp rename to fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp diff --git a/fbgemm_gpu/codegen/embedding_split_host_pt2_cuda_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_split_host_pt2_cuda_wrapper_template.cpp rename to fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp diff --git a/fbgemm_gpu/codegen/__init__.template b/fbgemm_gpu/codegen/training/python/__init__.template similarity index 99% rename from fbgemm_gpu/codegen/__init__.template rename to fbgemm_gpu/codegen/training/python/__init__.template index 28d111930f..42f49ee3c1 100644 --- a/fbgemm_gpu/codegen/__init__.template +++ b/fbgemm_gpu/codegen/training/python/__init__.template @@ -2,6 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/codegen/lookup_args.py b/fbgemm_gpu/codegen/training/python/lookup_args.py similarity index 100% rename from fbgemm_gpu/codegen/lookup_args.py rename to fbgemm_gpu/codegen/training/python/lookup_args.py diff --git a/fbgemm_gpu/codegen/optimizer_args.py b/fbgemm_gpu/codegen/training/python/optimizer_args.py similarity index 100% rename from fbgemm_gpu/codegen/optimizer_args.py rename to fbgemm_gpu/codegen/training/python/optimizer_args.py diff --git a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template similarity index 99% rename from fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template rename to fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index 0a2427f659..fec9238337 100644 --- a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + import torch {%- if is_experimental_optimizer %} import warnings diff --git a/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template similarity index 99% rename from fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template rename to fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template index 4405c3abcc..51385c53fc 100644 --- a/fbgemm_gpu/codegen/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + import torch from .optimizer_args import * from typing import Optional, List, Tuple diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu similarity index 100% rename from fbgemm_gpu/codegen/embedding_bounds_check.cu rename to fbgemm_gpu/codegen/utils/embedding_bounds_check.cu diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp similarity index 100% rename from fbgemm_gpu/codegen/embedding_bounds_check_host.cpp rename to fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp similarity index 98% rename from fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp rename to fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index ae684aa767..cc03d9ec2f 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -178,6 +178,9 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd // or DCE'd, etc. + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); m.def( "bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? B_offsets=None, SymInt max_B=-1) -> ()", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/experimental/example/CMakeLists.txt b/fbgemm_gpu/experimental/example/CMakeLists.txt new file mode 100644 index 0000000000..d6d4b55aae --- /dev/null +++ b/fbgemm_gpu/experimental/example/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +include(${CMAKEMODULES}/Utilities.cmake) + +################################################################################ +# Target Sources +################################################################################ + +set(experimental_example_cpp_source_files + src/example_ops.cpp) + +set(experimental_example_python_source_files + example/__init__.py + example/utils.py) + +################################################################################ +# Build Shared Library +################################################################################ + +add_library(fbgemm_gpu_experimental_example_py MODULE + ${experimental_example_cpp_source_files}) + +target_include_directories(fbgemm_gpu_experimental_example_py PRIVATE ${TORCH_INCLUDE_DIRS}) +target_link_libraries(fbgemm_gpu_experimental_example_py ${TORCH_LIBRARIES}) + +# Remove `lib` from the output artifact name `libfbgemm_gpu_py.so` +set_target_properties(fbgemm_gpu_experimental_example_py PROPERTIES PREFIX "") + +################################################################################ +# Install Shared Library and Python Files +################################################################################ + +install(TARGETS fbgemm_gpu_experimental_example_py + DESTINATION fbgemm_gpu/experimental/example) + +install(FILES ${experimental_example_python_source_files} + DESTINATION fbgemm_gpu/experimental/example) diff --git a/fbgemm_gpu/experimental/example/example/__init__.py b/fbgemm_gpu/experimental/example/example/__init__.py new file mode 100644 index 0000000000..d4bea7d448 --- /dev/null +++ b/fbgemm_gpu/experimental/example/example/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch + +try: + torch.ops.load_library( + os.path.join(os.path.dirname(__file__), "fbgemm_gpu_experimental_example_py.so") + ) +except Exception as e: + print(e) + +# Since __init__.py is only used in OSS context, we define `open_source` here +# and use its existence to determine whether or not we are in OSS context +open_source: bool = True diff --git a/fbgemm_gpu/experimental/example/example/utils.py b/fbgemm_gpu/experimental/example/example/utils.py new file mode 100644 index 0000000000..19a98377fb --- /dev/null +++ b/fbgemm_gpu/experimental/example/example/utils.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch + + +def add_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.ops.fbgemm.add_tensors_float(a, b) diff --git a/fbgemm_gpu/experimental/example/src/example_ops.cpp b/fbgemm_gpu/experimental/example/src/example_ops.cpp new file mode 100644 index 0000000000..585630373c --- /dev/null +++ b/fbgemm_gpu/experimental/example/src/example_ops.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace fbgemm_gpu::experimental { + +at::Tensor add_tensors_float(const at::Tensor& a, const at::Tensor& b) { + return a.to(at::kFloat) + b.to(at::kFloat); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def("add_tensors_float(Tensor a, Tensor b) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + m.impl( + "add_tensors_float", + torch::dispatch( + c10::DispatchKey::CPU, + TORCH_FN(fbgemm_gpu::experimental::add_tensors_float))); +} + +} // namespace fbgemm_gpu::experimental diff --git a/fbgemm_gpu/experimental/example/test/add_tensors_float_test.py b/fbgemm_gpu/experimental/example/test/add_tensors_float_test.py new file mode 100644 index 0000000000..5d0cd40e21 --- /dev/null +++ b/fbgemm_gpu/experimental/example/test/add_tensors_float_test.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch + +from fbgemm_gpu.experimental.example import utils + + +class ExampleTest(unittest.TestCase): + def test_add_tensors_float(self) -> None: + a = torch.tensor([1, 2, 3]) + b = torch.tensor([4, 5, 6]) + expected = torch.tensor([5, 7, 9], dtype=torch.float) + c = utils.add_tensors(a, b) + torch.testing.assert_close(c.cpu(), expected.cpu()) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp b/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp new file mode 100644 index 0000000000..ecf9f88f98 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/attention.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace fbgemm_gpu::gen_ai::attention { + +std::tuple gqa_attn_splitk_cuda( + const at::Tensor& XQ, + const at::Tensor& cache_K, + const at::Tensor& cache_V, + const at::Tensor& seq_positions, + const double qk_scale, + const int64_t num_split_ks, + const int64_t num_groups); + +} // namespace fbgemm_gpu::gen_ai::attention + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "gqa_attn_splitk(" + " Tensor XQ, " + " Tensor cache_K, " + " Tensor cache_V, " + " Tensor seq_positions, " + " float qk_scale, " + " int num_split_ks, " + " int num_int4_kv_groups=1" + ") -> (Tensor, Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { + m.impl( + "gqa_attn_splitk", + torch::dispatch( + c10::DispatchKey::CUDA, + TORCH_FN(fbgemm_gpu::gen_ai::attention::gqa_attn_splitk_cuda))); +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu new file mode 100644 index 0000000000..b1eab26cc5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu @@ -0,0 +1,1068 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#elif (defined(USE_ROCM)) +#include // @manual +#endif + +#ifndef USE_ROCM +#include +#endif + +#if ( \ + defined(__CUDA_ARCH__) && \ + ((__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900))) +#define USE_WMMA_FRAG +#endif + +#ifdef USE_ROCM +constexpr int32_t kThreadsPerWarp = 64; +constexpr int32_t kWarpsPerBlock = 16; +#else +constexpr int32_t kThreadsPerWarp = 32; +constexpr int32_t kWarpsPerBlock = 32; +#endif + +#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) +#define FINAL_MASK 0xffffffff + +namespace fbgemm_gpu::gen_ai::attention { + +constexpr int32_t D_H = 128; +constexpr int32_t MAX_T = 16384; +constexpr int SMEM_ADJUST_THRESHOLD = 48 * 1024; + +constexpr int kMaxHeads = 8; +// Fragments shapes used for wmma tensor core operations +constexpr int F_M = 8, F_N = 32, F_K = 16; +constexpr int SMEM_K_PAD = 2; +constexpr int SMEM_V_PAD = 2; +constexpr int SMEM_K_STRIDE = F_K + SMEM_K_PAD; +constexpr int SMEM_V_STRIDE = F_N + SMEM_V_PAD; + +// Use fewer warps for gqa_attn_splitk_wmma_kernel +constexpr int32_t kSplitKWarpsPerBlock = 4; + +namespace { + +static __host__ DEVICE_INLINE int32_t div_up(int32_t a, int32_t b) { + return (a + b - 1) / b; +}; + +static __host__ DEVICE_INLINE int32_t round_up(int32_t a, int32_t b) { + return ((a + b - 1) / b) * b; +} + +template +void set_gpu_max_dynamic_shared_memory( + func_t kernel, + const int smem_bytes, + const int device) { + // V100: 96 KB; A100: 160 KB; H100: 228 KB. + int max_shared_bytes = 0; + cudaDeviceGetAttribute( + &max_shared_bytes, +#ifndef __HIP_PLATFORM_AMD__ + cudaDevAttrMaxSharedMemoryPerBlockOptin, +#else + hipDeviceAttributeMaxSharedMemoryPerBlock, +#endif + device); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + TORCH_CHECK( + smem_bytes <= max_shared_bytes, + "Try to allocate ", + smem_bytes / 1024, + " KB of shared memory but only ", + max_shared_bytes / 1024, + " KB is available"); + + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// TODO: Include the following code from fbgemm_gpu header +struct __align__(16) bfx8 { + __nv_bfloat162 vals[4]; +}; + +struct __align__(16) halfx8 { + __half2 vals[4]; +}; + +// Reinterpret a pair of uint16_t (packed into a uint32_t) as half2, and +// multiply by rhs. +DEVICE_INLINE __half2 hmul_short2(uint32_t lhs, __half rhs) { +#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 +#ifndef __HALF2_TO_UI +// cuda_fp16.hpp +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#endif +#ifndef __HALF2_TO_CUI +// cuda_fp16.hpp +#define __HALF2_TO_CUI(var) *(reinterpret_cast(&(var))) +#endif + __half2 ret; + __half2 rhsp = make_half2(rhs, rhs); + asm("mul.f16x2 %0, %1, %2;" + : "=r"(__HALF2_TO_UI(ret)) + : "r"(__HALF2_TO_CUI(lhs)), "r"(__HALF2_TO_CUI(rhsp))); + return ret; +#else +#ifndef __HALF2_TO_UI +// cuda_fp16.hpp +#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) +#endif + __half2 lhs_h2; + __HALF2_TO_UI(lhs_h2) = lhs; + float2 fx = __half22float2(lhs_h2); + float2 fy = __half22float2(make_half2(rhs, rhs)); + float2 fr; + fr.x = fx.x * fy.x; + fr.y = fx.y * fy.y; + return __float22half2_rn(fr); +#endif +} + +__forceinline__ __device__ bfx8 +dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { + halfx8 res; + uint32_t v = packedVals; + // What's going on here, you might ask? We extra out 4-bit pairs of integers + // as 2xuint16 packed into an int32 via the mask operation, and then we + // convert them to half precision values. As these are all integers in [0, + // 15], we can actually just interpret the 4-bit integer values as + // half-precision values. We multiply by 4096 x 4096 to go from the 4-bit + // representation to the equivalent fp16 value, or alternatively 32768 * 512 + // (or 32 when we have shifted the 4-bit value up). See e.g. + // https://gist.github.com/ajtulloch/021254a291a95966bc509db4e34ffeff for a + // NumPy implementation. We do this dance because: a) doing bitwise operations + // on each 4-bit value is expensive on the ALU, and 4-bit to half is expensive + // on the XU. b) doing a 256-entry shared memory LUT on 8-bit pairs is + // expensive on SMEM throughput. Credit to @jhj. + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + v >>= 8; + res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + // ~5% perf gain is observed with the explicit type conversions using + // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using + // NVCC 11.0. Additionally, HIP compiler requires these explicit type + // conversions. + half shift_scale_x = __low2half(shift_scale); + half shift_scale_y = __high2half(shift_scale); + + // now, dequantize + auto shifts = __half2(shift_scale_y, shift_scale_y); + auto scales_lower_temp = __hmul(shift_scale_x, __float2half(512)); + auto scales_lower = __half2(scales_lower_temp, scales_lower_temp); + auto scales_upper_temp = __hmul(shift_scale_x, __float2half(32)); + auto scales_upper = __half2(scales_upper_temp, scales_upper_temp); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); + auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); + result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); + result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); + + return result; +} + +template < + typename kv_t, + int KVQuantNumGroups = 1, + typename kv_load_t = uint32_t> +__global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1) + gqa_attn_splitk_wmma_kernel( + const at::PackedTensorAccessor32 + XQ, + const at::PackedTensorAccessor64 + cache_K, + const at::PackedTensorAccessor64 + cache_V, + at::PackedTensorAccessor32 out_splitK, + const at::PackedTensorAccessor32 + seq_positions, + at::PackedTensorAccessor32 metadata, + float qk_scale) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + // Need kWarpsPerBlock == blockDim.y; + // Need D_H == 128 + static_assert(kWarpsPerBlock <= kThreadsPerWarp, ""); + + extern __shared__ __align__(16) float smem[]; + + // Each block handles a single query, split-K partition, and max of 8 query + // heads + const int32_t b = blockIdx.x; + // Head block + const int32_t h_block = blockIdx.y; + // Split-K block + const int32_t s_block = blockIdx.z; + + const int32_t H_max = XQ.size(2); + const int32_t num_split_ks = gridDim.z; + const int32_t warp_idx = threadIdx.y; + + // Note: this is decoding case where we attent to current and all previous + // tokens. + const auto t_max = seq_positions[b] + 1; + + // Assume cache_K/cache_V is contiguous + const auto* cache_K_base = &cache_K[b][0][0][0]; + const auto* cache_V_base = &cache_V[b][0][0][0]; + constexpr bool USE_INT4 = std::is_same::value; + + // Only used for int4 + constexpr int32_t INT4_PARAM_BYTES = 4 * KVQuantNumGroups; + constexpr int32_t D_H_bytes = D_H / 2 + INT4_PARAM_BYTES; + constexpr int32_t INT4_GROUP_SIZE = D_H / KVQuantNumGroups; + + // Compute S[MAX_T] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) + // Split T across warps in a block. + const int32_t t_total = round_up(t_max, num_split_ks); + const int32_t t_per_block = t_total / num_split_ks; + const int32_t t_per_block_start = t_per_block * s_block; + const int32_t t_per_block_end = min(t_per_block * (s_block + 1), t_max); + const int32_t t_total_per_block = t_per_block_end - t_per_block_start; + + // Compute start and end heads + const int32_t h_per_block_start = kMaxHeads * h_block; + const int32_t h_per_block_end = min(kMaxHeads * (h_block + 1), H_max); + const int32_t h_total_per_block = h_per_block_end - h_per_block_start; + + // Early return if there is no work to do + if (t_total_per_block <= 0) { + return; + } + + using namespace nvcuda; + // Number of vectors for K and V (vector type in this case is uint32_t) + constexpr int KV_NUM_VECS = 2; + // Number of elements to load when using the kv_load_t type (kv_load_t is 32 + // bits for KVQuantNumGroups = 1 and 64 bits for KVQuantNumGroups = 4) + constexpr int KV_LD_NUM_ELS = + (KV_NUM_VECS * sizeof(uint32_t)) / sizeof(kv_load_t); + + wmma::fragment + q_frag; + wmma::fragment + k_frag; + wmma::fragment c_frag; + + // Get shared memory pointers + static_assert( + F_N >= F_K, "F_N must be >= F_K because we allocate smem based on F_N"); + const int ldc = round_up(t_total_per_block, F_N); + auto smem_max = smem + max(h_total_per_block, F_M) * ldc; + __nv_bfloat16* smem_staging = reinterpret_cast<__nv_bfloat16*>( + smem_max + max(h_total_per_block, F_M) * kSplitKWarpsPerBlock); + float* smem_out = reinterpret_cast( + smem_staging + + kSplitKWarpsPerBlock * max(F_N * SMEM_K_STRIDE, F_K * SMEM_V_STRIDE)); + constexpr float NEG_FINF = -std::numeric_limits::infinity(); + +#ifdef USE_WMMA_FRAG + // The kernel can compute max_qk directly from the WMMA fragment on A100/H100 + // Each thread handles 2 heads according to the tensor core layout + constexpr int HEADS_PER_THREAD_QK = 2; + float max_qk[HEADS_PER_THREAD_QK]; + max_qk[0] = NEG_FINF; + max_qk[1] = NEG_FINF; +#else + // TODO: Support computing max_qk from the WMMA fragment on other GPUs + float max_qk = NEG_FINF; +#endif + + // Compute Q @ K^T + for (auto t_start = t_per_block_start + warp_idx * F_N; + t_start < t_per_block_end; + t_start += kSplitKWarpsPerBlock * F_N) { + constexpr int32_t K_UNROLLS = 4; + __half2 k_scales; + kv_load_t k_vals[KV_LD_NUM_ELS * K_UNROLLS]; + + // Init the accumulator with zeros + wmma::fill_fragment(c_frag, 0.0f); + + // Intra-warp reduction within across D_H + for (auto d_start = 0; d_start < D_H; d_start += F_K) { + if (USE_INT4 && d_start % INT4_GROUP_SIZE == 0) { + // Load K scales for INT4 K + // Each thread operates on a single row (T dim). Columns are split into + // KVQuantNumGroups groups and each group has the same K scales + if (t_start + threadIdx.x < min(t_start + F_N, t_per_block_end)) { + auto* k_ = cache_K_base + (t_start + threadIdx.x) * D_H_bytes; + const int group_id = d_start / INT4_GROUP_SIZE; + k_scales = reinterpret_cast(k_)[group_id]; + } + } + + // Load Q fragment + wmma::load_matrix_sync( + q_frag, + reinterpret_cast( + &XQ[b][0][h_per_block_start][d_start]), + D_H); + + // Load K fragment + if (USE_INT4) { + // Load and dequantize INT4 K + // Each thread loads 16 columns (D dim) from one row (T dim). + // Each row is handled by one thread. + const auto t = t_start + threadIdx.x; + const auto t_scope = min(t_start + F_N, t_per_block_end); + + // Prefetch 4 sets of Ks (load every 4 d_starts) + if (d_start % (K_UNROLLS * F_K) == 0) { + // Since F_N = 32, each thread handles only one row (T dim). Thus a + // for-loop is not required + if (t < t_scope) { + // Ratio between the INT4 bytes and kv_load_t bytes + constexpr int KV_LOAD_T_INT4_RATIO = 2 * sizeof(kv_load_t); + const auto k_offset = + t * D_H_bytes + INT4_PARAM_BYTES + d_start / 2; + const auto* cache_k_ = + reinterpret_cast(cache_K_base + k_offset); +#pragma unroll K_UNROLLS + for (int k_unroll = 0; k_unroll < K_UNROLLS; k_unroll++) { + auto* k_vals_ = k_vals + k_unroll * KV_LD_NUM_ELS; + const auto* k_ = + cache_k_ + ((k_unroll * F_K) / KV_LOAD_T_INT4_RATIO); +#pragma unroll KV_LD_NUM_ELS + for (auto k_i = 0; k_i < KV_LD_NUM_ELS; k_i++) { + k_vals_[k_i] = k_[(k_i * 8) / KV_LOAD_T_INT4_RATIO]; + } + } + } + } + + if (t < t_scope) { + // Shift pointers + const auto k_offset = + ((d_start % (K_UNROLLS * F_K)) / F_K) * KV_NUM_VECS; + const auto smem_offset = + (warp_idx * F_N + t - t_start) * SMEM_K_STRIDE; + const auto* k_vals_ = reinterpret_cast(k_vals) + k_offset; + auto* smem_staging_ = smem_staging + smem_offset; +#pragma unroll KV_NUM_VECS + for (int vec = 0; vec < KV_NUM_VECS; ++vec) { + // Dequantize 8 INT4s to 8 BF16s and store the results in shared + // memory + const auto k_deq = dequantize_permuted_int4(k_vals_[vec], k_scales); + auto* smem_s = + reinterpret_cast<__nv_bfloat162*>(smem_staging_ + vec * 8); +#pragma unroll + for (int i = 0; i < 4; i++) { + smem_s[i] = k_deq.vals[i]; + } + } + } + // Load BF16 values to K fragment + wmma::load_matrix_sync( + k_frag, + smem_staging + warp_idx * F_N * SMEM_K_STRIDE, + SMEM_K_STRIDE); + } else if (t_start + F_N <= MAX_T) { + // Load BF16 K to K fragment + wmma::load_matrix_sync( + k_frag, + reinterpret_cast(cache_K_base) + + t_start * D_H + d_start, + D_H); + } else { + // Handle the remainder of T to avoid load_matrix_sync to K will OOB + // Load 8 bfloat16s at a time for 16B loads + constexpr int kThreadsPerF_K = F_K / 8; + for (int t = t_start + threadIdx.x / kThreadsPerF_K; + t < min(t_start + F_N, t_per_block_end); + t += kThreadsPerWarp / kThreadsPerF_K) { + const int d = d_start + threadIdx.x % kThreadsPerF_K * 8; + const auto smem_offset = + (warp_idx * F_N + t - t_start) * F_K + d - d_start; + *(reinterpret_cast(smem_staging + smem_offset)) = + *(reinterpret_cast(cache_K_base + t * D_H + d)); + } + // Load BF16 values to K fragment + wmma::load_matrix_sync( + k_frag, smem_staging + warp_idx * F_N * F_K, F_K); + } + // Compute matrix multiplication + wmma::mma_sync(c_frag, q_frag, k_frag, c_frag); + } + +#ifdef USE_WMMA_FRAG + // The following fragment (tensor core) layout is specific to the A100/H100 + // GPU Compute max_qk directly from the fragment + constexpr int C_FRAG_SIZE = F_M * F_N; + // A quadrant has 64 elements + constexpr int C_QUAD_SIZE = (F_M * F_N) / 4; + // A quadrant of a quadrant has 16 elements + constexpr int C_DOUBLE_QUAD_SIZE = C_QUAD_SIZE / 4; + // Half of a quadrant of a quadrant has 8 elements + constexpr int C_HALF_DOUBLE_QUAD_SIZE = C_DOUBLE_QUAD_SIZE / 2; + if (t_start < t_per_block_end) { + const auto max_col = min(t_start + F_N, t_per_block_end) - t_start; + // The column stride that each thread processes is 8 + // The number of threads processing each column is 4 + const int col_group = max_col / 8; + const int cols_in_group = max_col % 8; + const int max_elements = + threadIdx.x < cols_in_group * 4 ? (col_group + 1) * 2 : col_group * 2; + const int h_start = (threadIdx.x % 4) * 2; + + const int frag_offset = + static_cast((t_start - t_per_block_start) / F_N) * C_FRAG_SIZE; + const int doub_quad_offset = threadIdx.x % 4 * C_DOUBLE_QUAD_SIZE; + const int pos = threadIdx.x >> 2; + auto* smem_ = smem + frag_offset + doub_quad_offset + pos; + + for (auto i = 0; i < max_elements && i < c_frag.num_elements; i++) { + const int h_i = i % 2; + if (h_i < h_total_per_block - h_start) { + const auto qk = c_frag.x[i]; + const auto qk_acc = qk * qk_scale; + max_qk[h_i] = max(max_qk[h_i], qk_acc); + + const int quad_offset = (i >> 1) * C_QUAD_SIZE; + const int half_doub_quad_offset = (i % 2) * C_HALF_DOUBLE_QUAD_SIZE; + smem_[quad_offset + half_doub_quad_offset] = qk_acc; + } + } + } +#else + // Store matrix multiplication results to shared memory + wmma::store_matrix_sync( + smem + t_start - t_per_block_start, c_frag, ldc, wmma::mem_row_major); + + // Scale the results and compute max for each head from shared memory + const int nThreadsPerH = kThreadsPerWarp / h_total_per_block; + // Each thread computes only one head + const int h = threadIdx.x / nThreadsPerH; + if (h < h_total_per_block) { + for (int t = t_start + (threadIdx.x % nThreadsPerH); + t < min(t_start + F_N, t_per_block_end); + t += nThreadsPerH) { + const float qk_acc = smem[h * ldc + t - t_per_block_start] * qk_scale; + max_qk = max(max_qk, qk_acc); + } + } + + // Compute max within a warp + // threadIdx.x % nThreadsPerH == 0 are master threads + for (int offset = nThreadsPerH >> 1; offset >= 1; offset >>= 1) { + max_qk = max(max_qk, __shfl_down_sync(FINAL_MASK, max_qk, offset)); + } +#endif + } + +#ifdef USE_WMMA_FRAG + // At this point, every thread has their local max_qk's + // Compute max_qk within a warp +#pragma unroll HEADS_PER_THREAD_QK + for (auto h_i = 0; h_i < HEADS_PER_THREAD_QK; h_i++) { + for (auto offset = 4; offset < kThreadsPerWarp; offset <<= 1) { + max_qk[h_i] = + max(max_qk[h_i], + __shfl_sync(FINAL_MASK, max_qk[h_i], threadIdx.x + offset)); + } + } + + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + if (threadIdx.x < 4) { + const auto h = threadIdx.x * 2; + if (t_per_block_start + warp_idx * F_N < t_per_block_end) { + smem_max[warp_idx * h_total_per_block + h] = max_qk[0]; + smem_max[warp_idx * h_total_per_block + h + 1] = max_qk[1]; + } else { + smem_max[warp_idx * h_total_per_block + h] = NEG_FINF; + smem_max[warp_idx * h_total_per_block + h + 1] = NEG_FINF; + } + } +#else + // Use shared reduction to compute max and compute softmax on shared memory. + // write max acc + const int max_qk_threads_per_h = kThreadsPerWarp / h_total_per_block; + if (threadIdx.x % max_qk_threads_per_h == 0) { + const auto h = threadIdx.x / max_qk_threads_per_h; + smem_max[warp_idx * h_total_per_block + h] = + (t_per_block_start + warp_idx * F_N < t_per_block_end) ? max_qk + : NEG_FINF; + } +#endif + + __syncthreads(); + + const auto h = threadIdx.x; + for (int w = kSplitKWarpsPerBlock >> 1; w >= 1; w >>= 1) { + if (warp_idx < w && h < h_total_per_block) { + smem_max[warp_idx * h_total_per_block + h] = + max(smem_max[warp_idx * h_total_per_block + h], + smem_max[(warp_idx + w) * h_total_per_block + h]); + } + __syncthreads(); + } + + const int hPerWarp = div_up(h_total_per_block, kSplitKWarpsPerBlock); + const int h_begin = warp_idx * hPerWarp; + const int h_end = min(h_begin + hPerWarp, h_total_per_block); + + // Complete max computation for each head at this point + const int threads_per_h = kThreadsPerWarp / (h_end - h_begin); + float head_sum = 0; +#ifdef USE_WMMA_FRAG + // A100/H100 GPU will only store max here and compute head sum later + // Only master thread sets the max metadata + if (threadIdx.x % threads_per_h == 0 && h_end > h_begin) { + const int h = h_begin + (threadIdx.x / threads_per_h); + const auto max_qk_ = smem_max[h]; + metadata[b][0][s_block][h_per_block_start + h] = max_qk_; + smem_max[h] = max_qk_; + } + __syncthreads(); + const auto max_qk_ = smem_max[threadIdx.x / 4]; +#else + // Non-A100/H100 GPUs will store both max and head sum here + if (h_begin + threadIdx.x / threads_per_h < h_end) { + const int h = h_begin + threadIdx.x / threads_per_h; + const auto max_qk_ = smem_max[h]; + auto* smem_ = smem + h * ldc; + for (int t = threadIdx.x % threads_per_h; t < t_total_per_block; + t += threads_per_h) { + const float p = __expf(smem_[t] * qk_scale - max_qk_); + // Compute the sum value for each head + head_sum += p; + smem_[t] = p; + } + } + // Compute sum within a warp + for (int offset = threads_per_h >> 1; offset >= 1; offset >>= 1) { + head_sum += __shfl_down_sync(FINAL_MASK, head_sum, offset); + } + + // Store max and sum to global memory + if (threadIdx.x % threads_per_h == 0 && h_end > h_begin) { + const int h = h_begin + (threadIdx.x / threads_per_h); + metadata[b][0][s_block][h_per_block_start + h] = smem_max[h]; + metadata[b][1][s_block][h_per_block_start + h] = head_sum; + } +#endif + + // Each thread loads two uint32_t's in each iteration + kv_load_t v_vals[KV_LD_NUM_ELS]; + __half2 v_scales; + + // Prefetch V + if (USE_INT4) { + const auto d_start = warp_idx * F_N; + const int t_chunk_id = threadIdx.x % 2; + const int group_id = d_start / INT4_GROUP_SIZE; + int t = t_per_block_start + threadIdx.x / 2; + if (t < min(t_per_block_start + F_K, t_per_block_end)) { + const auto* v_ = cache_V_base + t * D_H_bytes; + v_scales = reinterpret_cast(v_)[group_id]; +#pragma unroll KV_LD_NUM_ELS + for (int vec = 0; vec < KV_LD_NUM_ELS; vec++) { + int d = d_start + (vec + t_chunk_id * KV_NUM_VECS) * 8; + v_vals[vec] = + *reinterpret_cast(&v_[d / 2 + INT4_PARAM_BYTES]); + } + } + } + +#ifndef USE_WMMA_FRAG + // Non-A100/H100 GPUs convert P from FP32 to BF16 inplace (i.e., using the + // same shared memory space) here + constexpr int32_t CONV_UNROLLS = 4; + __nv_bfloat16* smem_bf16 = reinterpret_cast<__nv_bfloat16*>(smem); + float2 p[CONV_UNROLLS]; + const int t_stride = blockDim.x * blockDim.y * 2; + const int t_rounds = div_up(t_total_per_block, t_stride); + const int global_tid = warp_idx * blockDim.x + threadIdx.x; + + // Ensure that all threads finish writing to smem before modifying it in the + // loop below + __syncthreads(); + + // All threads work on the same head in every iteration + for (int t_i = 0; t_i < t_rounds; t_i++) { + const int t_start = t_i * t_stride + global_tid * 2; + const int global_t_start = t_per_block_start + t_start; + auto* smem_fp32_ = smem + t_start; + auto* smem_bf16_ = smem_bf16 + t_start; + + for (int h_i = 0; h_i < div_up(h_total_per_block, CONV_UNROLLS); h_i++) { + // Read FP32 +#pragma unroll + for (int h_j = 0; h_j < CONV_UNROLLS; h_j++) { + const int h = h_i * CONV_UNROLLS + h_j; + const int smem_idx = h * ldc; + + p[h_j].x = global_t_start < t_per_block_end ? smem_fp32_[smem_idx] : 0; + p[h_j].y = + global_t_start + 1 < t_per_block_end ? smem_fp32_[smem_idx + 1] : 0; + } + + // Sync threads to make sure that all threads finish reading data before + // overwriting the memory with new values + __syncthreads(); + + // Convert and write BF16 +#pragma unroll + for (int h_j = 0; h_j < CONV_UNROLLS; h_j++) { + // It is safe to use nv_bfloat162 because smem was float + if (global_t_start < t_per_block_end) { + const int h = h_i * CONV_UNROLLS + h_j; + const int smem_idx = h * ldc; + + *reinterpret_cast(&smem_bf16_[smem_idx]) = + __float22bfloat162_rn(p[h_j]); + } + } + } + } + __syncthreads(); + + // Fill smem with zeros for t_total <= t < F_K to avoid nan + // Round up to the nv_bfloat162 granularity because if t_total_per_block is + // an odd number, the FP32->BF16 conversion should already take care of + // writing zero to .y + const int t_zero_start = round_up(t_total_per_block, 2); + const nv_bfloat162 zero_bf162 = {0, 0}; + const int t_mul_F_K = round_up(t_total_per_block, F_K); + + for (auto h = warp_idx; h < h_total_per_block; h += blockDim.y) { + // Each thread operates on two BF16 values to avoid smem bank conflict + for (auto t = t_zero_start + threadIdx.x * 2; t < t_mul_F_K; + t += kThreadsPerWarp * 2) { + *reinterpret_cast(&smem_bf16[h * ldc + t]) = zero_bf162; + } + } + + __syncthreads(); +#endif + + // Split D_H across warps in a block + // each warp compute sum(t_subset) P[H, t] * V[t_subset, d_subset] + // outputs are of size float[H, D] + + wmma::fragment + v_frag; + + // Compute P @ V + // Parallelize D_H among warps. Note only 4 warps will do the work here. + for (auto d_start = warp_idx * F_N; d_start < D_H; + d_start += kSplitKWarpsPerBlock * F_N) { + // Init the accumulator with zeros + wmma::fill_fragment(c_frag, 0.0f); + + // Intra-warp reduction across T + for (auto t_start = t_per_block_start; t_start < t_per_block_end; + t_start += F_K) { +#ifdef USE_WMMA_FRAG + // A100/H100 GPU reads FP32 from shared memory, convert it into BF16, and + // writes data directly to the WMMA fragment. + const int head = threadIdx.x / 4; + constexpr int NUM_COL_VECS = 2; + constexpr int P_FRAG_SIZE = F_M * F_K; + constexpr int P_HALF_SIZE = P_FRAG_SIZE / 2; + const int frag_offset = + static_cast((t_start - t_per_block_start) / F_K) * P_FRAG_SIZE; + const int pos = threadIdx.x * 2; + const auto* smem_ = smem + frag_offset + pos; + const auto t_start_ = t_start + (threadIdx.x % 4) * 2; + const auto t_scope = min(t_start + F_K, t_per_block_end); + + for (auto vec = 0; vec < NUM_COL_VECS; vec++) { + float2 p; + const int t = t_start_ + 8 * vec; + if (head < h_total_per_block && t < t_scope) { + p = *(reinterpret_cast(&smem_[vec * P_HALF_SIZE])); + p.x = __expf(p.x - max_qk_); + p.y = t + 1 < t_scope ? __expf(p.y - max_qk_) : 0; + + // Compute head sum here + if (d_start == 0) { + head_sum += p.x + p.y; + } + } else { + p.x = 0; + p.y = 0; + } + + // TODO: store BF16 results in smem for D_H > 128 or F_N < 32 + // TODO: use vector store? + // FP32->BF16 conversion is implicit + q_frag.x[vec * 2] = p.x; + q_frag.x[vec * 2 + 1] = p.y; + } + __syncwarp(); +#else + // Non-A100/H100 GPUs already did the FP32->BF16 conversion before + // entering this loop. Thus, data can be loaded from shared memory + wmma::load_matrix_sync( + q_frag, smem_bf16 + t_start - t_per_block_start, ldc); +#endif + + // Load V fragment + if (USE_INT4) { + // Load and dequantize INT4 V + // Each thread loads 16 columns (D dim) from one row (T dim). + // Each row is handled by two threads + const auto t_scope = min(t_start + F_K, t_per_block_end); + const int t = t_start + threadIdx.x / 2; + const int t_chunk_id = threadIdx.x % 2; + if (t < t_scope) { + const auto smem_offset = + (warp_idx * F_K + t - t_start) * SMEM_V_STRIDE; + auto* smem_staging_ = smem_staging + smem_offset; +#pragma unroll KV_NUM_VECS + for (int vec = 0; vec < KV_NUM_VECS; ++vec) { + const int smem_d = vec + t_chunk_id * KV_NUM_VECS; + // Dequantize 8 INT4s to 8 BF16s and store the results in shared + // memory + const auto v_vals_ = reinterpret_cast(v_vals)[vec]; + const auto v_deq = dequantize_permuted_int4(v_vals_, v_scales); + auto* smem_s = + reinterpret_cast<__nv_bfloat162*>(smem_staging_ + smem_d * 8); +#pragma unroll + for (int i = 0; i < 4; i++) { + smem_s[i] = v_deq.vals[i]; + } + } + } else { + // Need to fill zeros to avoid nan + if (t < t_start + F_K) { + const auto smem_offset = + (warp_idx * F_K + t - t_start) * SMEM_V_STRIDE; + auto* smem_staging_ = smem_staging + smem_offset; +#pragma unroll KV_NUM_VECS + for (int vec = 0; vec < KV_NUM_VECS; ++vec) { + const int smem_d = vec + t_chunk_id * KV_NUM_VECS; + auto* smem_s = + reinterpret_cast(smem_staging_ + smem_d * 8); +#pragma unroll + for (int i = 0; i < 4; i++) { + smem_s[i] = 0; + } + } + } + } + + int t_start_next = t_start + F_K; + t_start_next = + t_start_next < t_per_block_end ? t_start_next : t_per_block_start; + const int d_start_next = t_start_next < t_per_block_end + ? d_start + : d_start + kSplitKWarpsPerBlock * F_N; + const int t_next = t_start_next + threadIdx.x / 2; + + if (t_next < min(t_start_next + F_K, t_per_block_end) && + d_start_next < D_H) { + auto* v_ = cache_V_base + t_next * D_H_bytes; + const auto group_id = d_start_next / INT4_GROUP_SIZE; + v_scales = reinterpret_cast(v_)[group_id]; +#pragma unroll KV_LD_NUM_ELS + for (int vec = 0; vec < KV_LD_NUM_ELS; vec++) { + const int d = d_start_next + (vec + t_chunk_id * KV_NUM_VECS) * 8; + v_vals[vec] = *reinterpret_cast( + &v_[d / 2 + INT4_PARAM_BYTES]); + } + } + // Load BF16 values to V fragment + wmma::load_matrix_sync( + v_frag, + smem_staging + warp_idx * SMEM_V_STRIDE * F_K, + SMEM_V_STRIDE); + } else if (t_start + F_K <= t_per_block_end) { + // Load BF16 V to V fragment + wmma::load_matrix_sync( + v_frag, + reinterpret_cast(cache_V_base) + + t_start * D_H + d_start, + D_H); + } else { + // Handle the remainder of T to avoid load_matrix_sync to V will OOB + int t = t_start; + const auto smem_offset = (warp_idx * F_K - t_start) * F_N - d_start; + auto* smem_staging_ = smem_staging + smem_offset; + for (; t < min(t_start + F_K, t_per_block_end); ++t) { + auto* smem_staging_t_ = smem_staging_ + t * F_N; + auto* v_ = + reinterpret_cast(cache_V_base) + t * D_H; + for (int d = d_start + threadIdx.x; d < d_start + F_N; + d += kThreadsPerWarp) { + smem_staging_t_[d] = v_[d]; + } + } + // Need to fill zeros to avoid nan + for (; t < t_start + F_K; ++t) { + auto* smem_staging_t_ = smem_staging_ + t * F_N; + for (int d = d_start + threadIdx.x; d < d_start + F_N; + d += kThreadsPerWarp) { + smem_staging_t_[d] = 0; + } + } + // Load BF16 values to V fragment + wmma::load_matrix_sync( + v_frag, smem_staging + warp_idx * F_N * F_K, F_N); + } + // Compute matrix multiplication + wmma::mma_sync(c_frag, q_frag, v_frag, c_frag); + } + + // Store final results in global memory + if (h_total_per_block == F_M) { + // For this common case, no need to worry about OOB. + auto* o_ = &out_splitK[b][s_block][h_per_block_start][d_start]; + wmma::store_matrix_sync(o_, c_frag, D_H, wmma::mem_row_major); + } else { + wmma::store_matrix_sync( + smem_out + F_M * d_start, c_frag, F_N, wmma::mem_row_major); + + for (int h = 0; h < h_total_per_block; ++h) { + // [B, H, num_split_ks, 1, D_H] + auto* o_ = &out_splitK[b][s_block][h_per_block_start + h][d_start]; + for (int d = threadIdx.x; d < F_N; d += kThreadsPerWarp) { + o_[d] = smem_out[F_M * d_start + h * F_N + d]; + } + } + } + } // d_start + +#ifdef USE_WMMA_FRAG + // A100/H100 GPU has to store head sum in global memory here because it + // computes this value during the P @ V computation + if (warp_idx == 0) { + for (int offset = 2; offset >= 1; offset >>= 1) { + head_sum += __shfl_sync(FINAL_MASK, head_sum, threadIdx.x + offset); + } + + const int head = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && head < h_total_per_block) { + metadata[b][1][s_block][h_per_block_start + head] = head_sum; + } + } +#endif +#endif +} + +__global__ void gqa_attn_splitk_reduce_wmma_kernel( + // {B, H, num_split_ks, D_H} + const at::PackedTensorAccessor32 + out_splitK, + // {B, H, 2, num_split_ks, 1}, + const at::PackedTensorAccessor32 metadata, + const at::PackedTensorAccessor32 + seq_positions, + // [B, 1, H, D] + at::PackedTensorAccessor32 O) { + const int32_t b = blockIdx.x; + const int32_t h = blockIdx.y; + const auto num_split_ks = out_splitK.size(1); + const auto d = threadIdx.y * kThreadsPerWarp + threadIdx.x; + + float m = metadata[b][0][0][h]; + float l_sum = metadata[b][1][0][h]; + float acc = out_splitK[b][0][h][d]; + + const int32_t t_max = seq_positions[b] + 1; + const int32_t t_total = round_up(t_max, num_split_ks); + const int32_t t_per_block = t_total / num_split_ks; + const int32_t num_split_ks_max = div_up(t_max, t_per_block); + + for (int k = 1; k < num_split_ks_max; ++k) { + float m_k = metadata[b][0][k][h]; + float l_k = metadata[b][1][k][h]; + float acc_k = out_splitK[b][k][h][d]; + + float m_new = max(m, m_k); + float alpha; + if (m_k < m) { + alpha = __expf(m_k - m_new); + acc_k *= alpha; + l_k *= alpha; + } else { + alpha = __expf(m - m_new); + acc *= alpha; + l_sum *= alpha; + } + + m = m_new; + l_sum += l_k; + acc += acc_k; + } + + O[b][0][h][d] = acc / l_sum; +} +} // namespace + +/// @ingroup experimental-gen-ai +/// +/// @brief Decoding Grouped Query Attention Split-K w/ BF16/INT4 KV +/// +/// The CUDA implementation of decoding Grouped Query Attention (GQA) +/// that supports BF16 and INT4 KV cache and BF16 input query. It +/// currently only supports the max context length of 16384, the fixed +/// head dimension of 128, and only one KV cache head. It supports an +/// arbitrary number of query heads. +/// +/// @param XQ Input query; shape = (B, 1, H_Q, D), where B = batch +/// size, H_Q = num query heads, D = head dimension (fixed +/// to 128) +/// @param cache_K K cache; shape = (B, MAX_T, H_KV, D), where MAX_T = +/// max context length (fixed to 16384), and H_KV = num +/// KV cache heads (fixed to 1) +/// @param cache_V V cache; shape = (B, MAX_T, H_KV, D) +/// @param seq_positions Sequence position (contains the actual +/// length of each token); shape = (B) +/// @param qk_scale The scale that is applied after QK^T +/// @param num_split_ks The number of split Ks (controlling the +/// amount of parallelism in the context length +/// dimension (MAX_T)) +/// @param num_int4_kv_groups The number of groups for group-wise INT4 +/// quantization for each KV token (each +/// group uses the same scale and bias for +/// quantization) +/// +/// @return A tuple of the combined split-K output, the +/// non-combined split-K output, and the split-K metadata +/// (containing max QK^T, and softmax(QK^T) head sum) +std::tuple gqa_attn_splitk_cuda( + const at::Tensor& XQ, + const at::Tensor& cache_K, + const at::Tensor& cache_V, + const at::Tensor& seq_positions, + const double qk_scale, + const int64_t num_split_ks, + const int64_t num_int4_kv_groups) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK( + dprops->major >= 8, + "Too old compute capability major version to run gqa_attn_splitk_wmma ", + dprops->major); + + at::OptionalDeviceGuard guard(XQ.device()); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(XQ.is_contiguous()); + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + TORCH_CHECK(cache_K.is_contiguous()); + TORCH_CHECK(cache_V.is_contiguous()); + TORCH_CHECK(seq_positions.is_cuda()); + + // Check input shapes + TORCH_CHECK(cache_K.size(1) <= MAX_T); + TORCH_CHECK( + cache_K.size(2) == 1, + "Currently gqa_attn_splitk only support num KV heads = 1"); + TORCH_CHECK( + cache_V.size(2) == 1, + "Currently gqa_attn_splitk only support num KV heads = 1"); + + if (cache_K.dtype() == at::kBFloat16) { + TORCH_CHECK(cache_K.size(3) == D_H); + } else { + TORCH_CHECK( + num_int4_kv_groups == 1 || num_int4_kv_groups == 4, + "Invalid num_int4_kv_groups ", + num_int4_kv_groups); + auto qparam_offset = 4 * num_int4_kv_groups; + TORCH_CHECK(cache_K.size(3) == D_H / 2 + qparam_offset); + } + + const auto B = XQ.size(0); + const auto H = XQ.size(2); + + auto out_splitK = + at::empty({B, num_split_ks, H, D_H}, XQ.options().dtype(at::kFloat)); + auto O = at::empty_like(XQ); + auto metadata = at::empty({B, 2, num_split_ks, H}, out_splitK.options()); + + // TODO: Check if the grid size is valid + const int32_t H_blocks = div_up(H, kMaxHeads); + dim3 blocks(B, H_blocks, num_split_ks); + dim3 threads(kThreadsPerWarp, kSplitKWarpsPerBlock); + + if (B == 0) { + return {O, out_splitK, metadata}; + } + + const int32_t t_per_block = div_up(cache_K.size(1), num_split_ks); + // This is called ldc inside gqa_attn_splitk_wmma_kernel kernel + const int32_t t_per_block_round_up = round_up(t_per_block, F_N); + + // QK^T and P smem: max(kMaxHeads, F_M) * t_per_block_round_up floats + // Max QK^T smem: max(kMaxHeads, F_M) * kSplitKWarpsPerBlock floats + // Stagging smem: smem_staging_size bfloat16s + // Output smem: max(kMaxHeads, F_M) * D_H floats + const int32_t smem_staging_size = kSplitKWarpsPerBlock * + max(F_N * SMEM_K_STRIDE, F_K * SMEM_V_STRIDE) * sizeof(at::BFloat16); + int32_t smem = max(kMaxHeads, F_M) * + (t_per_block_round_up + kSplitKWarpsPerBlock + D_H) * sizeof(float) + + smem_staging_size; + +#define CALL_GQA_ATTN_SPLITK_WMMA(CACHE_TYPE, NUM_GROUPS, KV_LOAD_T) \ + const auto gqa_fn = \ + gqa_attn_splitk_wmma_kernel; \ + if (smem > SMEM_ADJUST_THRESHOLD) { \ + set_gpu_max_dynamic_shared_memory(gqa_fn, smem, XQ.get_device()); \ + } \ + gqa_fn<<>>( \ + XQ.packed_accessor32(), \ + cache_K.packed_accessor64(), \ + cache_V.packed_accessor64(), \ + out_splitK.packed_accessor32(), \ + seq_positions.packed_accessor32(), \ + metadata.packed_accessor32(), \ + qk_scale); \ + C10_CUDA_KERNEL_LAUNCH_CHECK() + + if (cache_K.dtype() == at::kBFloat16) { + CALL_GQA_ATTN_SPLITK_WMMA(at::BFloat16, 1, uint32_t); + } else { + if (num_int4_kv_groups == 1) { + CALL_GQA_ATTN_SPLITK_WMMA(uint8_t, 1, uint32_t); + } else { + CALL_GQA_ATTN_SPLITK_WMMA(uint8_t, 4, uint2); + } + } + +#undef CALL_GQA_ATTN_SPLITK_WMMA + + gqa_attn_splitk_reduce_wmma_kernel<<< + dim3(B, H), + dim3(kThreadsPerWarp, D_H / kThreadsPerWarp), + 0, + at::cuda::getCurrentCUDAStream()>>>( + out_splitK.packed_accessor32(), + metadata.packed_accessor32(), + seq_positions.packed_accessor32(), + O.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return {O, out_splitK, metadata}; +} + +} // namespace fbgemm_gpu::gen_ai::attention diff --git a/fbgemm_gpu/experimental/gen_ai/tests/attention/gqa_test.py b/fbgemm_gpu/experimental/gen_ai/tests/attention/gqa_test.py new file mode 100755 index 0000000000..4a77594ffc --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/tests/attention/gqa_test.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List, Tuple + +import hypothesis.strategies as st +import numpy as np +import torch + +from hypothesis import given, settings, Verbosity + +VERBOSITY: Verbosity = Verbosity.verbose + + +def quant_int4_dequant_bf16( + in_tensor: torch.Tensor, num_groups: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A util function for quantizing a tensor from from a float type (including + FP32, FP16, BF16) to INT4 and then dequantize the INT4 result to BF16 + (i.e., fake quantization) + """ + in_shape = in_tensor.shape + in_tensor = in_tensor.reshape( + *in_shape[:-1], num_groups, in_shape[-1] // num_groups + ) + + # Find max and min for each group + max_vals = torch.max(in_tensor, dim=-1, keepdim=True).values + min_vals = torch.min(in_tensor, dim=-1, keepdim=True).values + + # Compute scale and shift + scale: torch.Tensor = (max_vals - min_vals) / 15 + shift = torch.min(in_tensor, dim=-1, keepdim=True).values + scale = scale.to(torch.float16) + shift = shift.to(torch.float16) + shift_expand = shift.expand(in_tensor.shape) + scale_expand = scale.expand(in_tensor.shape) + + # Scale and shift + in_bytes = ((in_tensor - shift_expand) / scale_expand).to(torch.uint8) + + # Get only 4 bits + in_int4 = in_bytes & 0xF + + # Pack int4 in uint8 + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + + # Concat data + scale_shift = torch.concat( + [scale.view(torch.uint8), shift.view(torch.uint8)], dim=-1 + ) + in_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ) + + # Dequantize tensor for reference + in_fp16 = in_int4.to(torch.float16) + + # Convert type based on the CUDA implementation + in_quant_dequant_fp16 = (in_fp16 * scale_expand) + shift_expand + in_quant_dequant_fp32 = in_quant_dequant_fp16.to(torch.float) + in_quant_dequant_bf16 = in_quant_dequant_fp32.to(torch.bfloat16) + + return in_quant, in_quant_dequant_bf16.view(*in_shape) + + +def gqa_reference( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + seq_lens: List[int], + qk_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + The reference GQA implementation + """ + (B, T, H, D) = Q.shape + (_, MAX_T, Hk, D) = K.shape + (_, MAX_T, Hv, D) = V.shape + assert T == 1 + assert Hk == Hv + Y = torch.zeros_like(Q) + attn_out = torch.zeros(B, H, MAX_T) + for b in range(B): + max_t = seq_lens[b] + for h in range(H): + # now, compute fused attention + Q_ = Q[b, 0, h, :] + assert Q_.shape == (D,) + K_ = K[b, :max_t, 0, :] + assert K_.shape == (max_t, D) + S = (Q_.view(1, D) @ K_.T) * qk_scale # 1.0 / np.sqrt(D) + # max_qk_acc = torch.max(S) + # softmax_denominator = torch.exp(S - max_qk_acc).sum() + assert S.shape == (1, max_t) + P = torch.nn.functional.softmax(S, dim=-1) + + assert P.shape == (1, max_t) + + V_ = V[b, :max_t, 0, :] + assert V_.shape == (max_t, D) + O_ = P.view(1, max_t) @ V_ + assert O_.shape == (1, D) + Y[b, 0, h, :] = O_ + attn_out[b, h, :max_t] = P + return Y, attn_out + + +class Int4GQATest(unittest.TestCase): + @unittest.skipIf( + not torch.version.cuda, + "Skip when CUDA is not available", + ) + @settings(verbosity=VERBOSITY, max_examples=40, deadline=None) + # pyre-ignore + @given( + int4_kv=st.booleans(), + num_groups=st.sampled_from([1, 4]), + B=st.integers(min_value=1, max_value=128), + MAX_T=st.integers(min_value=4, max_value=128), + N_H_L=st.integers(min_value=1, max_value=128), + ) + def test_gqa( + self, + int4_kv: bool, + num_groups: int, + B: int, + MAX_T: int, + N_H_L: int, + ) -> None: + """ + Test correctness of torch.ops.fbgemm.gqa_attn_splitk against the + reference GQA implementation (testing both BF16 and INT4 KV caches) + """ + + # Constants + D_H = 128 + N_KVH_L = 1 # gqa_attn_splitk only supports 1 currently + SEQ_POSITION = MAX_T - 2 + + seq_positions = torch.tensor( + [SEQ_POSITION for _ in range(B)], device="cuda" + ).int() + kv_seqlens = [seq_position + 1 for seq_position in seq_positions] + q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda") + + # Generate KV cache + cache_k = torch.randn( + (B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v = torch.randn_like(cache_k) + if int4_kv: + cache_k, cache_k_ref = quant_int4_dequant_bf16(cache_k, num_groups) + cache_v, cache_v_ref = quant_int4_dequant_bf16(cache_v, num_groups) + cache_k_ref = cache_k_ref.cpu().float() + cache_v_ref = cache_v_ref.cpu().float() + else: + cache_k_ref = cache_k.cpu().float() + cache_v_ref = cache_v.cpu().float() + + # Compute qk_scale + qk_scale = 1.0 / np.sqrt(D_H) + + # Run reference + z_ref, attn_out_ref = gqa_reference( + q.cpu().float(), + cache_k_ref, + cache_v_ref, + kv_seqlens, + qk_scale=qk_scale, + ) + + # Run test + for split_k in [1, 2, 4, 8, 13, 16]: + z, _, _ = torch.ops.fbgemm.gqa_attn_splitk( + q, + cache_k, + cache_v, + seq_positions, + qk_scale=qk_scale, + num_split_ks=split_k, + num_int4_kv_groups=num_groups, + ) + torch.testing.assert_close( + z.cpu().bfloat16(), + z_ref.cpu().bfloat16(), + atol=2.0e-2, + rtol=6.0e-3, + ) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index d9e4decc35..49d48a5c60 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -46,7 +46,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") import torch.utils._pytree as pytree -from torch import Tensor +from torch import SymInt, Tensor if hasattr(torch.library, "impl_abstract"): @@ -273,6 +273,8 @@ def check_all_same_device(*tensors: Optional[Tensor]) -> None: continue if first_tensor is None: first_tensor = tensor + if first_tensor.device.type == "cpu" and tensor.device.type == "cpu": + return torch._check(tensor.device == first_tensor.device) @@ -556,3 +558,21 @@ def keyed_jagged_index_select_dim1_abstract( ret.append(weights.new_empty([selected_lengths_sum])) return ret + + +@impl_abstract("fbgemm::bounds_check_indices") +def bounds_check_indices_abstract( + rows_per_table: torch.Tensor, + indices: torch.Tensor, + offsets: torch.Tensor, + bounds_check_mode_int: int, + bounds_check_warning: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + B_offsets: Optional[torch.Tensor] = None, + max_B: Optional[SymInt] = None, +) -> None: + """ + This meta function is used to fake the bounds checking + from the original function `fbgemm::bounds_check_indices` + """ + return diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py index 21f94fab61..ae102d8e78 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py @@ -334,6 +334,8 @@ def generate_indices_zipf( assert E >= L, "num-embeddings must be greater than equal to bag-size" # oversample and then remove duplicates to obtain sampling without # replacement + if L == 0: + return torch.empty(iters, 0, dtype=torch.int).to(get_device()) total_B = sum(Bs) zipf_shape = (iters, total_B, zipf_oversample_ratio * L) if torch.cuda.is_available(): @@ -390,7 +392,7 @@ def update_indices_with_random_reuse( ] reused_indices += B_offset indices[it + 1, reused_indices] = indices[it, reused_indices] - B_offset += B + B_offset += B * L return indices diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 936c3261bd..4d29b01a54 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -123,14 +123,28 @@ def nbit_construct_split_state( ) -def random_quant_scaled_tensor(shape: torch.Size, device: torch.device) -> torch.Tensor: - return torch.randint( - 0, - 255, - size=shape, - dtype=torch.uint8, - device=device, - ) +def random_quant_scaled_tensor( + shape: torch.Size, + device: torch.device, + output_tensor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if output_tensor is not None: + return torch.randint( + 0, + 255, + size=shape, + out=output_tensor, + dtype=torch.uint8, + device=device, + ) + else: + return torch.randint( + 0, + 255, + size=shape, + dtype=torch.uint8, + device=device, + ) # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized. @@ -1399,15 +1413,14 @@ def initialize_weights(self) -> None: def fill_random_weights(self) -> None: """ Fill the buffer with random weights, table by table - FIXME: make it in-place fill. """ self.initialize_weights() weights = self.split_embedding_weights() for dest_weight in weights: - dest_weight[0].copy_( - random_quant_scaled_tensor( - shape=dest_weight[0].shape, device=self.current_device - ) + random_quant_scaled_tensor( + shape=dest_weight[0].shape, + device=self.current_device, + output_tensor=dest_weight[0], ) def assign_embedding_weights( diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h similarity index 100% rename from fbgemm_gpu/codegen/embedding_forward_split_cpu.h rename to fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh similarity index 100% rename from fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh rename to fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h index 1f28843ef4..27b4bb2f95 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_op_registration.h @@ -17,7 +17,8 @@ static __inline void __attribute__(( __gnu_inline__, __always_inline__, __artificial__, - __target__("serialize"))) __builtin_ia32_serialize(void) { + __target__("serialize"))) +__builtin_ia32_serialize(void) { abort(); } #endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h index 80d7bcc3f0..9a70b95d7c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h @@ -42,8 +42,8 @@ namespace fbgemm_gpu { #define FBGEMM_GPU_ENUM_REGISTER_END ); #define FBGEMM_GPU_ENUM_OP(module_name, op_name) \ -#op_name "() -> ((str, (str, int)[])[])", \ - TORCH_FN(enum_query ) + #op_name "() -> ((str, (str, int)[])[])", \ + TORCH_FN(enum_query) // To work around (escape from) hipify_torch, the names of the idendifiers // are decoposed to `x` and `y`. `z` is supposed to be hipified. #define FBGEMM_GPU_ENUM_ITEM(x, y, z) \ diff --git a/fbgemm_gpu/include/fbgemm_gpu/ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/ops_utils.h index 45b6e71172..7812bb2627 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/ops_utils.h @@ -17,7 +17,8 @@ static __inline void __attribute__(( __gnu_inline__, __always_inline__, __artificial__, - __target__("serialize"))) __builtin_ia32_serialize(void) { + __target__("serialize"))) +__builtin_ia32_serialize(void) { abort(); } #endif diff --git a/fbgemm_gpu/requirements.txt b/fbgemm_gpu/requirements.txt index f85b60da58..29b7f9d5e8 100644 --- a/fbgemm_gpu/requirements.txt +++ b/fbgemm_gpu/requirements.txt @@ -10,6 +10,7 @@ # * https://github.com/nod-ai/SHARK/issues/2095 # * https://github.com/jianyicheng/mase-docker/pull/9 +build cmake hypothesis jinja2 @@ -17,5 +18,6 @@ mpmath==1.3.0 ninja numpy scikit-build +setuptools setuptools_git_versioning tabulate diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 9ea33db7ef..792c6d1cca 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -118,7 +118,17 @@ def package_name(self) -> str: def variant_version(self) -> str: pkg_vver: str = "" - if self.nova_flag() is None: + if "egg_info" in self.other_args: + # If build is invoked through `python -m build` instead of + # `python setup.py`, this script is invoked twice, once as + # `setup.py egg_info`, and once as `setup.py bdist_wheel`. + # Ignore determining the variant_version for the first case. + print( + "[SETUP.PY] Script was invoked as `setup.py egg_info`, ignoring variant_version" + ) + return pkg_vver + + elif self.nova_flag() is None: # If not running in a Nova workflow, then use the # `fbgemm_gpu-` naming convention for the package, since # PyPI does not accept version+xx in the naming convention. @@ -134,7 +144,7 @@ def variant_version(self) -> str: pkg_vver = f"+cu{cuda_version[0]}{cuda_version[1]}" else: sys.exit( - "[SETUP.PY] Installed PyTorch variant is not CUDA; cannot determine the CUDA version!" + "[SETUP.PY] The installed PyTorch variant is not CUDA; cannot determine the CUDA version!" ) elif self.args.package_variant == "rocm": @@ -143,7 +153,7 @@ def variant_version(self) -> str: pkg_vver = f"+rocm{rocm_version[0]}.{rocm_version[1]}" else: sys.exit( - "[SETUP.PY] Installed PyTorch variant is not ROCm; cannot determine the ROCm version!" + "[SETUP.PY] The installed PyTorch variant is not ROCm; cannot determine the ROCm version!" ) else: @@ -485,5 +495,5 @@ def main(argv: List[str]) -> None: if __name__ == "__main__": - print(f"[SETUP.PY] {sys.argv}") + print(f"[SETUP.PY] ARGV: {sys.argv}") main(sys.argv[1:]) diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu index fce2ed41e0..5c57b94560 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine.cu +++ b/fbgemm_gpu/src/input_combine_ops/input_combine.cu @@ -71,7 +71,11 @@ __launch_bounds__(kMaxThreads) void tbe_input_combine_with_length_kernel( : vec_copy_with_implicit_type_cast< int32_t, int32_t, - VEC_WIDTH>)(combined_indices, indices_addrs[list_id], src_idx, indices_start + src_idx, indices_end - indices_start); + VEC_WIDTH>)(combined_indices, + indices_addrs[list_id], + src_idx, + indices_start + src_idx, + indices_end - indices_start); // Invoke a function based on the lengths type ((lengths_is_long[is_long_idx] & is_long_mask) @@ -79,7 +83,11 @@ __launch_bounds__(kMaxThreads) void tbe_input_combine_with_length_kernel( : vec_copy_with_implicit_type_cast< int32_t, int32_t, - VEC_WIDTH>)(combined_lengths, lengths_addrs[list_id], src_idx, lengths_start + src_idx, lengths_end - lengths_start); + VEC_WIDTH>)(combined_lengths, + lengths_addrs[list_id], + src_idx, + lengths_start + src_idx, + lengths_end - lengths_start); if (per_sample_weights_addrs) { vec_copy_with_implicit_type_cast( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 57a0d43761..2e08efb4d0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -604,7 +604,7 @@ class JaggedIndexSelect2dOp const Tensor& values, const Tensor& lengths, const Tensor& indices, - const c10::optional num_dense_output_rows) { + const c10::optional optional_num_dense_output_rows) { TORCH_CHECK( values.dim() == 2, "jagged_index_select supports only 2D inputs") TENSORS_ON_SAME_DEVICE(lengths, indices); @@ -616,7 +616,6 @@ class JaggedIndexSelect2dOp ctx->save_for_backward({indices, output_offsets, input_offsets}); ctx->saved_data["num_input_rows"] = values.sym_size(0); - ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows; static auto op = c10::Dispatcher::singleton() @@ -628,14 +627,17 @@ class JaggedIndexSelect2dOp const Tensor& output_offsets, const c10::optional)>(); - return { - op.call( - values, - indices, - input_offsets, - output_offsets, - num_dense_output_rows), - output_lengths}; + auto out = op.call( + values, + indices, + input_offsets, + output_offsets, + optional_num_dense_output_rows); + + // Always save output size to avoid triggering D2H sync in backward + ctx->saved_data["num_dense_output_rows"] = out.sym_size(0); + + return {out, output_lengths}; } static torch::autograd::variable_list backward( @@ -654,7 +656,7 @@ class JaggedIndexSelect2dOp auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt(); auto num_dense_input_rows = - ctx->saved_data["num_dense_output_rows"].toOptional(); + ctx->saved_data["num_dense_output_rows"].toSymInt(); static auto op = c10::Dispatcher::singleton() @@ -665,7 +667,7 @@ class JaggedIndexSelect2dOp const Tensor& input_offsets, const Tensor& output_offsets, c10::SymInt num_output_rows, - const c10::optional optional_num_dense_input_rows)>(); + c10::SymInt num_dense_input_rows)>(); return { op.call( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index 4ba75765ab..3b0a41180e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1161,13 +1161,7 @@ Tensor jagged_index_add_2d_forward_v2_impl( const Tensor& input_offsets, const Tensor& output_offsets, const int64_t num_output_rows, - const c10::optional optional_num_dense_input_rows) { - // Intentionally not using optional::value_or here to avoid materializing - // .item() call when possible. - int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value() - ? optional_num_dense_input_rows.value() - : input_offsets[input_offsets.numel() - 1].item(); - + const int64_t num_dense_input_rows) { static auto v1_op = c10::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "") @@ -1681,7 +1675,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor"); m.def( - "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor", + "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, SymInt num_dense_input_rows) -> Tensor", {PT2_COMPLIANT_TAG}); m.def( "jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor"); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index 2881dc7447..a4652a353a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -34,6 +34,33 @@ __launch_bounds__(kMaxThreads) void _populate_length_to_feature_id_inplace_kerne length_to_feature_idx[batch_size_offsets[t] + b] = t; } +template +void increase_gpu_max_dynamic_shared_memory(func_t kernel, const int device) { + // V100: 96 KB; A100: 160 KB; H100: 228 KB. + int max_shared_bytes = 0; + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, +#ifndef __HIP_PLATFORM_AMD__ + cudaDevAttrMaxSharedMemoryPerBlockOptin, +#else + hipDeviceAttributeMaxSharedMemoryPerBlock, +#endif + device)); + + int shared_kb = max_shared_bytes >> 10; + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + // V100: 64 KB; A100: 96 KB; H100: 144 KB + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); + + int used_shared_bytes = used_shared_kb << 10; + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void*)kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + // Kernel for bucketize lengths, with the Block distribution (vs. cyclic, // block-cyclic distribution). Used for bucketize sparse feature, especially for // checkpointing with row-wise partition (sparse_feature is partitioned @@ -365,6 +392,8 @@ block_bucketize_sparse_features_cuda( static_assert(kMaxThreads % kWarpSize == 0); const dim3 block_dims(kWarpSize, kMaxThreads / kWarpSize); const dim3 grid_dims(cuda_calc_xblock_count(lengths_size, block_dims.y)); + const auto smem_adjust_threshold = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; AT_DISPATCH_INDEX_TYPES( offsets_contig.scalar_type(), "_block_bucketize_sparse_features_cuda_kernel1", @@ -621,6 +650,7 @@ block_bucketize_sparse_features_cuda( }); } } else { + auto smem_size = my_size * block_dims.y * sizeof(uint64_t); if (weights.has_value() & bucketize_pos) { Tensor weights_value = weights.value(); auto weights_value_contig = weights_value.contiguous(); @@ -639,41 +669,47 @@ block_bucketize_sparse_features_cuda( weights_value.scalar_type(), "_block_bucketize_pooled_sparse_features_cuda_kernel2_3", [&] { - _block_bucketize_pooled_sparse_features_cuda_kernel2< - true, - true, - offset_t, - index_t, - scalar_t> - <<>>( - lengths_size, - B, - block_sizes.data_ptr(), - my_size, - offsets_contig.data_ptr(), - indices_contig.data_ptr(), - weights_value_contig.data_ptr(), - new_offsets.data_ptr(), - new_indices.data_ptr(), - new_weights.data_ptr(), - new_pos.data_ptr(), - batch_size_per_feature.has_value() - ? length_to_feature_idx.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_concat - .data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_offsets - .data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? indices_to_lb.data_ptr() - : static_cast(nullptr)); + const auto block_bucketize_kernel = + _block_bucketize_pooled_sparse_features_cuda_kernel2< + true, + true, + offset_t, + index_t, + scalar_t>; + if (smem_size > smem_adjust_threshold) { + increase_gpu_max_dynamic_shared_memory( + block_bucketize_kernel, lengths.get_device()); + } + block_bucketize_kernel<<< + grid_dims, + block_dims, + smem_size, + at::cuda::getCurrentCUDAStream()>>>( + lengths_size, + B, + block_sizes.data_ptr(), + my_size, + offsets_contig.data_ptr(), + indices_contig.data_ptr(), + weights_value_contig.data_ptr(), + new_offsets.data_ptr(), + new_indices.data_ptr(), + new_weights.data_ptr(), + new_pos.data_ptr(), + batch_size_per_feature.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -695,41 +731,47 @@ block_bucketize_sparse_features_cuda( weights_value.scalar_type(), "_block_bucketize_pooled_sparse_features_cuda_kernel2_3", [&] { - _block_bucketize_pooled_sparse_features_cuda_kernel2< - true, - false, - offset_t, - index_t, - scalar_t> - <<>>( - lengths_size, - B, - block_sizes.data_ptr(), - my_size, - offsets_contig.data_ptr(), - indices_contig.data_ptr(), - weights_value_contig.data_ptr(), - new_offsets.data_ptr(), - new_indices.data_ptr(), - new_weights.data_ptr(), - nullptr, - batch_size_per_feature.has_value() - ? length_to_feature_idx.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_concat - .data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_offsets - .data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? indices_to_lb.data_ptr() - : static_cast(nullptr)); + const auto block_bucketize_kernel = + _block_bucketize_pooled_sparse_features_cuda_kernel2< + true, + false, + offset_t, + index_t, + scalar_t>; + if (smem_size > smem_adjust_threshold) { + increase_gpu_max_dynamic_shared_memory( + block_bucketize_kernel, lengths.get_device()); + } + block_bucketize_kernel<<< + grid_dims, + block_dims, + smem_size, + at::cuda::getCurrentCUDAStream()>>>( + lengths_size, + B, + block_sizes.data_ptr(), + my_size, + offsets_contig.data_ptr(), + indices_contig.data_ptr(), + weights_value_contig.data_ptr(), + new_offsets.data_ptr(), + new_indices.data_ptr(), + new_weights.data_ptr(), + nullptr, + batch_size_per_feature.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -745,39 +787,45 @@ block_bucketize_sparse_features_cuda( indices_contig.scalar_type(), "_block_bucketize_pooled_sparse_features_cuda_kernel2_2", [&] { - _block_bucketize_pooled_sparse_features_cuda_kernel2< - false, - true, - offset_t, - index_t, - std::nullptr_t> - <<>>( - lengths_size, - B, - block_sizes.data_ptr(), - my_size, - offsets_contig.data_ptr(), - indices_contig.data_ptr(), - nullptr, - new_offsets.data_ptr(), - new_indices.data_ptr(), - nullptr, - new_pos.data_ptr(), - batch_size_per_feature.has_value() - ? length_to_feature_idx.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_concat.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_offsets.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? indices_to_lb.data_ptr() - : static_cast(nullptr)); + const auto block_bucketize_kernel = + _block_bucketize_pooled_sparse_features_cuda_kernel2< + false, + true, + offset_t, + index_t, + std::nullptr_t>; + if (smem_size > smem_adjust_threshold) { + increase_gpu_max_dynamic_shared_memory( + block_bucketize_kernel, lengths.get_device()); + } + block_bucketize_kernel<<< + grid_dims, + block_dims, + smem_size, + at::cuda::getCurrentCUDAStream()>>>( + lengths_size, + B, + block_sizes.data_ptr(), + my_size, + offsets_contig.data_ptr(), + indices_contig.data_ptr(), + nullptr, + new_offsets.data_ptr(), + new_indices.data_ptr(), + nullptr, + new_pos.data_ptr(), + batch_size_per_feature.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); @@ -791,39 +839,45 @@ block_bucketize_sparse_features_cuda( indices_contig.scalar_type(), "_block_bucketize_pooled_sparse_features_cuda_kernel2_2", [&] { - _block_bucketize_pooled_sparse_features_cuda_kernel2< - false, - false, - offset_t, - index_t, - std::nullptr_t> - <<>>( - lengths_size, - B, - block_sizes.data_ptr(), - my_size, - offsets_contig.data_ptr(), - indices_contig.data_ptr(), - nullptr, - new_offsets.data_ptr(), - new_indices.data_ptr(), - nullptr, - nullptr, - batch_size_per_feature.has_value() - ? length_to_feature_idx.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_concat.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? block_bucketize_pos_offsets.data_ptr() - : static_cast(nullptr), - block_bucketize_pos.has_value() - ? indices_to_lb.data_ptr() - : static_cast(nullptr)); + const auto block_bucketize_kernel = + _block_bucketize_pooled_sparse_features_cuda_kernel2< + false, + false, + offset_t, + index_t, + std::nullptr_t>; + if (smem_size > smem_adjust_threshold) { + increase_gpu_max_dynamic_shared_memory( + block_bucketize_kernel, lengths.get_device()); + } + block_bucketize_kernel<<< + grid_dims, + block_dims, + smem_size, + at::cuda::getCurrentCUDAStream()>>>( + lengths_size, + B, + block_sizes.data_ptr(), + my_size, + offsets_contig.data_ptr(), + indices_contig.data_ptr(), + nullptr, + new_offsets.data_ptr(), + new_indices.data_ptr(), + nullptr, + nullptr, + batch_size_per_feature.has_value() + ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() + : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index a665e0ee46..ec284aebd9 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -135,6 +135,7 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( evicted_indices, at::PackedTensorAccessor32 actions_count, TORCH_DSA_KERNEL_ARGS) { + // Number of cache sets const int32_t C = lxu_cache_state.size(0); const int32_t N = sorted_cache_sets.size(0); @@ -144,16 +145,33 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( } const int32_t cache_set = sorted_cache_sets[n]; + + // Set actions_count. It is basically the sum of all SLs. Since cache sets + // are sorted in sorted_cache_sets, we can count the number of elements that + // are not C by finding the position of the last cache set that is not C + if (threadIdx.x == 0) { + // Zero cache misses (the first sorted_cache_sets is C) or + // some cache misses (some sorted_cache_sets are C) + if (cache_set == C && (n == 0 || sorted_cache_sets[n - 1] != C)) { + actions_count[0] = n; + } + // All cache misses (none of sorted_cache_sets is C) + else if (n == N - 1 && cache_set != C) { + actions_count[0] = N; + } + } + if (cache_set >= C) { - // ignore the already-existing elements - evicted_indices[n] = -1; - assigned_cache_slots[n] = -1; + if (threadIdx.x == 0) { + // ignore the already-existing elements + evicted_indices[n] = -1; + assigned_cache_slots[n] = -1; + } return; } // check if this warp is responsible for this whole segment. - const bool segment_start = - (n == 0 || sorted_cache_sets[n - 1] != sorted_cache_sets[n]); + const bool segment_start = (n == 0 || sorted_cache_sets[n - 1] != cache_set); if (!segment_start) { // don't have *warp* divergence since we launch full warps in blockDim.x, @@ -207,10 +225,6 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; lxu_cache_state[cache_set][insert_slot] = insert_idx; lru_state[cache_set][insert_slot] = time_stamp; - - if (threadIdx.x == 0) { - gpuAtomicAdd(&actions_count[0], SL); - } } std::tuple ssd_cache_populate_actions_cuda( @@ -236,11 +250,11 @@ std::tuple ssd_cache_populate_actions_cuda( const int32_t N = unique_indices.numel(); auto evicted_indices = empty_like(unique_indices); - auto assigned_cache_slots = - empty_like(unique_indices, unique_indices.options().dtype(at::kInt)); - auto actions_count = at::zeros({1}, unique_indices.options().dtype(at::kInt)); + const auto int_options = unique_indices.options().dtype(at::kInt); + auto assigned_cache_slots = empty_like(unique_indices, int_options); if (unique_indices.numel() == 0) { + auto actions_count = at::zeros({1}, int_options); // these are all of length zero return std::make_tuple( empty_like(unique_indices), @@ -248,11 +262,11 @@ std::tuple ssd_cache_populate_actions_cuda( assigned_cache_slots, actions_count); } + + auto actions_count = at::empty({1}, int_options); // Find uncached indices - Tensor uvm_cache_stats = - at::empty({0}, linear_indices.options().dtype(at::kInt)); - Tensor lxu_cache_locking_counter = - at::empty({0, 0}, lxu_cache_state.options().dtype(at::kInt)); + Tensor uvm_cache_stats = at::empty({0}, int_options); + Tensor lxu_cache_locking_counter = at::empty({0, 0}, int_options); auto cache_sets_and_unique_indices = lru_cache_find_uncached_cuda( unique_indices, unique_indices_length, diff --git a/fbgemm_gpu/test/sparse/block_bucketize_test.py b/fbgemm_gpu/test/sparse/block_bucketize_test.py index 8f0c2d0022..3832ceb0ca 100644 --- a/fbgemm_gpu/test/sparse/block_bucketize_test.py +++ b/fbgemm_gpu/test/sparse/block_bucketize_test.py @@ -604,16 +604,17 @@ def test_block_bucketize_sparse_features_with_block_bucketize_pos( has_weight=st.booleans(), bucketize_pos=st.booleans(), sequence=st.booleans(), + my_size=st.sampled_from([3, 194, 256]), ) - @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=32, deadline=None) def test_block_bucketize_sparse_features_large( self, index_type: Type[torch.dtype], has_weight: bool, bucketize_pos: bool, sequence: bool, + my_size: int, ) -> None: - my_size = 3 bucket_size = 5 warp_size = 32 max_num_thread_in_a_block = 1024 diff --git a/fbgemm_gpu/test/sparse/utils_test.cpp b/fbgemm_gpu/test/sparse/utils_test.cpp index 7fc087a6e0..942e26c631 100644 --- a/fbgemm_gpu/test/sparse/utils_test.cpp +++ b/fbgemm_gpu/test/sparse/utils_test.cpp @@ -12,7 +12,7 @@ #include #include -#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using namespace testing; diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py index 53f69dcfea..608676b62d 100755 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import sys from typing import Any, Dict diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 3939aa7911..cc79229c5e 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -760,9 +760,6 @@ def _get_wts_from_counter_adagrad_using_cowclip( suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) @unittest.skipIf(*gpu_unavailable) - @unittest.skip( - "is flaky, see https://www.internalfb.com/intern/test/281475047227145?ref_report_id=0" - ) def test_backward_optimizers_adam( # noqa C901 self, T: int, diff --git a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp index f2f012d7c2..fbc5dc8b2c 100644 --- a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp +++ b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp @@ -11,8 +11,8 @@ #include #include -#include "deeplearning/fbgemm/fbgemm_gpu/codegen/embedding_forward_split_cpu.h" #include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu TEST(CpuKernelTest, csr2csc_test) { diff --git a/src/EmbeddingSpMDMAutovec.h b/src/EmbeddingSpMDMAutovec.h index 0d3b16f17b..797696de2c 100644 --- a/src/EmbeddingSpMDMAutovec.h +++ b/src/EmbeddingSpMDMAutovec.h @@ -54,11 +54,11 @@ FBGEMM_API bool EmbeddingSpMDMNBit_autovec( #include "RefImplementations.h" -#define ALIAS_TEMPLATE_FUNCTION(highLevelF, lowLevelF) \ - template \ - inline auto highLevelF(Args&&... args) \ - ->decltype(lowLevelF(std::forward(args)...)) { \ - return lowLevelF(std::forward(args)...); \ +#define ALIAS_TEMPLATE_FUNCTION(highLevelF, lowLevelF) \ + template \ + inline auto highLevelF( \ + Args&&... args) -> decltype(lowLevelF(std::forward(args)...)) { \ + return lowLevelF(std::forward(args)...); \ } namespace fbgemm { diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 2243ea2b42..15e6331eeb 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -17,7 +17,7 @@ #include "./CodeCache.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/SimdUtils.h" -//#define FBGEMM_LOG_CODE 1 +// #define FBGEMM_LOG_CODE 1 namespace fbgemm { diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index 259b206e4b..9e955371c2 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -189,8 +189,7 @@ int32_t PackAMatrix::addr(int32_t r, int32_t c) const { template void PackAMatrix::printPackedMatrix(std::string name) { - std::cout << name << ":" - << "[" << BaseType::numPackedRows() << ", " + std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; T* out = BaseType::getBuf(); diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 63b7f069cf..c202ed2e33 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -703,8 +703,7 @@ void PackAWithIm2Col::pack(const block_type_t& block) { template void PackAWithIm2Col::printPackedMatrix( std::string name) { - std::cout << name << ":" - << "[" << BaseType::numPackedRows() << ", " + std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; T* out = BaseType::getBuf(); diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 00192ed497..9299dd24fe 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -210,8 +210,7 @@ int32_t PackAWithQuantRowOffset::addr(int32_t r, int32_t c) const { template void PackAWithQuantRowOffset::printPackedMatrix(std::string name) { - std::cout << name << ":" - << "[" << BaseType::numPackedRows() << ", " + std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; T* out = BaseType::getBuf(); diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 1f22d4df12..07a72c0669 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -189,8 +189,7 @@ int32_t PackAWithRowOffset::addr(int32_t r, int32_t c) const { template void PackAWithRowOffset::printPackedMatrix(std::string name) { - std::cout << name << ":" - << "[" << BaseType::numPackedRows() << ", " + std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; T* out = BaseType::getBuf(); diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 80463c56f1..256412118f 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -390,11 +390,9 @@ template void PackBMatrix::printPackedMatrix( std::string name, const BlockingFactors* params) { - std::cout << name << ":" - << "[" << BaseType::numPackedRows() << ", " + std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; - std::cout << "block size:" - << "[" << BaseType::blockRowSize() << ", " + std::cout << "block size:" << "[" << BaseType::blockRowSize() << ", " << BaseType::blockColSize() << "]" << std::endl; for (int g = 0; g < BaseType::numGroups(); ++g) { diff --git a/src/Utils.cc b/src/Utils.cc index 1ae55d6460..d2086d472b 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -102,8 +102,7 @@ void printMatrix( // R: number of rows in op(inp) // C: number of cols in op(inp) // ld: leading dimension in inp - std::cout << name << ":" - << "[" << R << ", " << C << "]" << std::endl; + std::cout << name << ":" << "[" << R << ", " << C << "]" << std::endl; bool tr = (op == matrix_op_t::Transpose); for (size_t r = 0; r < R; ++r) { for (size_t c = 0; c < C; ++c) { diff --git a/test/SparsePackUnpackTest.cc b/test/SparsePackUnpackTest.cc index a1d5c7a763..cd3e85b863 100644 --- a/test/SparsePackUnpackTest.cc +++ b/test/SparsePackUnpackTest.cc @@ -57,8 +57,8 @@ TEST_P(packUnpackTest, sparseUnpackTest) { for (int k = 0; k < K; ++k) { ASSERT_EQ(wData[j * K + k], wUnpackedData[j * K + k]) << "Original and unpacked data elements are not the same at idx [" - << j << ", " << k << "]: " - << "original: " << static_cast(wData[j * K + k]) + << j << ", " << k + << "]: " << "original: " << static_cast(wData[j * K + k]) << " , unpacked: " << static_cast(wUnpackedData[j * K + k]); } }