@@ -33,6 +33,12 @@ struct kernel_node_params {
3333 unsigned int shared_mem_bytes{};
3434
3535 std::vector<dpct::experimental::node_ptr> dependencies{};
36+ kernel_node_params () = default ;
37+ kernel_node_params (const kernel_node_params &other)
38+ : block_dim(other.block_dim), grid_dim(other.grid_dim),
39+ kernel_params (other.kernel_params), func(other.func),
40+ shared_mem_bytes(other.shared_mem_bytes),
41+ dependencies(other.dependencies) {}
3642
3743public:
3844 void set_block_dim (const dpct::dim3 &block_dim) {
@@ -142,6 +148,7 @@ class graph_mgr {
142148 dpct::experimental::node_ptr *dependencies,
143149 std::size_t numberOfDependencies,
144150 dpct::experimental::kernel_node_params *params) {
151+ node_graph_params_map[*node] = std::make_pair (graph, params);
145152 for (std::size_t i = 0 ; i < numberOfDependencies; i++) {
146153 params->add_dependency (dependencies[i]);
147154 }
@@ -156,7 +163,6 @@ class graph_mgr {
156163 for (std::size_t i = 0 ; i < kernel_params_vector.size (); i++) {
157164 auto &node_kernel_params_pair = kernel_params_vector[i];
158165 auto node_params = node_kernel_params_pair.second ;
159-
160166 const auto &dependency_ptrs = node_params->get_dependencies ();
161167 std::vector<sycl::ext::oneapi::experimental::node> dependencies;
162168 dependencies.reserve (dependency_ptrs.size ());
@@ -184,9 +190,12 @@ class graph_mgr {
184190 }
185191 node_kernel_params_pair.first = new_node;
186192 }
187- auto final_graph = graph->finalize ();
193+ execGraph = new sycl::ext::oneapi::experimental::command_graph<
194+ sycl::ext::oneapi::experimental::graph_state::executable>(
195+ graph->finalize (
196+ sycl::ext::oneapi::experimental::property::graph::updatable{}));
188197 queue->submit (
189- [&](sycl::handler &cgh) { cgh.ext_oneapi_graph (final_graph ); });
198+ [&](sycl::handler &cgh) { cgh.ext_oneapi_graph (*execGraph ); });
190199 }
191200
192201 void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
@@ -195,7 +204,31 @@ class graph_mgr {
195204 }
196205
197206 void kernel_node_get_params (dpct::experimental::node_ptr node,
198- dpct::experimental::kernel_node_params *params) {}
207+ dpct::experimental::kernel_node_params *params) {
208+ auto it = node_graph_params_map.find (node);
209+ if (it == node_graph_params_map.end ()) {
210+ return ;
211+ }
212+ *params = *(it->second .second );
213+ }
214+
215+ void kernel_node_set_params (dpct::experimental::node_ptr node,
216+ dpct::experimental::kernel_node_params *params) {
217+ node_graph_params_map[node].second = params;
218+ }
219+
220+ void get_node_type (dpct::experimental::node_ptr node,
221+ sycl::ext::oneapi::experimental::node_type *nodeType) {
222+ if (node_graph_params_map.find (node) != node_graph_params_map.end ()) {
223+ *nodeType = sycl::ext::oneapi::experimental::node_type::kernel;
224+ } else {
225+ if (node) {
226+ *nodeType = node->get_type ();
227+ } else {
228+ *nodeType = sycl::ext::oneapi::experimental::node_type::empty;
229+ }
230+ }
231+ }
199232
200233private:
201234 std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
@@ -214,8 +247,9 @@ class graph_mgr {
214247 dpct::experimental::kernel_node_params *>>>
215248 graph_kernel_node_params_map;
216249 std::unordered_map<dpct::experimental::node_ptr,
217- dpct::experimental::kernel_node_params>
218- node_params_map;
250+ std::pair<dpct::experimental::command_graph_ptr,
251+ dpct::experimental::kernel_node_params *>>
252+ node_graph_params_map;
219253};
220254} // namespace detail
221255
@@ -326,11 +360,31 @@ static void launch(dpct::experimental::command_graph_exec_ptr execGraph,
326360
327361static void
328362kernel_node_get_params (dpct::experimental::node_ptr node,
329- dpct::experimental::kernel_node_params *params) {}
363+ dpct::experimental::kernel_node_params *params) {
364+ detail::graph_mgr::instance ().kernel_node_get_params (node, params);
365+ }
330366
331367static void
332368kernel_node_set_params (dpct::experimental::node_ptr node,
333- dpct::experimental::kernel_node_params *params) {}
369+ dpct::experimental::kernel_node_params *params) {
370+ detail::graph_mgr::instance ().kernel_node_set_params (node, params);
371+ }
372+
373+ static void
374+ get_node_type (dpct::experimental::node_ptr node,
375+ sycl::ext::oneapi::experimental::node_type *nodeType) {
376+ detail::graph_mgr::instance ().get_node_type (node, nodeType);
377+ }
378+
379+ static void update (dpct::experimental::command_graph_exec_ptr graphExec,
380+ dpct::experimental::command_graph_ptr graph,
381+ int *updateResultInfo) {
382+ graphExec->update (*graph);
383+ if (!graphExec) {
384+ *updateResultInfo = 0 ;
385+ }
386+ *updateResultInfo = 1 ;
387+ }
334388
335389} // namespace experimental
336390} // namespace dpct
0 commit comments