Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: copy_behaviors to make sub-classing easy #3137

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import struct
import sys
import typing
from collections.abc import Collection

import numpy as np # noqa: TID251
Expand Down Expand Up @@ -102,3 +103,18 @@ def unique_list(items: Collection[T]) -> list[T]:
seen.add(item)
result.append(item)
return result


def copy_behaviors(existing_class: typing.Any, new_class: typing.Any, behavior: dict):
output = {}

oldname = existing_class.__name__
newname = new_class.__name__

for key, value in behavior.items():
if oldname in key:
if not isinstance(key, str) and "*" not in key:
new_tuple = tuple(newname if k == oldname else k for k in key)
output[new_tuple] = value

return output
125 changes: 125 additions & 0 deletions tests/test_2433_copy_behaviors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy
import pytest

import awkward as ak


def test():
class SuperVector:
def add(self, other):
"""Add two vectors together elementwise using `x` and `y` components"""
return ak.zip(
{"x": self.x + other.x, "y": self.y + other.y},
with_name="VectorTwoD",
behavior=self.behavior,
)

# first sub-class
@ak.mixin_class(ak.behavior)
class VectorTwoD(SuperVector):
def __eq__(self, other):
return ak.all(self.x == other.x) and ak.all(self.y == other.y)

v = ak.Array(
[
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
[],
[{"x": 3, "y": 3.3}],
[
{"x": 4, "y": 4.4},
{"x": 5, "y": 5.5},
{"x": 6, "y": 6.6},
],
],
with_name="VectorTwoD",
behavior=ak.behavior,
)
v_added = ak.Array(
[
[{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}],
[],
[{"x": 6, "y": 6.6}],
[
{"x": 8, "y": 8.8},
{"x": 10, "y": 11},
{"x": 12, "y": 13.2},
],
],
with_name="VectorTwoD",
behavior=ak.behavior,
)

# add method works but the binary operator does not
assert v.add(v) == v_added
with pytest.raises(TypeError):
v + v

# registering the operator makes everything work
ak.behavior[numpy.add, "VectorTwoD", "VectorTwoD"] = lambda v1, v2: v1.add(v2)
assert v + v == v_added

# second sub-class
@ak.mixin_class(ak.behavior)
class VectorTwoDAgain(VectorTwoD):
pass

v = ak.Array(
[
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
[],
[{"x": 3, "y": 3.3}],
[
{"x": 4, "y": 4.4},
{"x": 5, "y": 5.5},
{"x": 6, "y": 6.6},
],
],
with_name="VectorTwoDAgain",
behavior=ak.behavior,
)
# add method works but the binary operator does not
assert v.add(v) == v_added
with pytest.raises(TypeError):
v + v

# instead of registering every operator again, just copy the behaviors of
# another class to this class
ak.behavior.update(
ak._util.copy_behaviors(VectorTwoD, VectorTwoDAgain, ak.behavior)
)
assert v + v == v_added

# third sub-class
@ak.mixin_class(ak.behavior)
class VectorTwoDAgainAgain(VectorTwoDAgain):
pass

v = ak.Array(
[
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}],
[],
[{"x": 3, "y": 3.3}],
[
{"x": 4, "y": 4.4},
{"x": 5, "y": 5.5},
{"x": 6, "y": 6.6},
],
],
with_name="VectorTwoDAgainAgain",
behavior=ak.behavior,
)
# add method works but the binary operator does not
assert v.add(v) == v_added
with pytest.raises(TypeError):
v + v

# instead of registering every operator again, just copy the behaviors of
# another class to this class
ak.behavior.update(
ak._util.copy_behaviors(VectorTwoDAgain, VectorTwoDAgainAgain, ak.behavior)
)
assert v + v == v_added
Loading