From bbb55f1a693a0afcb2b0017ca50a04b2e470efc9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 24 Dec 2024 09:30:49 -0600 Subject: [PATCH] Introduce async_smart_free Thhis function intends to replace use of host_task submissions to manage USM temporary deallocations. Signature sycl::event async_smart_free( sycl::queue &, const std::vector &, std::unique_ptr, ...); --- .../include/kernels/sorting/radix_sort.hpp | 36 ++++--------------- .../include/kernels/sorting/topk.hpp | 33 ++++------------- .../include/utils/sycl_alloc_utils.hpp | 35 ++++++++++++++++-- 3 files changed, 45 insertions(+), 59 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index 043d047a6a..de0011e2ab 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -1605,14 +1605,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, n_counts, count_ptr, proj_op, is_ascending, depends); - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); }); - }); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); return sort_ev; } @@ -1655,19 +1649,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, } } - sort_ev = exec_q.submit([=](sycl::handler &cgh) { - cgh.depends_on(sort_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, count_ptr, tmp_arr]() { - sycl_free_noexcept(tmp_arr, ctx); - sycl_free_noexcept(count_ptr, ctx); - }); - }); - count_owner.release(); - tmp_arr_owner.release(); + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); } return sort_ev; @@ -1819,16 +1802,9 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, }); }); - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(map_back_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {map_back_ev}, workspace_owner); - workspace_owner.release(); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index d746c0755e..d1d686a2dd 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -154,16 +154,9 @@ topk_full_merge_sort_impl(sycl::queue &exec_q, }); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_out_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); - index_data_owner.release(); return cleanup_host_task_event; }; @@ -429,16 +422,9 @@ sycl::event topk_merge_impl( }); sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, index_data] { sycl_free_noexcept(index_data, ctx); }); - }); + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); - index_data_owner.release(); return cleanup_host_task_event; } } @@ -537,16 +523,9 @@ sycl::event topk_radix_impl(sycl::queue &exec_q, }); }); - sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(write_topk_ev); - - const sycl::context &ctx = exec_q.get_context(); - - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); - }); + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); - workspace_owner.release(); return cleanup_ev; } diff --git a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp index 79478a1c2e..87476221d8 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include "sycl/sycl.hpp" @@ -75,7 +76,8 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept } } -template void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept +template +void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept { sycl_free_noexcept(ptr, q.get_context()); } @@ -89,7 +91,7 @@ class USMDeleter USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {} USMDeleter(const sycl::context &ctx) : ctx_(ctx) {} - template void operator()(T *ptr) + template void operator()(T *ptr) const { sycl_free_noexcept(ptr, ctx_); } @@ -138,6 +140,35 @@ smart_malloc_jost(std::size_t count, return smart_malloc(count, q, sycl::usm::alloc::host, propList); } +template +sycl::event async_smart_free(sycl::queue &exec_q, + const std::vector &depends, + Args &&...args) +{ + constexpr std::size_t n = sizeof...(Args); + + std::vector ptrs; + ptrs.reserve(n); + (ptrs.push_back(reinterpret_cast(args.get())), ...); + + std::vector dels; + dels.reserve(n); + (dels.push_back(args.get_deleter()), ...); + + sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.host_task([ptrs, dels]() { + for (size_t i = 0; i < ptrs.size(); ++i) { + dels[i](ptrs[i]); + } + }); + }); + (args.release(), ...); + + return ht_e; +} + } // end of namespace alloc_utils } // end of namespace tensor } // end of namespace dpctl