Skip to content

Commit 1369689

Browse files
authored
Homogenize return type of Lattice.get_points_in_sphere to always be np.array(s) (#3797)
* homogenize return type of Lattice.get_points_in_sphere to always be np.arrays * test return type is always tuple and 4 * np.array if zip_results=False whether or not points in sphere are found
1 parent 9a3f714 commit 1369689

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

pymatgen/core/lattice.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,8 +1298,7 @@ def get_points_in_sphere(
12981298
zip_results=True,
12991299
) -> list[tuple[np.ndarray, float, int, np.ndarray]] | tuple[np.ndarray, ...] | list:
13001300
"""Find all points within a sphere from the point taking into account
1301-
periodic boundary conditions. This includes sites in other periodic
1302-
images.
1301+
periodic boundary conditions. This includes sites in other periodic images.
13031302
13041303
Algorithm:
13051304
@@ -1342,10 +1341,12 @@ def get_points_in_sphere(
13421341
all_coords=cart_coords, center_coords=center_coords, r=float(r), pbc=pbc, lattice=latt_matrix, tol=1e-8
13431342
)
13441343
if len(indices) < 1:
1345-
return [] if zip_results else [()] * 4
1344+
# return empty np.array (not list or tuple) to ensure consistent return type
1345+
# whether sphere contains points or not
1346+
return np.array([]) if zip_results else tuple(np.array([]) for _ in range(4))
13461347
frac_coords = frac_points[indices] + images
13471348
if zip_results:
1348-
return list(zip(frac_coords, distances, indices, images))
1349+
return tuple(zip(frac_coords, distances, indices, images))
13491350
return frac_coords, distances, indices, images
13501351

13511352
def get_points_in_sphere_py(

tests/core/test_lattice.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -352,26 +352,40 @@ def test_get_points_in_sphere(self):
352352
# This is a non-niggli representation of a cubic lattice
353353
lattice = Lattice([[1, 5, 0], [0, 1, 0], [5, 0, 1]])
354354
# evenly spaced points array between 0 and 1
355-
pts = np.array(list(itertools.product(range(5), repeat=3))) / 5
356-
pts = lattice.get_fractional_coords(pts)
355+
points = np.array(list(itertools.product(range(5), repeat=3))) / 5
356+
points = lattice.get_fractional_coords(points)
357357

358358
# Test getting neighbors within 1 neighbor distance of the origin
359-
fcoords, dists, inds, images = lattice.get_points_in_sphere(pts, [0, 0, 0], 0.20001, zip_results=False)
360-
assert len(fcoords) == 7 # There are 7 neighbors
359+
frac_coords, dists, indices, images = lattice.get_points_in_sphere(
360+
points, [0, 0, 0], 0.20001, zip_results=False
361+
)
362+
assert len(frac_coords) == 7 # There are 7 neighbors
361363
assert np.isclose(dists, 0.2).sum() == 6 # 6 are at 0.2
362364
assert np.isclose(dists, 0).sum() == 1 # 1 is at 0
363-
assert len(set(inds)) == 7 # They have unique indices
364-
assert_array_equal(images[np.isclose(dists, 0)], [[0, 0, 0]])
365+
assert len(set(indices)) == 7 # They have unique indices
366+
assert images[np.isclose(dists, 0)].tolist() == [[0, 0, 0]]
365367

366368
# More complicated case, using the zip output
367-
result = lattice.get_points_in_sphere(pts, [0.5, 0.5, 0.5], 1.0001)
369+
result = lattice.get_points_in_sphere(points, [0.5, 0.5, 0.5], 1.0001)
370+
assert isinstance(result, tuple)
368371
assert len(result) == 552
369-
assert len(result[0]) == 4 # coords, dists, ind, supercell
372+
assert len(result[0]) == 4 # coords, dists, indices, supercell
370373

371374
# test pbc
372375
latt_pbc = Lattice([[1, 5, 0], [0, 1, 0], [5, 0, 1]], pbc=(True, True, False))
373-
fcoords, dists, inds, images = latt_pbc.get_points_in_sphere(pts, [0, 0, 0], 0.20001, zip_results=False)
374-
assert len(fcoords) == 6
376+
frac_coords, dists, indices, images = latt_pbc.get_points_in_sphere(
377+
points, [0, 0, 0], 0.20001, zip_results=False
378+
)
379+
assert len(frac_coords) == 6
380+
381+
# ensure consistent return type if zip_results=False and no points in sphere are found
382+
# https://github.com/materialsproject/pymatgen/issues/3794
383+
result = lattice.get_points_in_sphere(points, [0.5, 0.5, 0.5], 0.0001, zip_results=False)
384+
assert isinstance(result, tuple)
385+
assert len(result) == 4
386+
assert all(len(arr) == 0 for arr in result)
387+
types = {*map(type, result)}
388+
assert types == {np.ndarray}, f"Expected only np.ndarray, got {types}"
375389

376390
def test_get_all_distances(self):
377391
fcoords = np.array(

0 commit comments

Comments
 (0)