Skip to content

Commit

Permalink
[SYCL][Graph] Add query function to recording queue that returns grap…
Browse files Browse the repository at this point in the history
…h object. (#12460)

This enables for example pause and continue recording from within a
single function with an interface that is not graph aware.

---------

Co-authored-by: Ewan Crawford <ewan@codeplay.com>
  • Loading branch information
reble and EwanC authored Jan 26, 2024
1 parent 85b7145 commit c835f82
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 2 deletions.
19 changes: 19 additions & 0 deletions sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ public:
ext::oneapi::experimental::queue_state
ext_oneapi_get_state() const;
ext::oneapi::experimental::command_graph<graph_state::modifiable>
ext_oneapi_get_graph() const;
/* -- graph convenience shortcuts -- */
event ext_oneapi_graph(command_graph<graph_state::executable>& graph);
Expand Down Expand Up @@ -1019,6 +1022,22 @@ Returns: If the queue is in the default state where commands are scheduled
immediately for execution, `queue_state::executing` is returned. Otherwise,
`queue_state::recording` is returned where commands are redirected to a `command_graph`
object.
|
[source,c++]
----
command_graph<graph_state::modifiable>
queue::ext_oneapi_get_graph() const;
----

| Query the underlying command graph of a queue when recording.

Returns: The graph object that the queue is recording commands into.

Exceptions:

* Throws synchronously with error code `invalid` if the queue is not in `queue_state::recording`
state.

|
[source,c++]
----
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ class command_graph : public detail::modifiable_command_graph {
/// @param Impl Detail implementation class to construct object with.
command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
: modifiable_command_graph(Impl) {}

template <class T>
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
};

template <>
Expand Down
7 changes: 6 additions & 1 deletion sycl/include/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include <sycl/ext/oneapi/bindless_images_memory.hpp> // for image_mem_handle
#include <sycl/ext/oneapi/device_global/device_global.hpp> // for device_global
#include <sycl/ext/oneapi/device_global/properties.hpp> // for device_image_s...
#include <sycl/ext/oneapi/experimental/graph.hpp> // for graph_state
#include <sycl/ext/oneapi/experimental/graph.hpp> // for command_graph...
#include <sycl/ext/oneapi/properties/properties.hpp> // for empty_properti...
#include <sycl/handler.hpp> // for handler, isDev...
#include <sycl/id.hpp> // for id
Expand Down Expand Up @@ -315,6 +315,11 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
/// \return State the queue is currently in.
ext::oneapi::experimental::queue_state ext_oneapi_get_state() const;

/// \return Graph when the queue is recording.
ext::oneapi::experimental::command_graph<
ext::oneapi::experimental::graph_state::modifiable>
ext_oneapi_get_graph() const;

/// \return true if this queue is a SYCL host queue.
__SYCL2020_DEPRECATED(
"is_host() is deprecated as the host device is no longer supported.")
Expand Down
14 changes: 14 additions & 0 deletions sycl/source/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ ext::oneapi::experimental::queue_state queue::ext_oneapi_get_state() const {
: ext::oneapi::experimental::queue_state::executing;
}

ext::oneapi::experimental::command_graph<
ext::oneapi::experimental::graph_state::modifiable>
queue::ext_oneapi_get_graph() const {
auto Graph = impl->getCommandGraph();
if (!Graph)
throw sycl::exception(
make_error_code(errc::invalid),
"ext_oneapi_get_graph() can only be called on recording queues.");

return sycl::detail::createSyclObjFromImpl<
ext::oneapi::experimental::command_graph<
ext::oneapi::experimental::graph_state::modifiable>>(Graph);
}

bool queue::is_host() const {
bool IsHost = impl->is_host();
assert(!IsHost && "queue::is_host should not be called in implementation.");
Expand Down
96 changes: 96 additions & 0 deletions sycl/test-e2e/Graph/RecordReplay/dotp_in_order_pause.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
// RUN: %if level_zero %{env UR_L0_LEAKS_DEBUG=1 %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}

// Tests a dotp operation using device USM and an in-order queue.
// Recording is paused, and command is submitted eagerly.

#include "../graph_common.hpp"

void foo(sycl::queue Queue, size_t N, int *X, int *Y, int *Z) {

auto Graph = Queue.ext_oneapi_get_graph();
Graph.end_recording();

Queue
.submit([&](handler &CGH) {
CGH.parallel_for(N, [=](id<1> it) {
X[it] = 0;
Y[it] = 0;
Z[it] = 0;
});
})
.wait();

Graph.begin_recording(Queue);
}

int main() {
property_list Properties{
property::queue::in_order{},
sycl::ext::intel::property::queue::no_immediate_command_list{}};
queue Queue{Properties};

if (!are_graphs_supported(Queue)) {
return 0;
}

exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()};

int *Dotp = malloc_device<int>(1, Queue);
Queue.memset(Dotp, 0, sizeof(int)).wait();

const size_t N = 10;
int *X = malloc_device<int>(N, Queue);
int *Y = malloc_device<int>(N, Queue);
int *Z = malloc_device<int>(N, Queue);

Graph.begin_recording(Queue);

auto InitEvent = Queue.submit([&](handler &CGH) {
CGH.parallel_for(N, [=](id<1> it) {
X[it] = 1;
Y[it] = 2;
Z[it] = 3;
});
});

foo(Queue, N, X, Y, Z);

auto EventA = Queue.submit([&](handler &CGH) {
CGH.parallel_for(range<1>{N},
[=](id<1> it) { X[it] = Alpha * X[it] + Beta * Y[it]; });
});

auto EventB = Queue.submit([&](handler &CGH) {
CGH.parallel_for(range<1>{N},
[=](id<1> it) { Z[it] = Gamma * Z[it] + Beta * Y[it]; });
});

Queue.submit([&](handler &CGH) {
CGH.single_task([=]() {
for (size_t j = 0; j < N; j++) {
Dotp[0] += X[j] * Z[j];
}
});
});

Graph.end_recording();

auto ExecGraph = Graph.finalize();

Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecGraph); });

int Output;
Queue.memcpy(&Output, Dotp, sizeof(int)).wait();

assert(Output == dotp_reference_result(N));

sycl::free(Dotp, Queue);
sycl::free(X, Queue);
sycl::free(Y, Queue);
sycl::free(Z, Queue);

return 0;
}
3 changes: 2 additions & 1 deletion sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3935,9 +3935,9 @@ _ZN4sycl3_V16detail13MemoryManager29ext_oneapi_copyD2D_cmd_bufferESt10shared_ptr
_ZN4sycl3_V16detail13MemoryManager29ext_oneapi_copyD2H_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPNS1_11SYCLMemObjIEPvjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjPcjSC_SE_jSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager29ext_oneapi_copyH2D_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPNS1_11SYCLMemObjIEPcjNS0_5rangeILi3EEENS0_2idILi3EEEjPvjSC_SC_SE_jSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager30ext_oneapi_copy_usm_cmd_bufferESt10shared_ptrINS1_12context_implEEPKvP22_pi_ext_command_buffermPvSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager30ext_oneapi_fill_usm_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPvmiSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager32ext_oneapi_advise_usm_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPKvm14_pi_mem_adviceSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager34ext_oneapi_prefetch_usm_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPvmSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager30ext_oneapi_fill_usm_cmd_bufferESt10shared_ptrINS1_12context_implEEP22_pi_ext_command_bufferPvmiSt6vectorIjSaIjEEPj
_ZN4sycl3_V16detail13MemoryManager3mapEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEENS0_6access4modeEjNS0_5rangeILi3EEESC_NS0_2idILi3EEEjSt6vectorIP9_pi_eventSaISH_EERSH_
_ZN4sycl3_V16detail13MemoryManager4copyEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEjNS0_5rangeILi3EEESA_NS0_2idILi3EEEjS5_S8_jSA_SA_SC_jSt6vectorIP9_pi_eventSaISF_EERSF_
_ZN4sycl3_V16detail13MemoryManager4copyEPNS1_11SYCLMemObjIEPvSt10shared_ptrINS1_10queue_implEEjNS0_5rangeILi3EEESA_NS0_2idILi3EEEjS5_S8_jSA_SA_SC_jSt6vectorIP9_pi_eventSaISF_EERSF_RKS6_INS1_10event_implEE
Expand Down Expand Up @@ -4270,6 +4270,7 @@ _ZNK4sycl3_V15queue12has_propertyINS0_8property5queue16enable_profilingEEEbv
_ZNK4sycl3_V15queue12has_propertyINS0_8property5queue4cuda18use_default_streamEEEbv
_ZNK4sycl3_V15queue12has_propertyINS0_8property5queue8in_orderEEEbv
_ZNK4sycl3_V15queue16ext_oneapi_emptyEv
_ZNK4sycl3_V15queue20ext_oneapi_get_graphEv
_ZNK4sycl3_V15queue20ext_oneapi_get_stateEv
_ZNK4sycl3_V15queue25ext_oneapi_get_last_eventEv
_ZNK4sycl3_V15queue28ext_codeplay_supports_fusionEv
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@
?ext_oneapi_get_default_context@platform@_V1@sycl@@QEBA?AVcontext@23@XZ
?ext_oneapi_get_kernel@kernel_bundle_plain@detail@_V1@sycl@@QEAA?AVkernel@34@AEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@@Z
?ext_oneapi_get_last_event@queue@_V1@sycl@@QEBA?AVevent@23@XZ
?ext_oneapi_get_graph@queue@_V1@sycl@@QEBA?AV?$command_graph@$0A@@experimental@oneapi@ext@23@XZ
?ext_oneapi_get_state@queue@_V1@sycl@@QEBA?AW4queue_state@experimental@oneapi@ext@23@XZ
?ext_oneapi_graph@handler@_V1@sycl@@QEAAXV?$command_graph@$00@experimental@oneapi@ext@23@@Z
?ext_oneapi_graph@queue@_V1@sycl@@QEAA?AVevent@23@V?$command_graph@$00@experimental@oneapi@ext@23@AEBUcode_location@detail@23@@Z
Expand Down

0 comments on commit c835f82

Please sign in to comment.