From 5fc67dcba844554c8a2390ef8775594e61f18737 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 1 Jul 2024 19:51:36 +0200 Subject: [PATCH] Fix proxy repr for non-tensor proxy (#692) --- thunder/core/proxies.py | 10 +++++++++- thunder/tests/test_core.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 89479bc43f..b60f795670 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -134,7 +134,9 @@ def replace_name(self, name: str | None = None): return self.__class__(name=name) def __repr__(self) -> str: - return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape}>' + # All subclasses of Proxy will have `self.name`, so this generic implementation relies on that. + # To have a specific repr for a subclass, override the implementation for that subclass. + return f'<{type(self).__name__}(name="{self.name}")>' def type_string(self) -> str: return "Any" @@ -1198,6 +1200,9 @@ def true_dtype(self): def requires_grad(self): return self._requires_grad + def __repr__(self): + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' + def type_string(self): return f"FUTURE {self.device} {self.dtype.shortname()}{list(self.shape)}" @@ -1293,6 +1298,9 @@ def replace_name(self, name: str): """Return a copy of this proxy with the given name.""" return tensorproxy(self, name=name, history=self.history) + def __repr__(self): + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' + def type_string(self): return f"{self.device} {self.dtype.shortname()}{list(self.shape)}" diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index f66318c3e2..fcf1b813fa 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -1244,7 +1244,7 @@ def foo(x, y, z): region_bsyms = trace.bound_symbols[:3] region = Region(producers, consumers, region_bsyms) assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == [ - '' + '' ] @@ -2853,3 +2853,21 @@ def forward(self, x) -> torch.Tensor: gradcheck(jitted, (x,)) with pytest.raises(GradcheckError): gradcheck(model, (x,)) + + +def test_proxy_repr(): + # Verify that we can call `__repr__` on different proxy subclasses. + t = thunder.core.trace.TraceCtx() + with thunder.core.trace.tracectx(t): + p = thunder.core.proxies.NumberProxy("number", 1, python_type=int) + c = thunder.core.proxies.CollectionProxy((1, 2), name="collection") + t = thunder.core.proxies.TensorProxy( + "tensor", + shape=(1,), + dtype=thunder.core.dtypes.float16, + device=thunder.core.devices.Device("cpu"), + requires_grad=True, + ) + assert p.__repr__() == '' + assert t.__repr__() == '' + assert c.__repr__() == ''