Skip to content

Commit

Permalink
Ensure TORCH_TRACE is run for Dynamo/Distributed tests (pytorch#139786)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#139786
Approved by: https://github.com/bobrenjc93, https://github.com/c00w, https://github.com/anijain2305
ghstack dependencies: pytorch#139716
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 7, 2024
1 parent 47446cb commit 4e64787
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
17 changes: 17 additions & 0 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,23 @@ def tensor(self, name, t) -> None:
+ f") # {name}"
)

def unsupported(self, name, arg):
# NB: Try hard not to /print/ a tensor, that will be very slow
self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}")
# Best effort dump as much useful stuff we can lol, in case you want
# to repair the repro
if isinstance(arg, (list, tuple)):
self._lines.append('"""')
for i, a in enumerate(arg):
name_i = f"{name}[{i}]"
if isinstance(a, torch.Tensor):
self.tensor(name_i, a)
elif isinstance(a, (int, torch.SymInt)):
self.symint(name_i, a)
else:
self.unsupported(name_i, a)
self._lines.append('"""')

# write out that the arg was filtered out as it is constant
def const(self, name) -> None:
self._lines.append(
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/repro/after_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def hint_if_symint(x):
elif arg is None:
writer.const(placeholder)
else:
raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}")
# It's better to produce a slightly wrong repro string than none
# at all
writer.unsupported(placeholder, arg)

model_str += "\n".join(writer.lines()) + "\n"

Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.testing
from torch._logging._internal import trace_log
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
IS_WINDOWS,
TEST_WITH_CROSSREF,
Expand Down Expand Up @@ -63,8 +64,11 @@ def setUp(self) -> None:
super().setUp()
reset()
utils.counters.clear()
self.handler = logging.NullHandler()
trace_log.addHandler(self.handler)

def tearDown(self) -> None:
trace_log.removeHandler(self.handler)
for k, v in utils.counters.items():
print(k, v.most_common())
reset()
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2965,7 +2965,7 @@ class DictGetItemGuardAccessor : public GuardAccessor {
}

std::string repr() const override {
return "DictGetItemGuardAccessor(" + py::str(_key).cast<std::string>() +
return "DictGetItemGuardAccessor(" + py::repr(_key).cast<std::string>() +
")";
}

Expand Down
3 changes: 3 additions & 0 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable, Tuple
from unittest.mock import patch

from torch._logging._internal import trace_log
import torch
import torch._dynamo.test_case
import torch.cuda.nccl
Expand Down Expand Up @@ -1348,6 +1349,8 @@ def world_size(self) -> int:

@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
trace_log.addHandler(logging.NullHandler())

# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank
Expand Down

0 comments on commit 4e64787

Please sign in to comment.