Skip to content

Detailed __repr__ #638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4c29111
feat: init
k223kim Jun 24, 2024
768f5a3
feat: removed unused code
k223kim Jun 24, 2024
7a4a251
feat: updated device repr
k223kim Jun 25, 2024
ce58d46
feat: type annotation fix and added repr for dtype
k223kim Jun 26, 2024
158fb7e
Merge remote-tracking branch 'upstream/main' into k223kim/detailed_repr
k223kim Jun 26, 2024
23da5d5
feat: fix test_core but not done yet
k223kim Jun 26, 2024
a99cc45
feat: updated type annotation
k223kim Jun 26, 2024
6270cf6
Merge branch 'main' into k223kim/detailed_repr
k223kim Jun 26, 2024
ffa8021
feat: updated prettyprint
k223kim Jun 26, 2024
2aafc4a
feat: removed comments
k223kim Jun 26, 2024
f51b783
feat: fixed test cases
k223kim Jun 26, 2024
77cd131
feat: removed/added comments
k223kim Jun 26, 2024
4f02925
feat: update
k223kim Jun 26, 2024
4cc4e30
feat: updated repr for executors
k223kim Jun 27, 2024
df06f57
feat: updated tensorproxy repr and prettyprint
k223kim Jun 27, 2024
b610070
feat: updated prettyprint
k223kim Jun 27, 2024
dc1a2e6
feat: first draft of torch div
k223kim Jun 27, 2024
27cccf9
feat: updated make_tensor and corresponding test cases
k223kim Jun 28, 2024
55fb16f
feat: fix
k223kim Jun 28, 2024
ee27987
feat: fix
k223kim Jun 28, 2024
7dcac06
feat: fix
k223kim Jun 28, 2024
0e94d74
feat: fix
k223kim Jun 28, 2024
9406381
feat: fix
k223kim Jun 28, 2024
331aff6
feat: fix
k223kim Jun 28, 2024
35dfca6
feat: fix
k223kim Jun 28, 2024
40ef76f
feat: fix
k223kim Jun 28, 2024
fba0714
feat: fix
k223kim Jun 28, 2024
9216dfb
feat: fix
k223kim Jun 28, 2024
f19fa28
feat: fix
k223kim Jun 28, 2024
aff9d1e
feat: fix
k223kim Jun 28, 2024
3b40fce
feat: fix
k223kim Jun 28, 2024
71a834c
feat: local working code
k223kim Jun 28, 2024
94c2f20
feat: fix
k223kim Jun 28, 2024
b356a8f
feat: fix error generator
k223kim Jun 28, 2024
04f4064
feat: fix
k223kim Jun 28, 2024
c3390ea
feat: fix to_torch_device
k223kim Jun 28, 2024
5b2532e
feat: fix
k223kim Jun 28, 2024
a85a6f9
Merge remote-tracking branch 'upstream/main' into k223kim/detailed_repr
k223kim Jun 28, 2024
9a78b69
feat: type -> device_str
k223kim Jun 30, 2024
5926d9b
feat: minimize change
k223kim Jul 1, 2024
0576b58
feat: minimize change
k223kim Jul 1, 2024
e5d5f6f
feat: fix
k223kim Jul 1, 2024
1296471
feat: fix
k223kim Jul 1, 2024
112a929
feat: fix
k223kim Jul 1, 2024
084ce1d
feat: fix
k223kim Jul 1, 2024
e08c46e
feat: fix
k223kim Jul 1, 2024
d8441ab
feat: fix
k223kim Jul 1, 2024
83661a8
feat: fix
k223kim Jul 1, 2024
4549a5f
Merge branch 'main' into k223kim/detailed_repr
k223kim Jul 1, 2024
6cc7c36
feat: fix
k223kim Jul 1, 2024
c05d103
feat: fix
k223kim Jul 1, 2024
27a52be
feat: ddp fix
k223kim Jul 1, 2024
54bedfe
feat: fix
k223kim Jul 1, 2024
8cb187a
Merge branch 'main' into k223kim/detailed_repr
k223kim Jul 1, 2024
bf51c96
feat: fix
k223kim Jul 1, 2024
423cc0f
feat: fix transforms
k223kim Jul 1, 2024
58cb421
Merge branch 'main' into k223kim/detailed_repr
k223kim Jul 1, 2024
fea8e5c
feat: fix transforms
k223kim Jul 1, 2024
01ee0f9
feat: fix distributed
k223kim Jul 1, 2024
4e15ce4
Update thunder/distributed/tensor_parallel/common.py
t-vi Jul 1, 2024
4a4892e
feat: fix
k223kim Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def __call__(self, fn: Callable) -> Callable:
@clangop()
def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None:
return prims.check_tensor_shape_and_metadata(
t, tuple(t.shape), str(t.device), dtypes.to_torch_dtype(t.dtype), t.requires_grad
t,
tuple(t.shape),
t.device.device_str(),
dtypes.to_torch_dtype(t.dtype),
t.requires_grad,
)


Expand Down
16 changes: 9 additions & 7 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def prettyprint(
if isinstance(x, dtypes.dtype):
return m(f"dtypes.{str(x)}")
if isinstance(x, devices.Device):
return m(f'devices.Device("{str(x)}")')
return m(f'devices.Device("{x.device_str()}")')
if type(x) is type:
return m(f"{baseutils.print_type(x, with_quotes=False)}")
if dataclasses.is_dataclass(x):
Expand All @@ -243,12 +243,14 @@ def prettyprint(
# NOTE: The `class` packagename1_MyContainer will present in `import_ctx` and passed to the compiled function.
# This is taken care of by function `to_printable`.
name = _generate_dataclass_class_name(x)
instance_repr = str(x)
parens_idx = instance_repr.find("(")
call_repr = instance_repr[
parens_idx:
] # only keep the construction part of the repr (as we will use our generated name)
return m(f"{name + call_repr}")
call_repr = []
for k, v in x.__dict__.items():
try:
call_repr.append(f"{k}={v.name}")
except:
call_repr.append(f"{k}={v}")
call_repr_str = ",".join(call_repr)
return m(f"{name}({call_repr_str})")

# Handles objects that this doesn't know how to serialize as a string
return m(f"(object of type {print_type(type(x), with_quotes=False)})")
Expand Down
15 changes: 12 additions & 3 deletions thunder/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,21 @@ def __hash__(self) -> int:
# converting Thunder devices to PyTorch devices
def __repr__(self) -> str:
if self.devicetype == DeviceType.CUDA:
return f"{devicetype_string(self.devicetype)}:{self.index}"
return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}:{self.index}')"
# note: self.devicetype == DeviceType.CPU, .META
return devicetype_string(self.devicetype)
return f"thunder.devices.Device(type='{devicetype_string(self.devicetype)}')"

# NOTE Because devices are singleton object, this has the luxury of using "is"
def __eq__(self, other: Device) -> bool:
return self is other

# NOTE this is needed when passing devices.Device to torch operators such as torch.testing.make_tensor
def device_str(self) -> str:
if self.devicetype == DeviceType.CUDA:
return f"{devicetype_string(self.devicetype)}:{self.index}"
# note: self.devicetype == DeviceType.CPU, .META
return devicetype_string(self.devicetype)


cpu = Device(DeviceType.CPU, None)

Expand Down Expand Up @@ -185,4 +192,6 @@ def to_torch_device(x: None | str | torch.device | Device, /) -> None | torch.de
return x

baseutils.check_type(x, (Device, str))
return torch.device(str(x))
if isinstance(x, Device):
return torch.device(x.device_str())
return torch.device(x)
4 changes: 1 addition & 3 deletions thunder/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def shortname(self):

# TODO Fix name printing
def __repr__(self):
return (
f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
)
return f"thunder.dtypes.{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"

def __str__(self):
return self.__repr__()
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ def assert_tensor_metadata_impl(
if (
type(t) in (torch.Tensor, torch.nn.Parameter)
and tuple(t.shape) == shape
and str(t.device) == str(device)
and str(t.device) == device.device_str()
and t.dtype == dtype
and t.requires_grad == requires_grad
):
return

raise AssertionError(
f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device)}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}"
f"Object had unexpected metadata. Expected type Tensor/nn.Parameter (without subclass), shape {shape}, device {str(device.device_str())}, dtype {dtype}, and {requires_grad=}, but found type {type(t)}, shape {tuple(t.shape)}, device {str(t.device)}, and requires_grad {t.requires_grad}"
)


Expand Down
6 changes: 3 additions & 3 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def replace_name(self, name: str | None = None):
return self.__class__(name=name)

def __repr__(self) -> str:
return f"{self.name}"
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape}>'

def type_string(self) -> str:
return "Any"
Expand Down Expand Up @@ -1610,7 +1610,7 @@ def real(self):


def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
device = devices.device_from_string(str(t.device))
device = devices.to_device(t.device)
dtype = dtypes.to_dtype(t.dtype)
# See Note [DistributedDataParallel and distparallel_type]
distparallel_type = getattr(t, "distparallel_type", None)
Expand All @@ -1631,7 +1631,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
def futuretensorproxy(
t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None
) -> FutureTensorProxy:
device = devices.device_from_string(str(t.device))
device = devices.to_device(t.device)
dtype = dtypes.to_dtype(t.dtype)
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
return FutureTensorProxy(
Expand Down
4 changes: 2 additions & 2 deletions thunder/distributed/tensor_parallel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def transform_traces(
if c.sym is prims.check_tensor_shape_and_metadata:
# TODO have a more principled way to update this?
a0, _, _, *a2pp = c.args
c.args = (a0, tuple(new_shape), str(a0.device), *a2pp)
c.args = (a0, tuple(new_shape), a0.device.device_str(), *a2pp)

for bsym in prologue_trace.bound_symbols:
if bsym.sym is prims.check_tensor_shape_and_metadata and prologue_producers[bsym.args[0]].sym in (
Expand All @@ -249,7 +249,7 @@ def transform_traces(
assert param_thunder_module is thunder_module_proxy
if name not in self.chunked_param_name_to_layer_type:
a0, shape, _, *a2pp = bsym.args
bsym.args = (a0, shape, str(a0.device), *a2pp)
bsym.args = (a0, shape, a0.device.device_str(), *a2pp)

if len(modules_and_thunder_modules) != 1:
raise NotImplementedError("cannot deal with modules other than the compiled module")
Expand Down
2 changes: 1 addition & 1 deletion thunder/distributed/transforms/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
param_name_to_comp_trc_proxy[param_name] = comp_inp_p
old_shape, new_shape, new_torch_device = self.sharded_params[param_name]
thunder_device = devices.to_device(new_torch_device)
thunder_device_str = str(thunder_device)
thunder_device_str = thunder_device.device_str()

pro_out_p._distparallel_type = DistParallelType.FULLY_SHARDED
pro_out_p._shape = tuple(new_shape)
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _to_transform(


def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy:
torch_device: str = str(device)
torch_device: str = device.device_str()
return to(a, torch_device)


Expand Down
2 changes: 1 addition & 1 deletion thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def implmap(self) -> dict[Hashable, ImplInfo]:
return self._implmap

def __repr__(self) -> str:
return str(self.name)
return f"thunder.extend.OperatorExecutor('{str(self.name)}')"

def __hash__(self) -> int:
return hash(self.name)
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,6 @@ def create_per_process_dataloader(
sampler = tudata.SequentialSampler(dataset)

collate_fn = None

if devicetype is not devices.DeviceType.CPU:
assert devicetype is devices.DeviceType.CUDA, f"Unknown devicetype {devicetype}"
device = torch.device("cuda", rank)
Expand Down Expand Up @@ -1214,6 +1213,7 @@ def _test_native_ddp_helper(input_data):
torch_dtype = ltorch.to_torch_dtype(dtype)

pg = init_per_process_distributed(init_method, devicetype, world_size, rank)

tdist.barrier(pg)

dataloader = create_per_process_dataloader(
Expand Down
5 changes: 3 additions & 2 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,15 @@ def _instantiate_executor_test_template(
) -> Callable:
devicetype: devices.DeviceType
device_str: str | list[str]
devicetype = device_or_devices
if isinstance(device_or_devices, devices.Device):
devicetype = device_or_devices.devicetype
device_str = str(device_or_devices)
device_str = device_or_devices.device_str()
else:
devicetype = device_or_devices[0].devicetype
device_str = []
for device in device_or_devices:
device_str.append(str(device))
device_str.append(device.device_str())

devicetype_str = devices.devicetype_string(devicetype)
template_name = as_name if as_name is not None else template.__name__
Expand Down
8 changes: 6 additions & 2 deletions thunder/tests/make_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast, List, Optional, Tuple, Union

import torch
import thunder

# adapted from https://github.com/pytorch/pytorch/blob/master/torch/testing/_creation.py
# Changes:
Expand Down Expand Up @@ -32,7 +33,7 @@ def _uniform_random(t: torch.Tensor, low: float, high: float):
def make_tensor(
*shape: int | torch.Size | list[int] | tuple[int, ...],
dtype: torch.dtype,
device: str | torch.device,
device: str | torch.device | thunder.devices.Device,
low: float | None = None,
high: float | None = None,
requires_grad: bool = False,
Expand Down Expand Up @@ -62,7 +63,7 @@ def make_tensor(
Args:
shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
dtype (:class:`torch.dtype`): The data type of the returned tensor.
device (Union[str, torch.device]): The device of the returned tensor.
device (Union[str, torch.device, thunder.devices.Device]): The device of the returned tensor.
low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
clamped to the least representable finite value of the given dtype. When ``None`` (default),
this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
Expand Down Expand Up @@ -112,6 +113,9 @@ def clamp(a, l, h):

return low, high

if isinstance(device, thunder.devices.Device):
device = device.device_str()

if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence):
shape = shape[0] # type: ignore[assignment]
shape = cast(tuple[int, ...], tuple(shape))
Expand Down
24 changes: 10 additions & 14 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def is_active(
# Acquires devicetype
devicetype_: devices.DeviceType
if isinstance(device_or_devicetype, str):
devicetype_ = devices.device_from_string(device_or_devicetype).devicetype
devicetype_ = devices.to_device(device_or_devicetype).devicetype
elif isinstance(device_or_devicetype, devices.Device):
devicetype_ = device_or_devicetype.devicetype
else:
Expand Down Expand Up @@ -392,26 +392,22 @@ def sample_inputs(
self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs
) -> Generator:
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.sample_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs)
return self.sample_input_generator(self, device, torch_dtype, requires_grad, **kwargs)

def reference_inputs(
self, device: str | devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs
) -> Generator:
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.reference_input_generator(self, torch_device, torch_dtype, requires_grad, **kwargs)
return self.reference_input_generator(self, device, torch_dtype, requires_grad, **kwargs)

def error_inputs(self, device: devices.Device, **kwargs):
torch_device = str(device)
return self.error_input_generator(self, torch_device, **kwargs)
return self.error_input_generator(self, device, **kwargs)

# NOTE Today all benchmarks are generated with PyTorch, so Thunder objects,
# like dtypes, need to be translated into PyTorch objects
def benchmarks(self, device: devices.Device, dtype: datatypes.dtype, *, requires_grad: bool = False, **kwargs):
torch_dtype = to_torch_dtype(dtype)
torch_device = str(device)
return self.benchmark_generator(self, torch_device, dtype, requires_grad, **kwargs)
return self.benchmark_generator(self, device, dtype, requires_grad, **kwargs)

def devicetypes(self):
return set(self._devicetypes)
Expand Down Expand Up @@ -5565,7 +5561,7 @@ def full_sample_generator(op, device, dtype, requires_grad, **kwargs):


def full_error_generator(op, device, **kwargs):
err_msg = "Can't safely cast fill_value of numbertype <class 'complex'> to dtype float32"
err_msg = "Can't safely cast fill_value of numbertype <class 'complex'> to dtype thunder.dtypes.float32"
yield (SampleInput((1, 2), 1j, device=device, dtype=torch.float), RuntimeError, err_msg)


Expand Down Expand Up @@ -5744,7 +5740,7 @@ def bernoulli_sample_generator(op, device, dtype, requires_grad, **kwargs):


def bernoulli_error_generator(op, device, **kwargs):
err_msg = "bernoulli only supports floating point dtypes, got int64"
err_msg = "bernoulli only supports floating point dtypes, got thunder.dtypes.int64"
yield (SampleInput(torch.ones(3, 3, device=device, dtype=torch.long)), RuntimeError, err_msg)

err_msg = "generator is not None which is currently unsupported"
Expand Down Expand Up @@ -5903,13 +5899,13 @@ def tensor_constructor_error_generator(op, device, **kwargs):
err_msg = "Expected sequences of numbers, but found type <class 'list'>"
yield (SampleInput([[1], [[6, 2]]]), ValueError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'float'> to dtype int32"
err_msg = "Can't safely cast sequence with numbertype <class 'float'> to dtype thunder.dtypes.int32"
yield (SampleInput([[1, 2.0], [6, 2]], dtype=torch.int32), RuntimeError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype int32"
err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype thunder.dtypes.int32"
yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.int32), RuntimeError, err_msg)

err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype float64"
err_msg = "Can't safely cast sequence with numbertype <class 'complex'> to dtype thunder.dtypes.float64"
yield (SampleInput([[1, 2j], [6, 2]], dtype=torch.float64), RuntimeError, err_msg)


Expand Down
4 changes: 3 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,9 @@ def foo(x, y, z):
consumers = thunder.core.utils.consumers(trace)
region_bsyms = trace.bound_symbols[:3]
region = Region(producers, consumers, region_bsyms)
assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == ["t0"]
assert len(region.inputs) == 0 and sorted(str(v) for v in region.outputs) == [
'<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(1,)>'
]


# This test ensures that calls to torch functions are recorded in the trace
Expand Down
4 changes: 2 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def type(a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: b
# 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device.
dev = a.device
else:
dev = device_from_string(devtype)
dev = to_device(devtype)
else:
# dtype is assumed to be torch.dtype (e.g. torch.int32)
dev = a.device
Expand Down Expand Up @@ -458,7 +458,7 @@ def cuda(
device = to_device(device)
utils.check(
device.devicetype == devices.DeviceType.CUDA,
lambda: f"cuda(): Invalid device {device}, must be cuda device",
lambda: f"cuda(): Invalid device {device.device_str()}, must be cuda device",
)

return to(a, device=device, memory_format=memory_format)
Expand Down
2 changes: 1 addition & 1 deletion thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **
# check has args: tensor, shape, device, dtype, requires_grad
proxy, _, _, _, requires_grad = check.args
thunder_device = thunder.devices.to_device(param.device)
thunder_device_str = str(thunder_device)
thunder_device_str = thunder_device.device_str()
check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)

output_idx = output_idxes.get(id(get_param.output))
Expand Down
Loading