Skip to content

Commit

Permalink
Merge pull request #21 from jcmgray/mimicfix
Browse files Browse the repository at this point in the history
fix NumpyMimic special attribute access (#20)
  • Loading branch information
jcmgray authored May 10, 2024
2 parents 2c77049 + 428fefe commit c906d17
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
42 changes: 24 additions & 18 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,25 @@ def binary_dispatcher(args, kwargs):

# --------------- object to act as drop-in replace for numpy ---------------- #

_partial_functions = {}

def _get_mimic_function_or_attribute(self, fn):
# respect all 'dunder' special methods and attributes
if (fn[:2] == "__") and (fn[-2:] == "__"):
return object.__getattribute__(self, fn)

# look out for certain submodules which are not functions
if fn == "linalg":
return NumpyMimic("linalg")

if fn == "random":
return NumpyMimic("random")

# if this is the e.g. linalg mimic, preprend 'linalg.'
submod = object.__getattribute__(self, "submodule")
if submod is not None:
fn = ".".join((submod, fn))

return functools.partial(do, fn)


class NumpyMimic:
Expand All @@ -1457,23 +1475,13 @@ class NumpyMimic:
def __init__(self, submodule=None):
self.submodule = submodule

def __getattribute__(self, fn):
# look out for certain submodules which are not functions
if fn == "linalg":
return numpy_linalg
if fn == "random":
return numpy_random

# if this is the e.g. linalg mimic, preprend 'linalg.'
submod = object.__getattribute__(self, "submodule")
if submod is not None:
fn = ".".join((submod, fn))

# cache the correct partial function
def __getattribute__(self, attr):
# cache the correct partial function (or special method/attribute)
d = object.__getattribute__(self, "__dict__")
try:
pfn = _partial_functions[fn]
pfn = d[attr]
except KeyError:
pfn = _partial_functions[fn] = functools.partial(do, fn)
pfn = d[attr] = _get_mimic_function_or_attribute(self, attr)

return pfn

Expand All @@ -1483,8 +1491,6 @@ def __repr__():


numpy = NumpyMimic()
numpy_linalg = NumpyMimic("linalg")
numpy_random = NumpyMimic("random")


# --------------------------------------------------------------------------- #
Expand Down
13 changes: 13 additions & 0 deletions tests/test_autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ def modified_gram_schmidt_np_mimic(X):
return np.stack(Q, axis=0)


def test_numpy_mimic_dunder_methods():
from abc import ABC
from autoray import numpy as np

class Base(ABC):
pass

assert isinstance(np, object)
assert not isinstance(np, Base)
print(np)
dir(np)


@pytest.mark.parametrize("backend", BACKENDS)
def test_mgs_np_mimic(backend):
if backend == "sparse":
Expand Down

0 comments on commit c906d17

Please sign in to comment.