Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix panic for ancestors and descendants when the source node is invalid #1389

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions releasenotes/notes/fix-ancestors-panic-3c64b9b43bb0551a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fixes:
- |
Fixed a panic when passing an invalid source node to
:func:`~rustworkx.ancenstors` and :func:`~rustworkx.descendants`. See
`#1381 <https://github.com/Qiskit/rustworkx/issues/1381>`__ for more information.
27 changes: 20 additions & 7 deletions src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::convert::TryFrom;

use hashbrown::HashSet;

use pyo3::exceptions::PyTypeError;
use pyo3::exceptions::{PyIndexError, PyTypeError};
use pyo3::prelude::*;
use pyo3::Python;

Expand Down Expand Up @@ -219,11 +219,18 @@ pub fn bfs_predecessors(
/// :rtype: set
#[pyfunction]
#[pyo3(text_signature = "(graph, node, /)")]
pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
core_ancestors(&graph.graph, NodeIndex::new(node))
pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> PyResult<HashSet<usize>> {
let index = NodeIndex::new(node);
if !graph.graph.contains_node(index) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{}\" out of graph bound",
node
)));
}
Ok(core_ancestors(&graph.graph, index)
.map(|x| x.index())
.filter(|x| *x != node)
.collect()
.collect())
}

/// Return the descendants of a node in a graph.
Expand All @@ -240,12 +247,18 @@ pub fn ancestors(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
/// :rtype: set
#[pyfunction]
#[pyo3(text_signature = "(graph, node, /)")]
pub fn descendants(graph: &digraph::PyDiGraph, node: usize) -> HashSet<usize> {
pub fn descendants(graph: &digraph::PyDiGraph, node: usize) -> PyResult<HashSet<usize>> {
let index = NodeIndex::new(node);
core_descendants(&graph.graph, index)
if !graph.graph.contains_node(index) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{}\" out of graph bound",
node
)));
}
Ok(core_descendants(&graph.graph, index)
.map(|x| x.index())
.filter(|x| *x != node)
.collect()
.collect())
}

/// Breadth-first traversal of a directed graph with several source vertices.
Expand Down
10 changes: 10 additions & 0 deletions tests/digraph/test_ancestors_descendants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def test_ancestors_no_descendants(self):
res = rustworkx.ancestors(dag, node_b)
self.assertEqual({node_a}, res)

def test_invalid_source(self):
graph = rustworkx.generators.directed_path_graph(5)
with self.assertRaises(IndexError):
rustworkx.ancestors(graph, 10)


class TestDescendants(unittest.TestCase):
def test_descendants(self):
Expand All @@ -62,3 +67,8 @@ def test_descendants_no_ancestors(self):
node_c = dag.add_child(node_b, "c", {"b": 1})
res = rustworkx.descendants(dag, node_b)
self.assertEqual({node_c}, res)

def test_invalid_source(self):
graph = rustworkx.generators.directed_path_graph(5)
with self.assertRaises(IndexError):
rustworkx.descendants(graph, 10)