diff --git a/param/parameters.py b/param/parameters.py index 2aa22be5..b349477f 100644 --- a/param/parameters.py +++ b/param/parameters.py @@ -2230,7 +2230,13 @@ def __init__(self, default=Undefined, **params): class Array(ClassSelector): - """Parameter whose value is a numpy array.""" + """Parameter whose value is a numpy array. + + Accepts numpy ``ndarray`` objects as well as array-like objects that + implement the ``__array__`` or ``__array_interface__`` protocols + (e.g. pandas ``ExtensionArray`` subclasses such as + ``ArrowStringArray``). + """ @typing.overload def __init__( @@ -2246,11 +2252,32 @@ def __init__(self, default=Undefined, **params): from numpy import ndarray super().__init__(default=default, class_=ndarray, **params) + @staticmethod + def _is_array_like(val): + """Return True if *val* supports the numpy array protocol.""" + try: + return ( + callable(getattr(val, '__array__', None)) + or getattr(val, '__array_interface__', None) is not None + ) + except Exception: + return False + + def _validate_class_(self, val, class_, is_instance): + # Accept array-like objects (e.g. pandas ExtensionArray, + # ArrowStringArray) that support the numpy array protocol. + if is_instance and not isinstance(val, class_) and self._is_array_like(val): + return + super()._validate_class_(val, class_, is_instance) + @classmethod def serialize(cls, value): if value is None: return None - return value.tolist() + if hasattr(value, 'tolist'): + return value.tolist() + import numpy + return numpy.asarray(value).tolist() @classmethod def deserialize(cls, value): diff --git a/tests/testnumpy.py b/tests/testnumpy.py index c1ed5e7d..ff54f621 100644 --- a/tests/testnumpy.py +++ b/tests/testnumpy.py @@ -78,3 +78,85 @@ class MatParam(param.Parameterized): mp = MatParam() mp.param.pprint() + + def test_array_accepts_array_like_with_dunder_array(self): + """Objects implementing __array__ should be accepted by param.Array.""" + class ArrayLike: + """Minimal array-like with __array__ protocol.""" + def __init__(self, data): + self._data = numpy.asarray(data) + def __array__(self, dtype=None, copy=None): + if dtype is not None: + return self._data.astype(dtype) + return self._data + + class P(param.Parameterized): + arr = param.Array() + + p = P() + array_like = ArrayLike([1, 2, 3]) + p.arr = array_like # Should not raise + numpy.testing.assert_array_equal(numpy.asarray(p.arr), [1, 2, 3]) + # Verify serialize fallback (ArrayLike has no .tolist()) + self.assertEqual(param.Array.serialize(array_like), [1, 2, 3]) + + def test_array_accepts_array_like_with_array_interface(self): + """Objects with __array_interface__ should be accepted.""" + class ArrayInterface: + """Minimal object with __array_interface__.""" + def __init__(self, data): + self._arr = numpy.asarray(data) + + @property + def __array_interface__(self): + return self._arr.__array_interface__ + + class P(param.Parameterized): + arr = param.Array() + + p = P() + obj = ArrayInterface([4, 5, 6]) + p.arr = obj # Should not raise + numpy.testing.assert_array_equal(numpy.asarray(p.arr), [4, 5, 6]) + + def test_array_rejects_plain_list(self): + """Plain lists should still be rejected (no __array__ attribute).""" + class P(param.Parameterized): + arr = param.Array() + + p = P() + with self.assertRaises(ValueError): + p.arr = [1, 2, 3] + + def test_array_rejects_string(self): + """Strings should still be rejected.""" + class P(param.Parameterized): + arr = param.Array() + + p = P() + with self.assertRaises(ValueError): + p.arr = "not an array" + + def test_array_accepts_pandas_extension_array(self): + """pandas ExtensionArray subclasses should be accepted.""" + try: + import pandas as pd + except ImportError: + self.skipTest("pandas not available") + + class P(param.Parameterized): + arr = param.Array() + + p = P() + # Categorical array implements __array__ + cat = pd.Categorical(["a", "b", "a"]) + p.arr = cat # Should not raise + # Verify serialization works on accepted array-like + self.assertEqual(param.Array.serialize(cat), ["a", "b", "a"]) + + # ArrowStringArray (pandas >= 1.2 with pyarrow) also implements __array__ + try: + arrow_arr = pd.array(["x", "y", "z"], dtype="string[pyarrow]") + p.arr = arrow_arr # Should not raise + except (ImportError, TypeError, ValueError): + pass # pyarrow not installed or dtype unavailable