Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions param/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2230,7 +2230,13 @@ def __init__(self, default=Undefined, **params):


class Array(ClassSelector):
"""Parameter whose value is a numpy array."""
"""Parameter whose value is a numpy array.

Accepts numpy ``ndarray`` objects as well as array-like objects that
implement the ``__array__`` or ``__array_interface__`` protocols
(e.g. pandas ``ExtensionArray`` subclasses such as
``ArrowStringArray``).
"""

@typing.overload
def __init__(
Expand All @@ -2246,11 +2252,32 @@ def __init__(self, default=Undefined, **params):
from numpy import ndarray
super().__init__(default=default, class_=ndarray, **params)

@staticmethod
def _is_array_like(val):
"""Return True if *val* supports the numpy array protocol."""
try:
return (
callable(getattr(val, '__array__', None))
or getattr(val, '__array_interface__', None) is not None
)
except Exception:
return False

def _validate_class_(self, val, class_, is_instance):
# Accept array-like objects (e.g. pandas ExtensionArray,
# ArrowStringArray) that support the numpy array protocol.
if is_instance and not isinstance(val, class_) and self._is_array_like(val):
return
super()._validate_class_(val, class_, is_instance)
Comment on lines +2266 to +2271

@classmethod
def serialize(cls, value):
if value is None:
return None
return value.tolist()
if hasattr(value, 'tolist'):
return value.tolist()
import numpy
return numpy.asarray(value).tolist()

@classmethod
def deserialize(cls, value):
Expand Down
82 changes: 82 additions & 0 deletions tests/testnumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,85 @@ class MatParam(param.Parameterized):

mp = MatParam()
mp.param.pprint()

def test_array_accepts_array_like_with_dunder_array(self):
"""Objects implementing __array__ should be accepted by param.Array."""
class ArrayLike:
"""Minimal array-like with __array__ protocol."""
def __init__(self, data):
self._data = numpy.asarray(data)
def __array__(self, dtype=None, copy=None):
if dtype is not None:
return self._data.astype(dtype)
return self._data

class P(param.Parameterized):
arr = param.Array()

p = P()
array_like = ArrayLike([1, 2, 3])
p.arr = array_like # Should not raise
numpy.testing.assert_array_equal(numpy.asarray(p.arr), [1, 2, 3])
# Verify serialize fallback (ArrayLike has no .tolist())
self.assertEqual(param.Array.serialize(array_like), [1, 2, 3])

def test_array_accepts_array_like_with_array_interface(self):
"""Objects with __array_interface__ should be accepted."""
class ArrayInterface:
"""Minimal object with __array_interface__."""
def __init__(self, data):
self._arr = numpy.asarray(data)

@property
def __array_interface__(self):
return self._arr.__array_interface__

class P(param.Parameterized):
arr = param.Array()

p = P()
obj = ArrayInterface([4, 5, 6])
p.arr = obj # Should not raise
numpy.testing.assert_array_equal(numpy.asarray(p.arr), [4, 5, 6])

def test_array_rejects_plain_list(self):
"""Plain lists should still be rejected (no __array__ attribute)."""
class P(param.Parameterized):
arr = param.Array()

p = P()
with self.assertRaises(ValueError):
p.arr = [1, 2, 3]

def test_array_rejects_string(self):
"""Strings should still be rejected."""
class P(param.Parameterized):
arr = param.Array()

p = P()
with self.assertRaises(ValueError):
p.arr = "not an array"

def test_array_accepts_pandas_extension_array(self):
"""pandas ExtensionArray subclasses should be accepted."""
try:
import pandas as pd
except ImportError:
self.skipTest("pandas not available")

class P(param.Parameterized):
arr = param.Array()

p = P()
# Categorical array implements __array__
cat = pd.Categorical(["a", "b", "a"])
p.arr = cat # Should not raise
# Verify serialization works on accepted array-like
self.assertEqual(param.Array.serialize(cat), ["a", "b", "a"])

# ArrowStringArray (pandas >= 1.2 with pyarrow) also implements __array__
try:
arrow_arr = pd.array(["x", "y", "z"], dtype="string[pyarrow]")
p.arr = arrow_arr # Should not raise
except (ImportError, TypeError, ValueError):
pass # pyarrow not installed or dtype unavailable