Skip to content

Commit

Permalink
Introduce async_smart_free
Browse files Browse the repository at this point in the history
Thhis function intends to replace use of host_task submissions
to manage USM temporary deallocations.

Signature

sycl::event
async_smart_free( sycl::queue &,
   const std::vector<sycl::event> &,
   std::unique_ptr<T, USMDeleter>, ...);
  • Loading branch information
oleksandr-pavlyk committed Dec 24, 2024
1 parent 535e471 commit bbb55f1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 59 deletions.
36 changes: 6 additions & 30 deletions dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1605,14 +1605,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
n_counts, count_ptr, proj_op,
is_ascending, depends);

sort_ev = exec_q.submit([=](sycl::handler &cgh) {
cgh.depends_on(sort_ev);
const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task(
[ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); });
});
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {sort_ev}, count_owner);

return sort_ev;
}
Expand Down Expand Up @@ -1655,19 +1649,8 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q,
}
}

sort_ev = exec_q.submit([=](sycl::handler &cgh) {
cgh.depends_on(sort_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, count_ptr, tmp_arr]() {
sycl_free_noexcept(tmp_arr, ctx);
sycl_free_noexcept(count_ptr, ctx);
});
});
count_owner.release();
tmp_arr_owner.release();
sort_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {sort_ev}, tmp_arr_owner, count_owner);
}

return sort_ev;
Expand Down Expand Up @@ -1819,16 +1802,9 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
});
});

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(map_back_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
});
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {map_back_ev}, workspace_owner);

workspace_owner.release();
return cleanup_ev;
}

Expand Down
33 changes: 6 additions & 27 deletions dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,9 @@ topk_full_merge_sort_impl(sycl::queue &exec_q,
});

sycl::event cleanup_host_task_event =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(write_out_ev);
const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task(
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
});
dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev},
index_data_owner);

index_data_owner.release();
return cleanup_host_task_event;
};

Expand Down Expand Up @@ -429,16 +422,9 @@ sycl::event topk_merge_impl(
});

sycl::event cleanup_host_task_event =
exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(write_topk_ev);
const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task(
[ctx, index_data] { sycl_free_noexcept(index_data, ctx); });
});
dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {write_topk_ev}, index_data_owner);

index_data_owner.release();
return cleanup_host_task_event;
}
}
Expand Down Expand Up @@ -537,16 +523,9 @@ sycl::event topk_radix_impl(sycl::queue &exec_q,
});
});

sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(write_topk_ev);

const sycl::context &ctx = exec_q.get_context();

using dpctl::tensor::alloc_utils::sycl_free_noexcept;
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
});
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
exec_q, {write_topk_ev}, workspace_owner);

workspace_owner.release();
return cleanup_ev;
}

Expand Down
35 changes: 33 additions & 2 deletions dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <iostream>
#include <memory>
#include <stdexcept>
#include <vector>

#include "sycl/sycl.hpp"

Expand Down Expand Up @@ -75,7 +76,8 @@ void sycl_free_noexcept(T *ptr, const sycl::context &ctx) noexcept
}
}

template <typename T> void sycl_free_noexcept(T *ptr, sycl::queue &q) noexcept
template <typename T>
void sycl_free_noexcept(T *ptr, const sycl::queue &q) noexcept
{
sycl_free_noexcept(ptr, q.get_context());
}
Expand All @@ -89,7 +91,7 @@ class USMDeleter
USMDeleter(const sycl::queue &q) : ctx_(q.get_context()) {}
USMDeleter(const sycl::context &ctx) : ctx_(ctx) {}

template <typename T> void operator()(T *ptr)
template <typename T> void operator()(T *ptr) const
{
sycl_free_noexcept(ptr, ctx_);
}
Expand Down Expand Up @@ -138,6 +140,35 @@ smart_malloc_jost(std::size_t count,
return smart_malloc<T>(count, q, sycl::usm::alloc::host, propList);
}

template <typename... Args>
sycl::event async_smart_free(sycl::queue &exec_q,
const std::vector<sycl::event> &depends,
Args &&...args)
{
constexpr std::size_t n = sizeof...(Args);

std::vector<void *> ptrs;
ptrs.reserve(n);
(ptrs.push_back(reinterpret_cast<void *>(args.get())), ...);

std::vector<USMDeleter> dels;
dels.reserve(n);
(dels.push_back(args.get_deleter()), ...);

sycl::event ht_e = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

cgh.host_task([ptrs, dels]() {
for (size_t i = 0; i < ptrs.size(); ++i) {
dels[i](ptrs[i]);
}
});
});
(args.release(), ...);

return ht_e;
}

} // end of namespace alloc_utils
} // end of namespace tensor
} // end of namespace dpctl

0 comments on commit bbb55f1

Please sign in to comment.