Skip to content

Commit 9012e7a

Browse files
Revert "[dynamo][pytree][1/N] make CXX pytree traceable: tree_iter / tree_leaves (pytorch#137397)"
This reverts commit 07850bb. Reverted pytorch#137397 on behalf of https://github.com/atalman due to Failing internal test ([comment](pytorch#137397 (comment)))
1 parent eb7deb2 commit 9012e7a

File tree

7 files changed

+57
-139
lines changed

7 files changed

+57
-139
lines changed

test/dynamo/test_misc.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import torch._dynamo.testing
3333
import torch._inductor.test_case
3434
import torch.onnx.operators
35-
import torch.utils._pytree as python_pytree
35+
import torch.utils._pytree as pytree
3636
import torch.utils.cpp_extension
3737
from torch import Tensor
3838
from torch._C import FileCheck
@@ -89,11 +89,9 @@
8989
from torch.testing._internal.logging_utils import logs_to_string
9090

9191

92-
HAS_OPTREE = python_pytree._cxx_pytree_exists
92+
HAS_OPTREE = importlib.util.find_spec("optree")
9393
if HAS_OPTREE:
94-
import torch.utils._cxx_pytree as cxx_pytree
95-
else:
96-
cxx_pytree = None
94+
import optree
9795

9896
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
9997
T = typing.TypeVar("T")
@@ -295,9 +293,9 @@ def fn(x):
295293

296294
@unittest.skipIf(not HAS_OPTREE, "missing optree package")
297295
def test_optree_graph_break_message(self):
298-
import optree
299-
300-
@torch.compile(backend="eager")
296+
@torch.compile(
297+
backend="eager",
298+
)
301299
def fn(x):
302300
d = {"a": 1}
303301
optree.tree_flatten(d)
@@ -8678,9 +8676,9 @@ def fn():
86788676

86798677
def test_tracing_py_tree(self):
86808678
def fn(xs):
8681-
flat_xs, spec = python_pytree.tree_flatten(xs)
8679+
flat_xs, spec = pytree.tree_flatten(xs)
86828680
res = [x.clone() for x in flat_xs]
8683-
return python_pytree.tree_unflatten(res, spec)
8681+
return pytree.tree_unflatten(res, spec)
86848682

86858683
xs = [torch.tensor(i) for i in range(3)]
86868684

@@ -8690,10 +8688,12 @@ def fn(xs):
86908688
self.assertEqual(counter.op_count, 3)
86918689

86928690
def test_tracing_nested_py_tree(self):
8691+
import torch.utils._pytree as pytree
8692+
86938693
def fn(xs):
8694-
flat_xs, spec = python_pytree.tree_flatten(xs)
8694+
flat_xs, spec = pytree.tree_flatten(xs)
86958695
res = [x.clone() for x in flat_xs]
8696-
return python_pytree.tree_unflatten(res, spec)
8696+
return pytree.tree_unflatten(res, spec)
86978697

86988698
xs = [torch.tensor(i) for i in range(3)]
86998699
xsl = [xs, xs, xs, xs]
@@ -8706,10 +8706,12 @@ def fn(xs):
87068706
self.assertEqual(counter.op_count, 12)
87078707

87088708
def test_tracing_nested_py_tree_tuples(self):
8709+
import torch.utils._pytree as pytree
8710+
87098711
def fn(xs):
8710-
flat_xs, spec = python_pytree.tree_flatten(xs)
8712+
flat_xs, spec = pytree.tree_flatten(xs)
87118713
res = [x.clone() for x in flat_xs]
8712-
return python_pytree.tree_unflatten(res, spec)
8714+
return pytree.tree_unflatten(res, spec)
87138715

87148716
xs = [torch.tensor(i) for i in range(3)]
87158717
xsl = (xs, xs, xs, xs)
@@ -8722,10 +8724,12 @@ def fn(xs):
87228724
self.assertEqual(counter.op_count, 12)
87238725

87248726
def test_tracing_nested_py_tree_dicts(self):
8727+
import torch.utils._pytree as pytree
8728+
87258729
def fn(xs):
8726-
flat_xs, spec = python_pytree.tree_flatten(xs)
8730+
flat_xs, spec = pytree.tree_flatten(xs)
87278731
res = [x.clone() for x in flat_xs]
8728-
return python_pytree.tree_unflatten(res, spec)
8732+
return pytree.tree_unflatten(res, spec)
87298733

87308734
xs = [torch.tensor(i) for i in range(3)]
87318735
xsl = {
@@ -8758,10 +8762,12 @@ def fn(x):
87588762
self.assertEqual(counter.op_count, 2)
87598763

87608764
def test_tracing_nested_py_tree_mixed_all(self):
8765+
import torch.utils._pytree as pytree
8766+
87618767
def fn(xs):
8762-
flat_xs, spec = python_pytree.tree_flatten(xs)
8768+
flat_xs, spec = pytree.tree_flatten(xs)
87638769
res = [x.clone() for x in flat_xs]
8764-
return python_pytree.tree_unflatten(res, spec)
8770+
return pytree.tree_unflatten(res, spec)
87658771

87668772
xs = [torch.tensor(i) for i in range(3)]
87678773
xsa = (xs, xs)
@@ -8806,12 +8812,13 @@ def fn(x):
88068812
self.assertEqual(cnt.frame_count, 2)
88078813

88088814
def test_tracing_py_tree_tensor_subclass(self):
8815+
import torch.utils._pytree as pytree
88098816
from torch.testing._internal.two_tensor import TwoTensor
88108817
from torch.utils.checkpoint import checkpoint
88118818

88128819
def fn(xs):
88138820
nested_xs = [[xs]]
8814-
flat_xs, spec = python_pytree.tree_flatten(xs)
8821+
flat_xs, spec = pytree.tree_flatten(xs)
88158822
return flat_xs[0].clone()
88168823

88178824
# use checkpoint to trigger a "sourceless" tensor subclass
@@ -8826,11 +8833,13 @@ def checkpoint_fn(xs):
88268833
self.assertEqual(counter.op_count, 2)
88278834

88288835
def test_tracing_tree_map_only(self):
8836+
import torch.utils._pytree as pytree
8837+
88298838
def fn(xs):
88308839
def mapper(x):
88318840
return x.clone()
88328841

8833-
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
8842+
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
88348843
return y
88358844

88368845
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
@@ -10184,9 +10193,7 @@ def fn(x, y):
1018410193
self.assertEqual(actual, expected)
1018510194

1018610195
def test_pytree_tree_leaves(self):
10187-
implemtations = [("python", python_pytree)]
10188-
if cxx_pytree is not None:
10189-
implemtations.append(("cxx", cxx_pytree))
10196+
implemtations = [("python", pytree)]
1019010197

1019110198
for name, module in implemtations:
1019210199
with self.subTest(f"pytree implement: {name}"):
@@ -10218,7 +10225,7 @@ def fn(x):
1021810225
self.assertEqual(actual, expected)
1021910226

1022010227
def test_pytree_tree_flatten_unflatten(self):
10221-
implemtations = [("python", python_pytree)]
10228+
implemtations = [("python", pytree)]
1022210229

1022310230
for name, module in implemtations:
1022410231
with self.subTest(f"pytree implement: {name}"):
@@ -10267,7 +10274,7 @@ def fn(x, y):
1026710274
self.assertEqual(actual, expected)
1026810275

1026910276
def test_pytree_tree_map(self):
10270-
implemtations = [("python", python_pytree)]
10277+
implemtations = [("python", pytree)]
1027110278

1027210279
for name, module in implemtations:
1027310280
with self.subTest(f"pytree implement: {name}"):

torch/_dynamo/guards.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,11 +2080,10 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None)
20802080
obj_ref = None
20812081
# Not necessary to have weakref for Enum type, but there is a bug that
20822082
# makes hasattr(guarded_object.__class__, "__weakref__") return True.
2083-
supports_weakref = (
2084-
getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
2085-
)
20862083
# See D64140537 for why we are checking for tuple.
2087-
if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)):
2084+
if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
2085+
guarded_object, (enum.Enum, tuple)
2086+
):
20882087
obj_ref = weakref.ref(guarded_object)
20892088

20902089
guard.set_export_info(

torch/_dynamo/polyfills/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
itertools as itertools,
2424
operator as operator,
2525
os as os,
26-
pytree as pytree,
2726
sys as sys,
2827
)
2928

torch/_dynamo/polyfills/loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"itertools",
1919
"operator",
2020
"os",
21-
"pytree",
2221
"sys",
2322
)
2423
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(

torch/_dynamo/polyfills/pytree.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

torch/_dynamo/trace_rules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3310,7 +3310,6 @@ def _module_dir(m: types.ModuleType):
33103310
"torch.testing",
33113311
"torch.utils._content_store",
33123312
"torch.utils._contextlib",
3313-
"torch.utils._cxx_pytree",
33143313
"torch.utils._device",
33153314
"torch.utils._foreach_utils",
33163315
"torch.utils._python_dispatch",

torch/utils/_cxx_pytree.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
from typing_extensions import deprecated
3131

3232
import optree
33-
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
33+
from optree import PyTreeSpec # direct import for type annotations
3434

35-
import torch.utils._pytree as python_pytree
36-
from torch.utils._pytree import KeyEntry as KeyEntry
35+
import torch.utils._pytree as _pytree
36+
from torch.utils._pytree import KeyEntry
3737

3838

3939
__all__ = [
@@ -79,6 +79,7 @@
7979

8080
Context = Any
8181
PyTree = Any
82+
TreeSpec = PyTreeSpec
8283
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
8384
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
8485
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
@@ -150,7 +151,9 @@ def register_pytree_node(
150151
from_dumpable_context=from_dumpable_context,
151152
)
152153

153-
python_pytree._private_register_pytree_node(
154+
from . import _pytree as python
155+
156+
python._private_register_pytree_node(
154157
cls,
155158
flatten_fn,
156159
unflatten_fn,
@@ -868,19 +871,24 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
868871
f"treespec_dumps(spec): Expected `spec` to be instance of "
869872
f"TreeSpec but got item of type {type(treespec)}."
870873
)
874+
from ._pytree import (
875+
tree_structure as _tree_structure,
876+
treespec_dumps as _treespec_dumps,
877+
)
871878

872-
dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
873-
orig_treespec = python_pytree.tree_structure(dummy_tree)
874-
return python_pytree.treespec_dumps(orig_treespec, protocol=protocol)
879+
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
880+
return _treespec_dumps(orig_treespec, protocol=protocol)
875881

876882

877883
def treespec_loads(serialized: str) -> TreeSpec:
878884
"""Deserialize a treespec from a JSON string."""
879-
orig_treespec = python_pytree.treespec_loads(serialized)
880-
dummy_tree = python_pytree.tree_unflatten(
881-
[0] * orig_treespec.num_leaves,
882-
orig_treespec,
885+
from ._pytree import (
886+
tree_unflatten as _tree_unflatten,
887+
treespec_loads as _treespec_loads,
883888
)
889+
890+
orig_treespec = _treespec_loads(serialized)
891+
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
884892
treespec = tree_structure(dummy_tree)
885893
return treespec
886894

@@ -994,10 +1002,6 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
9941002
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
9951003

9961004

997-
with python_pytree._NODE_REGISTRY_LOCK:
998-
python_pytree._cxx_pytree_imported = True
999-
args, kwargs = (), {} # type: ignore[var-annotated]
1000-
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
1001-
_private_register_pytree_node(*args, **kwargs)
1002-
python_pytree._cxx_pytree_pending_imports.clear()
1003-
del args, kwargs
1005+
_pytree._cxx_pytree_imported = True
1006+
for args, kwargs in _pytree._cxx_pytree_pending_imports:
1007+
_private_register_pytree_node(*args, **kwargs)

0 commit comments

Comments
 (0)