Skip to content

Commit

Permalink
python: fix __array__ methods, the logic in a helper func wasn't correct
Browse files Browse the repository at this point in the history
one particular problem, for copy=False, was that astype(None) was called
resulting in casting complex values to real with warning:
ComplexWarning: Casting complex values to real discards the imaginary part
as reported in rs-station/reciprocalspaceship#284
  • Loading branch information
wojdyr committed Jan 14, 2025
1 parent 1509a97 commit 4dfefa1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
14 changes: 7 additions & 7 deletions python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ nb::list getitem_slice(Items& items, const nb::slice& slice) {

// for numpy __array__ method
inline nb::object handle_numpy_array_args(const nb::object& o, nb::handle dtype, nb::handle copy) {
if (dtype.is_none() && copy.is_none())
return o;
if (copy.is_none())
return o.attr("astype")(dtype);
// astype() may copy even with copy=False, so we check first
if (copy.ptr() == Py_False && !dtype.is_none() && !dtype.is(o.attr("dtype")))
if (dtype.is_none() || dtype.is(o.attr("dtype"))) {
if (copy.ptr() != Py_True)
return o;
dtype = o.attr("dtype");
}
if (copy.ptr() == Py_False) // astype() would copy even with copy=False
throw nb::value_error("Unable to avoid copy while creating an array as requested.");
return o.attr("astype")(dtype, nb::arg("copy")=copy);
return o.attr("astype")(dtype);
}

namespace nanobind { namespace detail {
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,27 @@ def cif_floats(rb, tag):
self.assertAlmostEqual(d3['HLC'], 1.53099, delta=1e-5)
self.assertAlmostEqual(d3['HLD'], 4.64824, delta=1e-5)

class TestReciprocalGrid(unittest.TestCase):
@unittest.skipIf(numpy is None, 'requires NumPy')
def test_array_conversion(self):
grid = gemmi.ReciprocalComplexGrid(4, 4, 4)
expected_dtype = numpy.complex64
self.assertEqual(grid.array.dtype, expected_dtype)
new_val = 0+0j
if numpy.__version__[:2] == '1.':
possible_copy_values = [True, False]
else:
possible_copy_values = [None, True, False]
for dtype in [None, numpy.complex64]:
for copy in possible_copy_values:
arr = numpy.array(grid, dtype=dtype, copy=copy)
self.assertEqual(arr.dtype, expected_dtype,
msg=f'for {dtype=} {copy=} {new_val=}, {grid.array[0,0,0]=}')
new_val += 1+1j
arr[0,0,0] = new_val
grid_changed = (grid.array[0,0,0] == new_val)
self.assertEqual(grid_changed, not copy,
msg=f'for {dtype=} {copy=} {new_val=}, {grid.array[0,0,0]=}')

if __name__ == '__main__':
unittest.main()
5 changes: 2 additions & 3 deletions tests/test_unitcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,12 @@ def test_mat33_dunder_array(self):
m.fromlist([[50,50,50], [50,50,50], [50,50,50]])
for cx in [cn00, cn32, cn64, ct00, ct32, ct64]:
self.assertEqual(cx[0][0], 1)
self.assertEqual(cf00[0][0], 50)
for cx in [cf00, cf64]:
self.assertEqual(cx[0][0], 50)
if numpy.__version__ < '2.':
self.assertEqual(cf32[0][0], 1)
self.assertEqual(cf64[0][0], 1)
else:
self.assertIsNone(cf32)
self.assertEqual(cf64[0][0], 50)

class TestUnitCell(unittest.TestCase):
def test_dummy_cell(self):
Expand Down

0 comments on commit 4dfefa1

Please sign in to comment.