From 0d7ce8be918dd18b381f99208a7ced2a930ad4b8 Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 25 Apr 2024 10:45:20 +0900 Subject: [PATCH 01/10] feat: initial implementation of type --- thunder/torch/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 1a4ed0b31a..d9ac09fe08 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -187,6 +187,19 @@ def is_cuda(a: TensorLike, /) -> bool: register_method("size", size) +@torchsymbol(torch.Tensor.type, is_method=True) +def type(a: TensorLike, dtype: None | dtypeLike, non_blocking: bool = False, /) -> TensorLike: + utils.check( + not non_blocking, + lambda: f"type(): `non_blocking==True` is currently not supported.", + exception_type=NotImplementedError, + ) + if dtype is None: + return a.dtype + return clang.maybe_convert_to_dtype(a, to_dtype(dtype)) + + +register_method("type", type) # # Data movement and transformation operations # From a161881bea762b31b3daa995956540a6f632167c Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 25 Apr 2024 10:52:27 +0900 Subject: [PATCH 02/10] feat: small fix --- thunder/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index d9ac09fe08..fe99359f80 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -188,7 +188,7 @@ def is_cuda(a: TensorLike, /) -> bool: @torchsymbol(torch.Tensor.type, is_method=True) -def type(a: TensorLike, dtype: None | dtypeLike, non_blocking: bool = False, /) -> TensorLike: +def type(a: TensorLike, dtype: None | dtypeLike = None, non_blocking: bool = False, /) -> TensorLike: utils.check( not non_blocking, lambda: f"type(): `non_blocking==True` is currently not supported.", From c427ed23a380abad7079c51bf0618c34c5a28eef Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 25 Apr 2024 11:38:45 +0900 Subject: [PATCH 03/10] feat: convert to torch dtype --- thunder/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index fe99359f80..d9879aadb4 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -195,7 +195,7 @@ def type(a: TensorLike, dtype: None | dtypeLike = None, non_blocking: bool = Fal exception_type=NotImplementedError, ) if dtype is None: - return a.dtype + return to_torch_dtype(a.dtype) return clang.maybe_convert_to_dtype(a, to_dtype(dtype)) From 0b59d4b9732d9a27807961b55348a1a5a00518c0 Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 25 Apr 2024 16:00:02 +0900 Subject: [PATCH 04/10] feat: updated torch.type implementation --- thunder/torch/__init__.py | 54 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index d9879aadb4..d1f62c2581 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -188,15 +188,63 @@ def is_cuda(a: TensorLike, /) -> bool: @torchsymbol(torch.Tensor.type, is_method=True) -def type(a: TensorLike, dtype: None | dtypeLike = None, non_blocking: bool = False, /) -> TensorLike: +def type(a: TensorLike, dtype: None | str | dtypeLike = None, non_blocking: bool = False, /) -> str | TensorLike: utils.check( not non_blocking, lambda: f"type(): `non_blocking==True` is currently not supported.", exception_type=NotImplementedError, ) + + DTYPE_STR = { + torch.float32: "torch.FloatTensor", + torch.float64: "torch.DoubleTensor", + torch.float16: "torch.HalfTensor", + torch.bfloat16: "torch.BFloat16Tensor", + torch.uint8: "torch.ByteTensor", + torch.int8: "torch.CharTensor", + torch.int16: "torch.ShortTensor", + torch.int32: "torch.IntTensor", + torch.long: "torch.LongTensor", + torch.bool: "torch.BoolTensor", + } if dtype is None: - return to_torch_dtype(a.dtype) - return clang.maybe_convert_to_dtype(a, to_dtype(dtype)) + # returns the type of the input tensor in string + torch_dtype = to_torch_dtype(a.dtype) + torch_dtype = DTYPE_STR.get(torch_dtype) + if a.device.devicetype is devices.DeviceType.CUDA: + t, _dtype = torch_dtype.split(".") + torch_dtype = f"{t}.cuda.{_dtype}" + return torch_dtype + + TORCH_DTYPES = { + "torch.FloatTensor": torch.float32, + "torch.DoubleTensor": torch.float64, + "torch.HalfTensor": torch.float16, + "torch.BFloat16Tensor": torch.bfloat16, + "torch.ByteTensor": torch.uint8, + "torch.CharTensor": torch.int8, + "torch.ShortTensor": torch.int16, + "torch.IntTensor": torch.int32, + "torch.LongTensor": torch.long, + "torch.BoolTensor": torch.bool, + } + if isinstance(dtype, str): + parse_dtype = dtype.split(".") + if len(parse_dtype) == 2: + utils.check(dtype in TORCH_DTYPES, lambda: f"type(): invalid type: {dtype}.", exception_type=ValueError) + dtype = TORCH_DTYPES.get(dtype) + else: + t, device, torch_dtype = dtype.split(".") + utils.check( + f"{t}.{torch_dtype}" in TORCH_DTYPES and device == "cuda", + lambda: f"type(): invalid type: {dtype}.", + exception_type=ValueError, + ) + dtype = TORCH_DTYPES.get(f"{t}.{torch_dtype}") + output = clang.maybe_convert_to_dtype(a, to_dtype(dtype)) + if a.device.devicetype is devices.DeviceType.CUDA: + output = prims.device_put(output, a.device.devicetype) + return output register_method("type", type) From 2430d477ed6bf70e3cfd58aa52dffdb21d89fb95 Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 25 Apr 2024 19:06:07 +0900 Subject: [PATCH 05/10] feat: added test case for torch.type() --- thunder/tests/opinfos.py | 58 +++++++++++++++++++++++++++++++++++++++ thunder/torch/__init__.py | 1 + 2 files changed, 59 insertions(+) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 07bb925bae..e9d2648ca8 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2749,6 +2749,64 @@ def tril_sample_generator(op, device, dtype, requires_grad, **kwargs): # data_movement_ops.append(convert_element_type_opinfo) +def type_sample_generator_tensor(op, device, dtype, requires_grad, **kwargs): + # dtype is not None + # expected to return tensor + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + DTYPE_STR = { + torch.float32: "torch.FloatTensor", + torch.float64: "torch.DoubleTensor", + torch.float16: "torch.HalfTensor", + torch.bfloat16: "torch.BFloat16Tensor", + torch.uint8: "torch.ByteTensor", + torch.int8: "torch.CharTensor", + torch.int16: "torch.ShortTensor", + torch.int32: "torch.IntTensor", + torch.long: "torch.LongTensor", + torch.bool: "torch.BoolTensor", + } + + yield SampleInput(make(4, 4), dtype) + yield SampleInput(make(4, 4), DTYPE_STR[dtype]) + + +type_opinfo_tensor = OpInfo( + ltorch.type, + sample_input_generator=type_sample_generator_tensor, + torch_reference=torch.Tensor.type, +) + +data_movement_ops.append(type_opinfo_tensor) + + +def type_sample_generator_str(op, device, dtype, requires_grad, **kwargs): + # dtype is None + # expected to return string + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + yield SampleInput(make(4, 4)) + + +# comparing strings (NOTE: assert_close does not support string comparison) +def string_compare(actual, expected, **kwargs): + assert actual == expected + + +type_opinfo_str = OpInfo( + ltorch.type, + sample_input_generator=type_sample_generator_str, + torch_reference=torch.Tensor.type, + test_directives=( + DecorateInfo( + custom_comparator(string_compare), + "test_core_vs_torch_consistency", + ), + ), +) + +data_movement_ops.append(type_opinfo_str) + + def to_sample_generator(op, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index d1f62c2581..2d2f0faa6d 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -207,6 +207,7 @@ def type(a: TensorLike, dtype: None | str | dtypeLike = None, non_blocking: bool torch.long: "torch.LongTensor", torch.bool: "torch.BoolTensor", } + if dtype is None: # returns the type of the input tensor in string torch_dtype = to_torch_dtype(a.dtype) From d87cbe39e8e87c805f71db2d8821be779a232888 Mon Sep 17 00:00:00 2001 From: k223kim Date: Mon, 29 Apr 2024 12:42:15 +0900 Subject: [PATCH 06/10] feat: updated implementation of tensor.type based on suggested comment by mruberry --- thunder/tests/opinfos.py | 56 +++++++++++++----- thunder/torch/__init__.py | 120 +++++++++++++++++++++++--------------- 2 files changed, 114 insertions(+), 62 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index e9d2648ca8..b0dae16917 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2752,27 +2752,46 @@ def tril_sample_generator(op, device, dtype, requires_grad, **kwargs): def type_sample_generator_tensor(op, device, dtype, requires_grad, **kwargs): # dtype is not None # expected to return tensor - make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - DTYPE_STR = { - torch.float32: "torch.FloatTensor", - torch.float64: "torch.DoubleTensor", - torch.float16: "torch.HalfTensor", - torch.bfloat16: "torch.BFloat16Tensor", - torch.uint8: "torch.ByteTensor", - torch.int8: "torch.CharTensor", - torch.int16: "torch.ShortTensor", - torch.int32: "torch.IntTensor", - torch.long: "torch.LongTensor", - torch.bool: "torch.BoolTensor", + + _torch_dtype_to_old_torch_typestring_map = { + torch.float32: "FloatTensor", + torch.float64: "DoubleTensor", + torch.float16: "HalfTensor", + torch.bfloat16: "BFloat16Tensor", + torch.uint8: "ByteTensor", + torch.int8: "CharTensor", + torch.int16: "ShortTensor", + torch.int32: "IntTensor", + torch.long: "LongTensor", + torch.bool: "BoolTensor", } - yield SampleInput(make(4, 4), dtype) - yield SampleInput(make(4, 4), DTYPE_STR[dtype]) + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + to_dtype = torch.complex128 if dtype.is_complex else torch.float64 + + yield SampleInput(make(4, 4, device=device), dtype) + yield SampleInput(make(4, 4, device=device), dtype=to_dtype) + # below can be deleted if we don't support strings + yield SampleInput(make(4, 4, device=device), f"torch.{_torch_dtype_to_old_torch_typestring_map[dtype]}") + + # Explictly pass device + if torch.device(device).type == "cuda": + yield SampleInput(make(4, 4, device=device), f"torch.cuda.{_torch_dtype_to_old_torch_typestring_map[dtype]}") + + +# kind of redundant? +def type_error_generator_tensor(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + err_msg = r"type\(\): `non_blocking==True` is currently not supported." + yield SampleInput(make(3, 3), dtype, non_blocking=True), RuntimeError, err_msg type_opinfo_tensor = OpInfo( ltorch.type, sample_input_generator=type_sample_generator_tensor, + error_input_generator=type_error_generator_tensor, torch_reference=torch.Tensor.type, ) @@ -2787,6 +2806,14 @@ def type_sample_generator_str(op, device, dtype, requires_grad, **kwargs): yield SampleInput(make(4, 4)) +# kind of redundant? +def type_error_generator_str(op, device, dtype=torch.float32, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + err_msg = r"type\(\): `non_blocking==True` is currently not supported." + yield SampleInput(make(3, 3), non_blocking=True), RuntimeError, err_msg + + # comparing strings (NOTE: assert_close does not support string comparison) def string_compare(actual, expected, **kwargs): assert actual == expected @@ -2795,6 +2822,7 @@ def string_compare(actual, expected, **kwargs): type_opinfo_str = OpInfo( ltorch.type, sample_input_generator=type_sample_generator_str, + error_input_generator=type_error_generator_str, torch_reference=torch.Tensor.type, test_directives=( DecorateInfo( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2d2f0faa6d..063f61407e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -186,69 +186,93 @@ def is_cuda(a: TensorLike, /) -> bool: register_method("size", size) +_torch_dtype_to_old_torch_typestring_map = { + torch.float32: "FloatTensor", + torch.float64: "DoubleTensor", + torch.float16: "HalfTensor", + torch.bfloat16: "BFloat16Tensor", + torch.uint8: "ByteTensor", + torch.int8: "CharTensor", + torch.int16: "ShortTensor", + torch.int32: "IntTensor", + torch.long: "LongTensor", + torch.bool: "BoolTensor", +} + +_old_torch_typestring_to_torch_dtype_map = {v: k for k, v in _torch_dtype_to_old_torch_typestring_map.items()} + + +def _device_and_dtype_to_old_torch_typestring(device: DeviceLike, dtype: dtypeLike): + torch_dtype = to_torch_dtype(dtype) + dtype_str = _torch_dtype_to_old_torch_typestring_map.get(torch_dtype) + devicetype_str: str = "" + if device.devicetype is not devices.DeviceType.CPU: + devicetype_str = f"{devices.devicetype_string(device.devicetype)}." + return f"torch.{devicetype_str}{dtype_str}" + + +def _old_torch_typestring_to_devicetype_and_dtype(typestring: str) -> tuple[DeviceLike, dtypeLike]: + + # Two cases: + # - torch.DtypeTensor + # - torch.device.DtypeTensor + + _, *dev_and_dtype = typestring.split(".") + + if len(dev_and_dtype) == 1: + (dtype_str,) = dev_and_dtype + return "cpu", _old_torch_typestring_to_torch_dtype_map[dtype_str] + + if len(dev_and_dtype) == 2: + dtype_str, devicetype_str = dev_and_dtype + return dtype_str, _old_torch_typestring_to_torch_dtype_map[devicetype_str] + + # Assertion error -- expected the string to split into one or two elements + utils.check( + False, + lambda: f"type(): dtype format does not match torch.Tensor dtype nor old torch typestring format", + exception_type=ValueError, + ) + @torchsymbol(torch.Tensor.type, is_method=True) -def type(a: TensorLike, dtype: None | str | dtypeLike = None, non_blocking: bool = False, /) -> str | TensorLike: +def type( + a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: bool = False, **kwargs +) -> str | TensorLike: utils.check( not non_blocking, lambda: f"type(): `non_blocking==True` is currently not supported.", exception_type=NotImplementedError, ) - DTYPE_STR = { - torch.float32: "torch.FloatTensor", - torch.float64: "torch.DoubleTensor", - torch.float16: "torch.HalfTensor", - torch.bfloat16: "torch.BFloat16Tensor", - torch.uint8: "torch.ByteTensor", - torch.int8: "torch.CharTensor", - torch.int16: "torch.ShortTensor", - torch.int32: "torch.IntTensor", - torch.long: "torch.LongTensor", - torch.bool: "torch.BoolTensor", - } - if dtype is None: # returns the type of the input tensor in string - torch_dtype = to_torch_dtype(a.dtype) - torch_dtype = DTYPE_STR.get(torch_dtype) - if a.device.devicetype is devices.DeviceType.CUDA: - t, _dtype = torch_dtype.split(".") - torch_dtype = f"{t}.cuda.{_dtype}" - return torch_dtype - - TORCH_DTYPES = { - "torch.FloatTensor": torch.float32, - "torch.DoubleTensor": torch.float64, - "torch.HalfTensor": torch.float16, - "torch.BFloat16Tensor": torch.bfloat16, - "torch.ByteTensor": torch.uint8, - "torch.CharTensor": torch.int8, - "torch.ShortTensor": torch.int16, - "torch.IntTensor": torch.int32, - "torch.LongTensor": torch.long, - "torch.BoolTensor": torch.bool, - } + return _device_and_dtype_to_old_torch_typestring(a.device, a.dtype) + if isinstance(dtype, str): - parse_dtype = dtype.split(".") - if len(parse_dtype) == 2: - utils.check(dtype in TORCH_DTYPES, lambda: f"type(): invalid type: {dtype}.", exception_type=ValueError) - dtype = TORCH_DTYPES.get(dtype) + devtype, dtype = _old_torch_typestring_to_devicetype_and_dtype(dtype) + + # below if-statement handles the following case that you have mentioned: + # Follow-up question: what if the old string says "CUDA" but the tensor is on CUDA device 1, will a tensor + # on CUDA device 0 be created? + + if devtype == a.device.type: + dev = a.device + elif devtype == "cpu": + dev = devices.DeviceType.CPU else: - t, device, torch_dtype = dtype.split(".") - utils.check( - f"{t}.{torch_dtype}" in TORCH_DTYPES and device == "cuda", - lambda: f"type(): invalid type: {dtype}.", - exception_type=ValueError, - ) - dtype = TORCH_DTYPES.get(f"{t}.{torch_dtype}") - output = clang.maybe_convert_to_dtype(a, to_dtype(dtype)) - if a.device.devicetype is devices.DeviceType.CUDA: - output = prims.device_put(output, a.device.devicetype) - return output + dev = devices.DeviceType.CUDA + else: + # Question here -- if a device is not specified and the tensor is a CUDA tensor, will this create a tensor + # on a CPU device, or leave the tensor on a CUDA device? This would change how dev is set here + # PyTorch leaves the tensor on a CUDA device + dev = a.device + + return to(a, dev, dtype) register_method("type", type) + # # Data movement and transformation operations # From a005d03ec345b59d8b18a980aabd6997220e8777 Mon Sep 17 00:00:00 2001 From: k223kim Date: Tue, 30 Apr 2024 08:39:06 +0900 Subject: [PATCH 07/10] feat: updated tensor type based on mruberry's comments --- thunder/tests/opinfos.py | 4 ++-- thunder/torch/__init__.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 96fac413b6..7ea5238944 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2773,11 +2773,11 @@ def type_sample_generator_tensor(op, device, dtype, requires_grad, **kwargs): yield SampleInput(make(4, 4, device=device), dtype) yield SampleInput(make(4, 4, device=device), dtype=to_dtype) # below can be deleted if we don't support strings - yield SampleInput(make(4, 4, device=device), f"torch.{_torch_dtype_to_old_torch_typestring_map[dtype]}") + yield SampleInput(make(4, 4, device=device), f"torch.{_torch_dtype_to_old_torch_typestring_map[to_dtype]}") # Explictly pass device if torch.device(device).type == "cuda": - yield SampleInput(make(4, 4, device=device), f"torch.cuda.{_torch_dtype_to_old_torch_typestring_map[dtype]}") + yield SampleInput(make(4, 4, device=device), f"torch.cuda.{_torch_dtype_to_old_torch_typestring_map[to_dtype]}") # kind of redundant? diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 0f93208d9d..2ec5e8c438 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -202,7 +202,7 @@ def is_cuda(a: TensorLike, /) -> bool: _old_torch_typestring_to_torch_dtype_map = {v: k for k, v in _torch_dtype_to_old_torch_typestring_map.items()} -def _device_and_dtype_to_old_torch_typestring(device: DeviceLike, dtype: dtypeLike): +def _device_and_dtype_to_old_torch_typestring(device: DeviceLike, dtype: dtypeLike) -> str: torch_dtype = to_torch_dtype(dtype) dtype_str = _torch_dtype_to_old_torch_typestring_map.get(torch_dtype) devicetype_str: str = "" @@ -230,7 +230,7 @@ def _old_torch_typestring_to_devicetype_and_dtype(typestring: str) -> tuple[Devi # Assertion error -- expected the string to split into one or two elements utils.check( False, - lambda: f"type(): dtype format does not match torch.Tensor dtype nor old torch typestring format", + lambda: f"type(): unrecognized torch typestring {typestring}", exception_type=ValueError, ) @@ -252,20 +252,19 @@ def type( if isinstance(dtype, str): devtype, dtype = _old_torch_typestring_to_devicetype_and_dtype(dtype) - # below if-statement handles the following case that you have mentioned: - # Follow-up question: what if the old string says "CUDA" but the tensor is on CUDA device 1, will a tensor - # on CUDA device 0 be created? - if devtype == a.device.type: + # This handles two cases: + # 1. When a tensor is already on a CUDA device, and the device type string is CUDA. In this case the tensor remains on its current device. + # 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device. dev = a.device elif devtype == "cpu": dev = devices.DeviceType.CPU - else: + elif devtype == "cuda": dev = devices.DeviceType.CUDA + else: + raise ValueError(f"type(): unrecognized torch typestring {dtype}") else: - # Question here -- if a device is not specified and the tensor is a CUDA tensor, will this create a tensor - # on a CPU device, or leave the tensor on a CUDA device? This would change how dev is set here - # PyTorch leaves the tensor on a CUDA device + # dtype is assumed to be torch.dtype (e.g. torch.int32) dev = a.device return to(a, dev, dtype) From 09950b39b807b77f7c3af75975f2e019c89a8560 Mon Sep 17 00:00:00 2001 From: k223kim Date: Tue, 30 Apr 2024 09:06:36 +0900 Subject: [PATCH 08/10] feat: added device_from_string --- thunder/torch/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2ec5e8c438..8473e3b62f 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -17,7 +17,7 @@ import thunder.clang as clang import thunder.core.devices as devices -from thunder.core.devices import to_device +from thunder.core.devices import to_device, device_from_string import thunder.core.dtypes as dtypes from thunder.core.dtypes import to_torch_dtype, to_dtype, _thunder_to_torch_dtype_map, _torch_to_thunder_dtype_map import thunder.core.prims as prims @@ -257,10 +257,8 @@ def type( # 1. When a tensor is already on a CUDA device, and the device type string is CUDA. In this case the tensor remains on its current device. # 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device. dev = a.device - elif devtype == "cpu": - dev = devices.DeviceType.CPU - elif devtype == "cuda": - dev = devices.DeviceType.CUDA + elif devtype == "cpu" or devtype == "cuda": + dev = device_from_string(devtype) else: raise ValueError(f"type(): unrecognized torch typestring {dtype}") else: From e43c28fcbe37f08c5ab11d7aade2e4b89c10dc66 Mon Sep 17 00:00:00 2001 From: k223kim Date: Thu, 2 May 2024 02:26:49 +0900 Subject: [PATCH 09/10] feat: updated tensor type based on mruberry's comments --- thunder/torch/__init__.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 8473e3b62f..f8b80b3ec3 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -218,27 +218,32 @@ def _old_torch_typestring_to_devicetype_and_dtype(typestring: str) -> tuple[Devi # - torch.device.DtypeTensor _, *dev_and_dtype = typestring.split(".") + devicetype_str = "cpu" + dtype_str = "" if len(dev_and_dtype) == 1: + # when devicetype_str is omitted, device type is CPU (dtype_str,) = dev_and_dtype - return "cpu", _old_torch_typestring_to_torch_dtype_map[dtype_str] + dtype_str = _old_torch_typestring_to_torch_dtype_map[dtype_str] if len(dev_and_dtype) == 2: - dtype_str, devicetype_str = dev_and_dtype - return dtype_str, _old_torch_typestring_to_torch_dtype_map[devicetype_str] + devicetype_str, dtype_str = dev_and_dtype + dtype_str = _old_torch_typestring_to_torch_dtype_map[dtype_str] - # Assertion error -- expected the string to split into one or two elements + # Value error + # expected the string to split into one or two elements + # and devicetype_str should be either "cpu" or "cuda" utils.check( - False, + devicetype_str in ("cpu", "cuda") and 1 <= len(dev_and_dtype) <= 2, lambda: f"type(): unrecognized torch typestring {typestring}", exception_type=ValueError, ) + return devicetype_str, dtype_str + @torchsymbol(torch.Tensor.type, is_method=True) -def type( - a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: bool = False, **kwargs -) -> str | TensorLike: +def type(a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: bool = False) -> str | TensorLike: utils.check( not non_blocking, lambda: f"type(): `non_blocking==True` is currently not supported.", @@ -257,10 +262,8 @@ def type( # 1. When a tensor is already on a CUDA device, and the device type string is CUDA. In this case the tensor remains on its current device. # 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device. dev = a.device - elif devtype == "cpu" or devtype == "cuda": - dev = device_from_string(devtype) else: - raise ValueError(f"type(): unrecognized torch typestring {dtype}") + dev = device_from_string(devtype) else: # dtype is assumed to be torch.dtype (e.g. torch.int32) dev = a.device From 89dd7926d75b9d9fbc4b5dac89b82a9fc2fec857 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 May 2024 08:01:29 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/torch/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index aa67d1a685..182e5adc41 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -193,6 +193,7 @@ def numel(a: TensorLike, /) -> int: def is_cuda(a: TensorLike, /) -> bool: return a.device.devicetype is devices.DeviceType.CUDA + _torch_dtype_to_old_torch_typestring_map = { torch.float32: "FloatTensor", torch.float64: "DoubleTensor",