From 629d04e9b1e053208a69ddc9f817ecc31fec2fe2 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 6 Jun 2023 23:28:46 -0400 Subject: [PATCH] Fix pickle/deepcopy node hole handling (#888) * Fix pickle/deepcopy node hole handling This commit fixes an issue introduced by #589 where in certain cases node holes in a graph would result in a panic being raised. This was caused by a logic bug in trying to recreate the holes. Additionally, there were several places where graph methods removed nodes that the flag to indicate there were removals would no be set. This commit fixes all of these issues so that deepcopy/pickle works as expected. * Fix test failures * Fix lint * Update src/digraph.rs --- ...x-removed-nodes-attr-d1829e1f4462d96a.yaml | 7 ++ src/digraph.rs | 83 +++++-------------- src/graph.rs | 78 +++++------------ .../rustworkx_tests/digraph/test_deepcopy.py | 26 ++++++ tests/rustworkx_tests/graph/test_deepcopy.py | 26 ++++++ 5 files changed, 98 insertions(+), 122 deletions(-) create mode 100644 releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml diff --git a/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml new file mode 100644 index 0000000000..7900c2139f --- /dev/null +++ b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed an issue with several :class:`~.PyDiGraph` and :class:`~.PyGraph` + methods that removed nodes where previously when calling + these methods the :attr:`.PyDiGraph.node_removed` attribute would not be + updated to reflect that nodes were removed. diff --git a/src/digraph.rs b/src/digraph.rs index a48f2d8b65..6c177e775f 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -29,7 +29,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -44,7 +44,7 @@ use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, - Visitable, + NodeIndexable, Visitable, }; use super::dot_utils::build_dot; @@ -318,7 +318,6 @@ impl PyDiGraph { }; edges.push(edge); } - let out_dict = PyDict::new(py); let nodes_lst: PyObject = PyList::new(py, nodes).into(); let edges_lst: PyObject = PyList::new(py, edges).into(); @@ -398,55 +397,22 @@ impl PyDiGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } // Remove any temporary nodes we added for tmp_node in tmp_nodes { @@ -463,20 +429,8 @@ impl PyDiGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1760,8 +1714,8 @@ impl PyDiGraph { /// the graph. #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -2389,7 +2343,7 @@ impl PyDiGraph { // If no nodes are copied bail here since there is nothing left // to do. if out_map.is_empty() { - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; // Return a new empty map to clear allocation from out_map return Ok(NodeMap { node_map: DictMap::new(), @@ -2450,7 +2404,7 @@ impl PyDiGraph { self._add_edge(source_out, target, weight)?; } // Remove node - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; Ok(NodeMap { node_map: out_map }) } @@ -2559,7 +2513,7 @@ impl PyDiGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -2912,7 +2866,10 @@ impl PyDiGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/src/graph.rs b/src/graph.rs index 75165cc4c1..04c90c4c79 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -26,7 +26,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -47,6 +47,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, + NodeIndexable, }; /// A class for creating undirected graphs @@ -284,56 +285,24 @@ impl PyGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } + // Remove any temporary nodes we added for tmp_node in tmp_nodes { self.graph.remove_node(tmp_node); } @@ -348,20 +317,8 @@ impl PyGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1062,8 +1019,8 @@ impl PyGraph { /// the graph #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -1695,7 +1652,7 @@ impl PyGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -1846,7 +1803,10 @@ impl PyGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/tests/rustworkx_tests/digraph/test_deepcopy.py b/tests/rustworkx_tests/digraph/test_deepcopy.py index bd296a5a5a..5422512732 100644 --- a/tests/rustworkx_tests/digraph/test_deepcopy.py +++ b/tests/rustworkx_tests/digraph/test_deepcopy.py @@ -70,3 +70,29 @@ def test_deepcopy_different_objects(self): self.assertIsNot( graph_a.get_edge_data(node_a, node_b), graph_b.get_edge_data(node_a, node_b) ) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + ) diff --git a/tests/rustworkx_tests/graph/test_deepcopy.py b/tests/rustworkx_tests/graph/test_deepcopy.py index 074d032118..6941d2dcb8 100644 --- a/tests/rustworkx_tests/graph/test_deepcopy.py +++ b/tests/rustworkx_tests/graph/test_deepcopy.py @@ -48,3 +48,29 @@ def test_deepcopy_attrs(self): graph = rustworkx.PyGraph(attrs="abc") graph_copy = copy.deepcopy(graph) self.assertEqual(graph.attrs, graph_copy.attrs) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + )