Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3007,6 +3007,7 @@ impl PyDiGraph {
}
(None, true) => self.graph.contract_nodes(nodes, obj, check_cycle)?,
};
self.node_removed = true;
Ok(res.index())
}

Expand Down
1 change: 1 addition & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,7 @@ impl PyGraph {
}
(None, true) => self.graph.contract_nodes(nodes, obj),
};
self.node_removed = true;
Ok(res.index())
}

Expand Down
40 changes: 32 additions & 8 deletions src/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,18 @@ pub fn from_node_link_json_file<'py>(
Ok(if graph.directed {
let mut inner_graph: StablePyGraph<Directed> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
let node_removed = node_link_data::parse_node_link_data(
&py,
graph,
&mut inner_graph,
node_attrs,
edge_attrs,
)?;
digraph::PyDiGraph {
graph: inner_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
node_removed,
multigraph,
attrs,
}
Expand All @@ -88,11 +94,17 @@ pub fn from_node_link_json_file<'py>(
} else {
let mut inner_graph: StablePyGraph<Undirected> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
let node_removed = node_link_data::parse_node_link_data(
&py,
graph,
&mut inner_graph,
node_attrs,
edge_attrs,
)?;

graph::PyGraph {
graph: inner_graph,
node_removed: false,
node_removed,
multigraph,
attrs,
}
Expand Down Expand Up @@ -150,12 +162,18 @@ pub fn parse_node_link_json<'py>(
Ok(if graph.directed {
let mut inner_graph: StablePyGraph<Directed> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
let node_removed = node_link_data::parse_node_link_data(
&py,
graph,
&mut inner_graph,
node_attrs,
edge_attrs,
)?;
digraph::PyDiGraph {
graph: inner_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
node_removed,
multigraph,
attrs,
}
Expand All @@ -164,10 +182,16 @@ pub fn parse_node_link_json<'py>(
} else {
let mut inner_graph: StablePyGraph<Undirected> =
StablePyGraph::with_capacity(graph.nodes.len(), graph.links.len());
node_link_data::parse_node_link_data(&py, graph, &mut inner_graph, node_attrs, edge_attrs)?;
let node_removed = node_link_data::parse_node_link_data(
&py,
graph,
&mut inner_graph,
node_attrs,
edge_attrs,
)?;
graph::PyGraph {
graph: inner_graph,
node_removed: false,
node_removed,
multigraph,
attrs,
}
Expand Down
80 changes: 64 additions & 16 deletions src/json/node_link_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,70 @@ pub fn parse_node_link_data<Ty: EdgeType>(
out_graph: &mut StablePyGraph<Ty>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<()> {
) -> PyResult<bool> {
let mut id_mapping: HashMap<usize, NodeIndex> = HashMap::with_capacity(graph.nodes.len());
for node in graph.nodes {
let payload = match node.data {
Some(data) => match node_attrs {
Some(ref callback) => callback.call1(*py, (data,))?,
None => data.into_py_any(*py)?,
},
None => py.None(),
};
let id = out_graph.add_node(payload);
match node.id {
Some(input_id) => id_mapping.insert(input_id, id),
None => id_mapping.insert(id.index(), id),
};
}

// Check if nodes have explicit IDs that need preservation
let preserve_ids = graph.nodes.iter().any(|n| n.id.is_some());

let node_removed = if preserve_ids {
// Find the maximum node ID to determine how many placeholder nodes we need
let max_id = graph.nodes.iter().filter_map(|n| n.id).max().unwrap_or(0);

// Create placeholder nodes up to max_id
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
for _ in 0..=max_id {
let idx = out_graph.add_node(py.None());
tmp_nodes.push(idx);
}
Comment on lines +97 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a more idiomatic way of writing this:

let tmp_nodes: Vec<NodeIndex> = (0..=max_id)
    .map(|_| out_graph.add_node(py.None()))
    .collect();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion also works with HashSet by the way


// Replace placeholder nodes with actual data and track which to keep
for node in graph.nodes {
let payload = match node.data {
Some(data) => match node_attrs {
Some(ref callback) => callback.call1(*py, (data,))?,
None => data.into_py_any(*py)?,
},
None => py.None(),
};
let node_id = node.id.unwrap_or(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let node_id = node.id.unwrap_or(0);
let node_id = node.id?;

I don't think we should default to 0 if it is not defined.

let idx = NodeIndex::new(node_id);

// Replace the placeholder with actual data
if let Some(weight) = out_graph.node_weight_mut(idx) {
*weight = payload;
}

id_mapping.insert(node_id, idx);
// Mark this index as used (remove from tmp_nodes)
tmp_nodes.retain(|&n| n != idx);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can lead to quadratic behaviour, we are looping through all tmp_nodes to delete the index.

Please use https://docs.rs/hashbrown/latest/hashbrown/struct.HashSet.html instead

}

// Track if we're removing any nodes (indicates gaps in indices)
let has_gaps = !tmp_nodes.is_empty();

// Remove remaining placeholder nodes
for tmp_node in tmp_nodes {
out_graph.remove_node(tmp_node);
}

has_gaps
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit picking: can you swap the if/else order? i.e. use if ! preserve_ids.

This is a style choice but I prefer to have the most common case first in if statements.

// No explicit IDs, just add nodes sequentially (legacy behavior)
for node in graph.nodes {
let payload = match node.data {
Some(data) => match node_attrs {
Some(ref callback) => callback.call1(*py, (data,))?,
None => data.into_py_any(*py)?,
},
None => py.None(),
};
let id = out_graph.add_node(payload);
id_mapping.insert(id.index(), id);
}
false
};

for edge in graph.links {
let data = match edge.data {
Some(data) => match edge_attrs {
Expand All @@ -109,7 +157,7 @@ pub fn parse_node_link_data<Ty: EdgeType>(
};
out_graph.add_edge(id_mapping[&edge.source], id_mapping[&edge.target], data);
}
Ok(())
Ok(node_removed)
}

#[allow(clippy::too_many_arguments)]
Expand Down
42 changes: 42 additions & 0 deletions tests/digraph/test_node_link_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,45 @@ def test_not_JSON(self):
"""
with self.assertRaises(rustworkx.JSONDeserializationError):
rustworkx.parse_node_link_json(invalid_input)

def test_node_indices_preserved_with_deletion(self):
"""Test that node indices are preserved after deletion (related to issue #1503)"""
graph = rustworkx.PyDiGraph()
graph.add_node("A") # 0
graph.add_node("B") # 1
graph.add_node("C") # 2
graph.add_edge(0, 2, "A->C")
graph.remove_node(1) # Remove middle node

# Verify original has gaps in indices
self.assertEqual([0, 2], graph.node_indices())

# Round-trip through JSON
json_str = rustworkx.node_link_json(graph)
restored = rustworkx.parse_node_link_json(json_str)

# Verify indices are preserved
self.assertEqual(graph.node_indices(), restored.node_indices())
self.assertEqual(graph.edge_list(), restored.edge_list())

def test_node_indices_preserved_with_contraction(self):
"""Test that node indices are preserved after contraction (issue #1503)"""
graph = rustworkx.PyDiGraph()
graph.add_node("A") # 0
graph.add_node("B") # 1
graph.add_node("C") # 2

# Contract nodes 0 and 1
contracted_idx = graph.contract_nodes([0, 1], "AB")
graph.add_edge(2, contracted_idx, "C->AB")

# Verify original has non-consecutive indices
self.assertEqual([2, contracted_idx], graph.node_indices())

# Round-trip through JSON
json_str = rustworkx.node_link_json(graph)
restored = rustworkx.parse_node_link_json(json_str)

# Verify indices are preserved
self.assertEqual(graph.node_indices(), restored.node_indices())
self.assertEqual(graph.edge_list(), restored.edge_list())
23 changes: 23 additions & 0 deletions tests/digraph/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,26 @@ def test_weight_graph(self):
self.assertEqual([1, 2, 3], gprime.node_indices())
self.assertEqual(["B", "C", "D"], gprime.nodes())
self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map()))

def test_contracted_nodes_pickle(self):
"""Test pickle/unpickle of directed graphs with contracted nodes (issue #1503)"""
g = rx.PyDiGraph()
g.add_node("A") # Node 0
g.add_node("B") # Node 1
g.add_node("C") # Node 2

# Contract nodes 0 and 1 into a new node
contracted_idx = g.contract_nodes([0, 1], "AB")
g.add_edge(2, contracted_idx, "C -> AB")

# Verify initial state
self.assertEqual([2, contracted_idx], g.node_indices())
self.assertEqual([(2, contracted_idx)], g.edge_list())

# Test pickle/unpickle
gprime = pickle.loads(pickle.dumps(g))

# Verify the unpickled graph matches
self.assertEqual(g.node_indices(), gprime.node_indices())
self.assertEqual(g.edge_list(), gprime.edge_list())
self.assertEqual(g.nodes(), gprime.nodes())
67 changes: 67 additions & 0 deletions tests/graph/test_node_link_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,70 @@ def test_round_trip_with_file_no_graph_attr(self):
self.assertEqual(new.nodes(), graph.nodes())
self.assertEqual(new.weighted_edge_list(), graph.weighted_edge_list())
self.assertEqual(new.attrs, {"label": graph.attrs})

def test_node_indices_preserved_with_deletion(self):
"""Test that node indices are preserved after deletion (related to issue #1503)"""
graph = rustworkx.PyGraph()
graph.add_node(None) # 0
graph.add_node(None) # 1
graph.add_node(None) # 2
graph.add_edge(0, 2, None)
graph.remove_node(1) # Remove middle node

# Verify original has gaps in indices
self.assertEqual([0, 2], graph.node_indices())

# Round-trip through JSON
json_str = rustworkx.node_link_json(graph)
restored = rustworkx.parse_node_link_json(json_str)

# Verify indices are preserved
self.assertEqual(graph.node_indices(), restored.node_indices())
self.assertEqual(graph.edge_list(), restored.edge_list())

def test_node_indices_preserved_with_contraction(self):
"""Test that node indices are preserved after contraction (issue #1503)"""
graph = rustworkx.PyGraph()
graph.add_node(None) # 0
graph.add_node(None) # 1
graph.add_node(None) # 2

# Contract nodes 0 and 1
contracted_idx = graph.contract_nodes([0, 1], None)
graph.add_edge(2, contracted_idx, None)

# Verify original has non-consecutive indices
self.assertEqual([2, contracted_idx], graph.node_indices())

# Round-trip through JSON
json_str = rustworkx.node_link_json(graph)
restored = rustworkx.parse_node_link_json(json_str)

# Verify indices are preserved
self.assertEqual(graph.node_indices(), restored.node_indices())
self.assertEqual(graph.edge_list(), restored.edge_list())

def test_node_indices_preserved_complex(self):
"""Test index preservation with multiple deletions and edges"""
graph = rustworkx.PyGraph()
for i in range(6):
graph.add_node(None)

graph.add_edge(0, 1, None)
graph.add_edge(2, 3, None)
graph.add_edge(4, 5, None)

# Remove nodes 1 and 4
graph.remove_node(1)
graph.remove_node(4)

# Verify gaps exist
self.assertEqual([0, 2, 3, 5], graph.node_indices())

# Round-trip through JSON
json_str = rustworkx.node_link_json(graph)
restored = rustworkx.parse_node_link_json(json_str)

# Verify complete state is preserved
self.assertEqual(graph.node_indices(), restored.node_indices())
self.assertEqual(graph.edge_list(), restored.edge_list())
43 changes: 43 additions & 0 deletions tests/graph/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,46 @@ def test_weight_graph(self):
self.assertEqual([1, 2, 3], gprime.node_indices())
self.assertEqual(["B", "C", "D"], gprime.nodes())
self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map()))

def test_contracted_nodes_pickle(self):
"""Test pickle/unpickle of graphs with contracted nodes (issue #1503)"""
g = rx.PyGraph()
g.add_node("A") # Node 0
g.add_node("B") # Node 1
g.add_node("C") # Node 2

# Contract nodes 0 and 1 into a new node
contracted_idx = g.contract_nodes([0, 1], "AB")
g.add_edge(2, contracted_idx, "C -> AB")

# Verify initial state
self.assertEqual([2, contracted_idx], g.node_indices())
self.assertEqual([(2, contracted_idx)], g.edge_list())

# Test pickle/unpickle
gprime = pickle.loads(pickle.dumps(g))

# Verify the unpickled graph matches
self.assertEqual(g.node_indices(), gprime.node_indices())
self.assertEqual(g.edge_list(), gprime.edge_list())
self.assertEqual(g.nodes(), gprime.nodes())

def test_contracted_nodes_with_weights_pickle(self):
"""Test pickle/unpickle of graphs with contracted nodes and edge weights"""
g = rx.PyGraph()
g.add_nodes_from(["Node0", "Node1", "Node2", "Node3"])
g.add_edges_from([(0, 2, "edge_0_2"), (1, 3, "edge_1_3")])

# Contract multiple nodes
contracted_idx = g.contract_nodes([0, 1], "Contracted_0_1")
g.add_edge(contracted_idx, 2, "contracted_to_2")
g.add_edge(3, contracted_idx, "3_to_contracted")

# Test pickle/unpickle
gprime = pickle.loads(pickle.dumps(g))

# Verify complete graph state is preserved
self.assertEqual(g.node_indices(), gprime.node_indices())
self.assertEqual(g.edge_list(), gprime.edge_list())
self.assertEqual(g.nodes(), gprime.nodes())
self.assertEqual(dict(g.edge_index_map()), dict(gprime.edge_index_map()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course you copied examples that don't use it but: https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertDictEqual

This just has a nicer error message, in case someone ever introduces a bug