Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make lh5 types and exceptions picklable #129

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lgdo/lh5/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __str__(self) -> str:
+ super().__str__()
)

def __reduce__(self) -> tuple: # for pickling.
return self.__class__, (*self.args, self.file, self.obj)


class LH5EncodeError(Exception):
def __init__(
Expand All @@ -32,3 +35,6 @@ def __str__(self) -> str:
f"while writing object {self.group}/{self.name} to file {self.file}: "
+ super().__str__()
)

def __reduce__(self) -> tuple: # for pickling.
return self.__class__, (*self.args, self.file, self.group, self.name)
11 changes: 8 additions & 3 deletions src/lgdo/types/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,18 @@ def fill(self, data, w: NDArray = None, keys: Sequence[str] = None) -> None:

def __setitem__(self, name: str, obj: LGDO) -> None:
# do not allow for new attributes on this
msg = "histogram fields cannot be mutated"
raise TypeError(msg)
known_keys = ("binning", "weights", "isdensity")
if name in known_keys and not dict.__contains__(self, name):
# but allow initialization while unpickling (after __init__() this is unreachable)
dict.__setitem__(self, name, obj)
else:
msg = "histogram fields cannot be mutated "
raise TypeError(msg)

def __getattr__(self, name: str) -> None:
# do not allow for new attributes on this
msg = "histogram fields cannot be mutated"
raise TypeError(msg)
raise AttributeError(msg)

def add_field(self, name: str | int, obj: LGDO) -> None: # noqa: ARG002
"""
Expand Down
6 changes: 6 additions & 0 deletions src/lgdo/types/lgdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
class LGDO(ABC):
"""Abstract base class representing a LEGEND Data Object (LGDO)."""

def __new__(cls, *_args, **_kwargs):
# allow for (un-)pickling LGDO objects.
obj = super().__new__(cls)
obj.attrs = {}
return obj

@abstractmethod
def __init__(self, attrs: dict[str, Any] | None = None) -> None:
self.attrs = {} if attrs is None else dict(attrs)
Expand Down
6 changes: 6 additions & 0 deletions src/lgdo/types/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class Table(Struct):
:meth:`__len__` to access valid data, which returns the ``size`` attribute.
"""

def __new__(cls, *args, **kwargs):
# allow for (un-)pickling LGDO objects.
obj = super().__new__(cls, *args, **kwargs)
obj.size = None
return obj

def __init__(
self,
col_dict: Mapping[str, LGDO] | pd.DataFrame | ak.Array | None = None,
Expand Down
17 changes: 17 additions & 0 deletions tests/lh5/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import pickle

from lgdo.lh5.exceptions import LH5DecodeError, LH5EncodeError


def test_pickle():
# test (un-)pickling of LH5 exceptions; e.g. for multiprocessing use.

ex = LH5EncodeError("message", "file", "group", "name")
ex = pickle.loads(pickle.dumps(ex))
assert isinstance(ex, LH5EncodeError)

ex = LH5DecodeError("message", "file", "obj")
ex = pickle.loads(pickle.dumps(ex))
assert isinstance(ex, LH5DecodeError)
13 changes: 13 additions & 0 deletions tests/types/test_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import awkward as ak
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -61,3 +63,14 @@ def test_view():

with pytest.raises(ValueError):
a.view_as("ak", with_units=True)


def test_pickle():
obj = Array(nda=np.array([1, 2, 3, 4]))
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Array)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert np.all(ex.nda == np.array([1, 2, 3, 4]))
50 changes: 50 additions & 0 deletions tests/types/test_encoded.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import awkward as ak
import awkward_pandas as akpd
import numpy as np
Expand Down Expand Up @@ -285,3 +287,51 @@ def test_aoeesa_view_as():

with pytest.raises(TypeError):
df = voev.view_as("np")


def test_aoeesa_pickle():
obj = ArrayOfEncodedEqualSizedArrays(
encoded_data=VectorOfVectors(
flattened_data=Array(nda=np.array([1, 2, 3, 4, 5, 2, 4, 8, 9, 7, 5, 3, 1])),
cumulative_length=Array(nda=np.array([2, 5, 6, 10, 13])),
),
decoded_size=99,
)

ex = pickle.loads(pickle.dumps(obj))

desired = [
[1, 2],
[3, 4, 5],
[2],
[4, 8, 9, 7],
[5, 3, 1],
]

for i, v in enumerate(ex):
assert np.array_equal(v, desired[i])


def test_voev_pickle():
obj = VectorOfEncodedVectors(
encoded_data=VectorOfVectors(
flattened_data=Array(nda=np.array([1, 2, 3, 4, 5, 2, 4, 8, 9, 7, 5, 3, 1])),
cumulative_length=Array(nda=np.array([2, 5, 6, 10, 13])),
),
decoded_size=Array(shape=5, fill_val=6),
attrs={"units": "s"},
)

ex = pickle.loads(pickle.dumps(obj))

desired = [
[1, 2],
[3, 4, 5],
[2],
[4, 8, 9, 7],
[5, 3, 1],
]

for i, (v, s) in enumerate(ex):
assert np.array_equal(v, desired[i])
assert s == 6
14 changes: 13 additions & 1 deletion tests/types/test_histogram.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import pickle

import hist
import numpy as np
Expand Down Expand Up @@ -266,7 +267,7 @@ def test_view_as_np():
def test_not_like_table():
h = Histogram(np.array([1, 1]), (np.array([0, 1, 2]),))
assert h.form_datatype() == "struct{binning,weights,isdensity}"
with pytest.raises(TypeError):
with pytest.raises(AttributeError):
x = h.x # noqa: F841
with pytest.raises(TypeError):
h["x"] = Scalar(1.0)
Expand Down Expand Up @@ -392,3 +393,14 @@ def test_histogram_fill():

with pytest.raises(ValueError, match="data must be"):
h.fill(np.ones(shape=(5, 5)))


def test_pickle():
obj = Histogram(np.array([1, 1]), (Histogram.Axis.from_range_edges([0, 1, 2]),))
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Histogram)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert np.all(ex.weights == obj.weights)
13 changes: 13 additions & 0 deletions tests/types/test_scalar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

import lgdo
Expand Down Expand Up @@ -33,3 +35,14 @@ def test_getattrs():

def test_equality():
assert lgdo.Scalar(value=42) == lgdo.Scalar(value=42)


def test_pickle():
obj = lgdo.Scalar(value=10)
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, lgdo.Scalar)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
assert ex.value == 10
14 changes: 14 additions & 0 deletions tests/types/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pickle

import pytest

import lgdo
Expand Down Expand Up @@ -78,3 +80,15 @@ def test_remove_field():

struct.remove_field("array1", delete=True)
assert list(struct.keys()) == []


def test_pickle():
obj_dict = {"scalar1": lgdo.Scalar(value=10)}
attrs = {"attr1": 1}
struct = lgdo.Struct(obj_dict=obj_dict, attrs=attrs)

ex = pickle.loads(pickle.dumps(struct))
assert isinstance(ex, lgdo.Struct)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == struct.attrs["datatype"]
assert ex["scalar1"].value == 10
18 changes: 18 additions & 0 deletions tests/types/test_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pickle
import warnings

import awkward as ak
Expand Down Expand Up @@ -221,3 +222,20 @@ def test_remove_column():

tbl.remove_column("c")
assert list(tbl.keys()) == ["b"]


def test_pickle():
col_dict = {
"a": lgdo.Array(nda=np.array([1, 2, 3, 4])),
"b": lgdo.Array(nda=np.array([5, 6, 7, 8])),
"c": lgdo.Array(nda=np.array([9, 10, 11, 12])),
}
obj = Table(col_dict=col_dict)
obj.attrs["attr1"] = 1

ex = pickle.loads(pickle.dumps(obj))
assert isinstance(ex, Table)
assert ex.attrs["attr1"] == 1
assert ex.attrs["datatype"] == obj.attrs["datatype"]
for key, val in col_dict.items():
assert ex[key] == val
17 changes: 17 additions & 0 deletions tests/types/test_vectorofvectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import pickle
from collections import namedtuple

import awkward as ak
Expand Down Expand Up @@ -441,3 +442,19 @@ def test_lh5_iterator_view_as(lgnd_test_data):

for obj, _, _ in it:
assert ak.is_valid(obj.view_as("ak"))


def test_pickle(testvov):
obj = testvov.v2d
ex = pickle.loads(pickle.dumps(obj))

desired = [
np.array([1, 2]),
np.array([3, 4, 5]),
np.array([2]),
np.array([4, 8, 9, 7]),
np.array([5, 3, 1]),
]

for i in range(len(desired)):
assert np.array_equal(desired[i], ex[i])
Loading