From 9414ade317034a6485231c7b6f4bb9dab68cebbf Mon Sep 17 00:00:00 2001 From: Richard Hayes Date: Mon, 29 Jan 2024 08:46:29 +0000 Subject: [PATCH] feature/array jax (#248) Adapt to use the updated autoarray for JAX --- autolens/point/point_solver.py | 8 ++++---- test_autolens/analysis/test_result.py | 9 ++++++--- test_autolens/point/test_point_solver.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/autolens/point/point_solver.py b/autolens/point/point_solver.py index 6ae73fbc3..f8b14fda7 100644 --- a/autolens/point/point_solver.py +++ b/autolens/point/point_solver.py @@ -75,8 +75,8 @@ def grid_with_coordinates_to_mass_profile_centre_removed_from( ) grid = grid_outside_distance_mask_from( - distances_1d=distances_1d, - grid_slim=grid, + distances_1d=np.array(distances_1d), + grid_slim=np.array(grid), outside_distance=self.distance_to_mass_profile_centre, ) @@ -180,8 +180,8 @@ def grid_peaks_from(self, deflections_func, grid, source_plane_coordinate): ) grid_peaks = grid_peaks_from( - distance_1d=source_plane_distances, - grid_slim=grid, + distance_1d=np.array(source_plane_distances), + grid_slim=np.array(grid), neighbors=neighbors.astype("int"), has_neighbors=has_neighbors, ) diff --git a/test_autolens/analysis/test_result.py b/test_autolens/analysis/test_result.py index be860d8dd..f5c0046ed 100644 --- a/test_autolens/analysis/test_result.py +++ b/test_autolens/analysis/test_result.py @@ -4,6 +4,7 @@ import autofit as af import autolens as al +from autoarray import Array2D from autolens.analysis import result as res from autolens.imaging.model.result import ResultImaging @@ -366,11 +367,13 @@ def test___image_dict(analysis_imaging_7x7): ) image_dict = result.image_galaxy_dict - assert isinstance(image_dict[str(("galaxies", "lens"))], np.ndarray) - assert isinstance(image_dict[str(("galaxies", "source"))], np.ndarray) + + assert isinstance(image_dict[str(("galaxies", "lens"))], Array2D) + assert isinstance(image_dict[str(("galaxies", "source"))], Array2D) result.instance.galaxies.lens = al.Galaxy(redshift=0.5) image_dict = result.image_galaxy_dict + assert (image_dict[str(("galaxies", "lens"))].native == np.zeros((7, 7))).all() - assert isinstance(image_dict[str(("galaxies", "source"))], np.ndarray) + assert isinstance(image_dict[str(("galaxies", "source"))], Array2D) diff --git a/test_autolens/point/test_point_solver.py b/test_autolens/point/test_point_solver.py index 9d0ab7fe3..ce5f85d94 100644 --- a/test_autolens/point/test_point_solver.py +++ b/test_autolens/point/test_point_solver.py @@ -666,7 +666,7 @@ def test__simple_arrays(self): peaks_coordinates = pos.grid_peaks_from( distance_1d=distance_1d, - grid_slim=grid_slim, + grid_slim=np.array(grid_slim), neighbors=neighbors_1d.astype("int"), has_neighbors=has_neighbors, )