Skip to content

Commit

Permalink
BUG: (NEP 19) fix array repr compatibility with Numpy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Aug 3, 2023
1 parent b0ff846 commit e033864
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
11 changes: 11 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,14 @@ def interp(x, xp, fp, *args, **kwargs):
np.interp(np.asarray(x), np.asarray(xp), np.asarray(fp), *args, **kwargs)
* ret_units
)


@implements(np.array_repr)
def array_repr(arr, *args, **kwargs):
rep = np.array_repr._implementation(arr.view(np.ndarray), *args, **kwargs)
rep = rep.replace("array", arr.__class__.__name__)
units_repr = arr.units.__repr__()
if "=" in rep:
return rep[:-1] + ", units='" + units_repr + "')"
else:
return rep[:-1] + ", '" + units_repr + "')"
7 changes: 1 addition & 6 deletions unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,12 +636,7 @@ def __new__(
return obj

def __repr__(self):
rep = super().__repr__()
units_repr = self.units.__repr__()
if "=" in rep:
return rep[:-1] + ", units='" + units_repr + "')"
else:
return rep[:-1] + ", '" + units_repr + "')"
return np.array_repr(self)

def __str__(self):
return str(self.view(np.ndarray)) + " " + str(self.units)
Expand Down
3 changes: 1 addition & 2 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
np.argpartition, # returns pure numbers
np.argsort, # returns pure numbers
np.argwhere, # returns pure numbers
np.array_repr, # hooks into __repr__
np.array_str, # hooks into __str__
np.atleast_1d, # works out of the box (tested)
np.atleast_2d, # works out of the box (tested)
Expand Down Expand Up @@ -256,7 +255,7 @@ def test_wrapping_completeness():

def test_array_repr():
arr = [1, 2, 3] * cm
assert np.array_repr(arr) == "unyt_array([1, 2, 3], units='cm')"
assert np.array_repr(arr) == "unyt_array([1, 2, 3], 'cm')"


def test_dot_vectors():
Expand Down

0 comments on commit e033864

Please sign in to comment.