Skip to content

Commit

Permalink
more pickling (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 16, 2024
1 parent 5517685 commit 815bff0
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ looseversion ==1.3.0
lightning-utilities >=0.7.0
numpy >=1.23.0,<2 # not yet ready for numpy 2
igraph >=0.10.4
optree >=0.11.0
optree >=0.12.1
opt_einsum >= 3.3.0
mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined`
dill >=0.3.8 # Support for 3.12
3 changes: 3 additions & 0 deletions thunder/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def device_str(self) -> str:
# note: self.devicetype == DeviceType.CPU, .META
return devicetype_string(self.devicetype)

def __reduce__(self):
return (Device, (self.device_str(),))


cpu = Device(DeviceType.CPU, None)

Expand Down
11 changes: 10 additions & 1 deletion thunder/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,15 @@ def is_weak(self):
def shortname(self):
return f"{self._shortname}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}"

@property
def full_name(self):
return (
f"{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
)

# TODO Fix name printing
def __repr__(self):
return f"thunder.dtypes.{self._name}{8 * self._bytes}{f'_{self._variant}' if self._variant else ''}{'_' if self._is_weak else ''}"
return f"thunder.dtypes.{self.full_name}"

def __str__(self):
return self.__repr__()
Expand All @@ -104,6 +110,9 @@ def __eq__(self, other) -> bool:
and self._variant == other._variant
)

def __reduce__(self):
return self.full_name


class exact(dtype):
"""Abstract base class for the signedinteger, unsignedinteger and bool_ dtypes."""
Expand Down
6 changes: 6 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,12 @@ def fn(a, b, l):

assert str(pickle.loads(pickle.dumps(prologue_trace))) == str(prologue_trace)

# check that these are looked up rather than duplicated
device = thunder.devices.Device("cpu")
assert pickle.loads(pickle.dumps(device)) is device
fp32 = thunder.dtypes.float32
assert pickle.loads(pickle.dumps(fp32)) is fp32


@pytest.mark.parametrize("requires_grad", (True, False))
def test_dataclass_output(requires_grad):
Expand Down

0 comments on commit 815bff0

Please sign in to comment.