Skip to content

Commit

Permalink
fix dataclass arg repr (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jul 23, 2024
1 parent d27b107 commit a16d6ff
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 3 additions & 4 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,9 @@ def prettyprint(
name = _generate_dataclass_class_name(x)
call_repr = []
for k, v in x.__dict__.items():
try:
call_repr.append(f"{k}={v.name}")
except:
call_repr.append(f"{k}={v}")
call_repr.append(
f"{k}={prettyprint(v, with_type=False, literals_as_underscores=literals_as_underscores, _quote_markers=False)}"
)
call_repr_str = ",".join(call_repr)
return m(f"{name}({call_repr_str})")

Expand Down
6 changes: 4 additions & 2 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,11 +2789,12 @@ class TestDataclass:
s: torch.Tensor
i: int
f: float
g: tuple

def foo(x):
# TestDataClass as the output and part of the nested output.
return TestDataclass(x, x + 2, x.numel(), x.numel() / 2.0), (
TestDataclass(x, x + 2, x.numel(), x.numel() / 2.0),
return TestDataclass(x, x + 2, x.numel(), x.numel() / 2.0, (x,)), (
TestDataclass(x, x + 2, x.numel(), x.numel() / 2.0, (x)),
{"x": x, "y": x + 3},
)

Expand All @@ -2813,6 +2814,7 @@ def _test_container(actual_container, expected_container):
torch.testing.assert_close(actual_container.s, expected_container.s)
torch.testing.assert_close(actual_container.i, expected_container.i)
torch.testing.assert_close(actual_container.f, expected_container.f)
torch.testing.assert_close(actual_container.g[0], expected_container.g[0])

_test_container(actual_container, expected_container)
_test_container(actual_tuple[0], expected_tuple[0])
Expand Down

0 comments on commit a16d6ff

Please sign in to comment.