Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5917,7 +5917,8 @@ struct OrtApi {
/** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph.
*
* \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference
* the same underlying graph.
* the same underlying graph. "dst_graph" preserves the input order of "src_graph", and
* its output order corresponds to the outputs produced by the nodes in "nodes" with the given order.
*
* \param[in] src_graph The source OrtGraph instance.
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
Expand Down
41 changes: 35 additions & 6 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
producer_info.output_index = 0;

if (graph_ == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get producer node for OrtValueInfo '", name_,
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_FOUND, "Unable to get producer node for OrtValueInfo '", name_,
"' that is not owned by a OrtGraph.");
}

Expand All @@ -379,7 +379,13 @@

const EpNode* ep_node = graph_->GetNode(node->Index());
if (ep_node == nullptr) {
return Status::OK(); // Node is not in this GraphViewer
producer_info.node = nullptr;
producer_info.output_index = 0;
const auto& logger = graph_->GetGraphViewer().GetGraph().GetLogger();

Check failure on line 384 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'logger': references must be initialized

Check failure on line 384 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'GetLogger': is not a member of 'onnxruntime::Graph'
LOGS(logger, WARNING) << "Unable to get producer node for OrtValueInfo '"

Check failure on line 385 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'logger': cannot be used before it is initialized
<< name_
<< "' that is not owned by an OrtGraph.";
return Status::OK();
}

size_t output_index = 0;
Expand Down Expand Up @@ -539,10 +545,14 @@
size_t num_elems = (max_node_index - min_node_index) + 1;

min_node_index_ = min_node_index;
max_node_index_ = max_node_index;
nodes_.resize(num_elems, nullptr);
}

EpNode* EpGraph::IndexToEpNodeMap::GetEpNode(NodeIndex node_index) const {
if (node_index < min_node_index_ || node_index > max_node_index_) {
return nullptr;
}
size_t i = node_index - min_node_index_;
assert(i < nodes_.size());
return nodes_[i];
Expand All @@ -566,10 +576,10 @@
owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
return CreateImpl(std::move(ep_graph), graph_viewer, result, create_parent_node);
}

// Static class function to create a std::unique_ptr<EpGraph>.
Expand All @@ -584,7 +594,8 @@
return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node) {
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;

Expand Down Expand Up @@ -687,13 +698,23 @@
}
}

std::unique_ptr<EpNode> ep_parent_node = nullptr;

// If this is a subgraph, add the OrtValueInfo and OrtValue objects that come from the outer scope.
// Wait until we have already processed OrtValueInfos consumed and produced by nodes so that we only add
// outer OrtValueInfo/OrtValue if they are actually used by the nodes in this GraphViewer.
if (graph_viewer.IsSubgraph()) {
gsl::not_null<const Graph*> parent_graph = graph_viewer.GetGraph().ParentGraph();
gsl::not_null<const Node*> parent_node = graph_viewer.ParentNode();

if (create_parent_node) {
std::unique_ptr<EpNode> ep_node = nullptr;

std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map_tmp; // won't be used
ORT_RETURN_IF_ERROR(EpNode::Create(*parent_node, ep_graph.get(), value_infos_map_tmp, ep_node));
ep_parent_node = std::move(ep_node);
}

for (gsl::not_null<const NodeArg*> implicit_node_arg : parent_node->ImplicitInputDefs()) {
const std::string& implicit_name = implicit_node_arg->Name();
auto value_info_iter = value_infos_map.find(implicit_name);
Expand Down Expand Up @@ -741,6 +762,7 @@
ep_graph->outer_scope_initializer_values_ = std::move(outer_scope_initializer_values);
ep_graph->inputs_ = std::move(graph_input_value_infos);
ep_graph->outputs_ = std::move(graph_output_value_infos);
ep_graph->parent_node_owned_ = std::move(ep_parent_node);

result = std::move(ep_graph);

Expand Down Expand Up @@ -872,7 +894,14 @@
}

Status EpGraph::GetParentNode(const OrtNode*& result) const {
result = parent_node_ != nullptr ? parent_node_->ToExternal() : nullptr;
if (parent_node_ != nullptr) {
result = parent_node_->ToExternal();
} else if (parent_node_owned_) {
result = parent_node_owned_->ToExternal();
} else {
result = nullptr;
}

return Status::OK();
}

Expand Down
22 changes: 18 additions & 4 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ struct EpGraph : public OrtGraph {

private:
NodeIndex min_node_index_ = 0;
NodeIndex max_node_index_ = 0;
std::vector<EpNode*> nodes_;
};

Expand All @@ -269,8 +270,16 @@ struct EpGraph : public OrtGraph {
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <param name="create_parent_node">If the `graph_viewer` is a subgraph of a control flow op,
/// e.g. For/If/Scan op, and `create_parent_node` is set to true,
/// then `result` EpGraph will create and own parent node's EpNode
/// instance. It's mainly used in EP's GetCapability() as it's
/// a bottom-up approach where inner-most subgraph will be constructed
/// first and by the time its parent node/graph hasn't be constructed yet.</param>
/// <returns></returns>
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
static Status Create(const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result,
bool create_parent_node = false);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
Expand Down Expand Up @@ -364,16 +373,21 @@ struct EpGraph : public OrtGraph {
private:
/// <summary>
/// The real implementation of creating an EpGraph instance.
/// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly.
/// Please use one of the above 'Create' functions that internally call this function,
/// and avoid calling this function directly.
/// </summary>
/// <param name="ep_graph"></param>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <param name="create_parent_node"></param>
/// <returns></returns>
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node = false);

const GraphViewer& graph_viewer_;
const EpNode* parent_node_ = nullptr;
const EpNode* parent_node_ = nullptr; // Keep the pointer to the parent node that
// is not owned by this graph
std::unique_ptr<EpNode> parent_node_owned_ = nullptr; // Hold the parent node created and owned by this graph

std::unique_ptr<GraphViewer> owned_graph_viewer_ = nullptr;
std::unique_ptr<IndexedSubGraph> owned_indexed_sub_graph_ = nullptr;
Expand Down
Loading
Loading