From 815bff00f18825b9551e6dcf7b91f1f3c0dfb37b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 16 Aug 2024 12:13:03 +0200 Subject: [PATCH] more pickling (#975) --- requirements/base.txt | 2 +- thunder/core/devices.py | 3 +++ thunder/core/dtypes.py | 11 ++++++++++- thunder/tests/test_core.py | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index 4a194eb0ed..fc9b17e751 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -3,7 +3,7 @@ looseversion ==1.3.0 lightning-utilities >=0.7.0 numpy >=1.23.0,<2 # not yet ready for numpy 2 igraph >=0.10.4 -optree >=0.11.0 +optree >=0.12.1 opt_einsum >= 3.3.0 mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined` dill >=0.3.8 # Support for 3.12 diff --git a/thunder/core/devices.py b/thunder/core/devices.py index 03f4a20d45..dd82a9c6c1 100644 --- a/thunder/core/devices.py +++ b/thunder/core/devices.py @@ -133,6 +133,9 @@ def device_str(self) -> str: # note: self.devicetype == DeviceType.CPU, .META return devicetype_string(self.devicetype) + def __reduce__(self): + return (Device, (self.device_str(),)) + cpu = Device(DeviceType.CPU, None) diff --git a/thunder/core/dtypes.py b/thunder/core/dtypes.py index c021975bb6..36399e8164 100644 --- a/thunder/core/dtypes.py +++ b/thunder/core/dtypes.py @@ -83,9 +83,15 @@ def is_weak(self): def shortname(self): return f"{self._shortname}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}" + @property + def full_name(self): + return ( + f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}" + ) + # TODO Fix name printing def __repr__(self): - return f"thunder.dtypes.{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}" + return f"thunder.dtypes.{self.full_name}" def __str__(self): return self.__repr__() @@ -104,6 +110,9 @@ def __eq__(self, other) -> bool: and self._variant == other._variant ) + def __reduce__(self): + return self.full_name + class exact(dtype): """Abstract base class for the signedinteger, unsignedinteger and bool_ dtypes.""" diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index b7075a6ef2..2d3dd8d244 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2802,6 +2802,12 @@ def fn(a, b, l): assert str(pickle.loads(pickle.dumps(prologue_trace))) == str(prologue_trace) + # check that these are looked up rather than duplicated + device = thunder.devices.Device("cpu") + assert pickle.loads(pickle.dumps(device)) is device + fp32 = thunder.dtypes.float32 + assert pickle.loads(pickle.dumps(fp32)) is fp32 + @pytest.mark.parametrize("requires_grad", (True, False)) def test_dataclass_output(requires_grad):