diff --git a/requirements/base.txt b/requirements/base.txt index fc9b17e751..e1c4766f6a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,7 +2,7 @@ torch >=2.3.0 looseversion ==1.3.0 lightning-utilities >=0.7.0 numpy >=1.23.0,<2 # not yet ready for numpy 2 -igraph >=0.10.4 +networkx >= 3.3 optree >=0.12.1 opt_einsum >= 3.3.0 mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined` diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 3c64737ee6..53f60ea5ae 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -7,7 +7,7 @@ from collections import defaultdict import time -from igraph import Graph +import networkx as nx from thunder.core import prims, utils from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface @@ -317,12 +317,9 @@ def find_cut( # Create a graph edges = [] - name_to_id = {} - capacities = [] def add_edge(src, dst, capacity): - edges.append((name_to_id.setdefault(src, len(name_to_id)), name_to_id.setdefault(dst, len(name_to_id)))) - capacities.append(capacity) + edges.append((src, dst, {"capacity": capacity})) utils.check( len(required_consumer_vars) > 0, @@ -374,23 +371,17 @@ def add_edges(var): for var in symbol.flat_proxy_outs: add_edges(var) - g = Graph( - n=len(name_to_id), - edges=edges, - directed=True, - edge_attrs={"capacity": capacities}, - ) - source = name_to_id["source"] - sink = name_to_id["sink"] + g = nx.DiGraph() + g.add_edges_from(edges) + + _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") - id_to_name = dict(map(reversed, name_to_id.items())) + cut_edges = set() + for u, nbrs in ((n, g[n]) for n in reachable): + cut_edges.update((u, v) for v in nbrs if v in non_reachable) - g_edges = g.get_edgelist() - cut = g.mincut(source, sink, "capacity").cut cut_nodes = set() - for cut_edge_id in cut: - u, v = g_edges[cut_edge_id] - node_in, node_out = id_to_name[u], id_to_name[v] + for node_in, node_out in cut_edges: if node_out == "sink": continue assert node_in.endswith("_in"), node_in diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index cb4ece6aa0..f6e68f0e23 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -281,16 +281,13 @@ def func(x): # There are two nvfuser fusion groups separated by the matmul operation. assert len(fusion_bsyms) == 2 - nvf_0, nvf_1 = fusion_bsyms # CSE removes the redundant (t0 + 5) operation - assert len(nvf_0.subsymbols) == 5 - # Return t0 and t1 from the first fusion - assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t1", "t4"] + nvf_0, nvf_1 = fusion_bsyms + assert len(nvf_0.subsymbols) + len(nvf_1.subsymbols) == 7 - # CSE does not change the second fusion - assert len(nvf_1.subsymbols) == 2 - assert [t.name for t in tree_flatten(nvf_1.output)[0]] == ["t10"] + outside_fusion_syms = ["unpack_trivial", "matmul", "python_return", "python_del"] + assert {el.sym.name for el in fw_trace.bound_symbols if not el.sym.is_fusion} == set(outside_fusion_syms) @instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,))