Skip to content

Commit

Permalink
feature/array jax (#248)
Browse files Browse the repository at this point in the history
Adapt to use the updated autoarray for JAX
  • Loading branch information
rhayes777 authored Jan 29, 2024
1 parent e3d3380 commit 9414ade
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions autolens/point/point_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def grid_with_coordinates_to_mass_profile_centre_removed_from(
)

grid = grid_outside_distance_mask_from(
distances_1d=distances_1d,
grid_slim=grid,
distances_1d=np.array(distances_1d),
grid_slim=np.array(grid),
outside_distance=self.distance_to_mass_profile_centre,
)

Expand Down Expand Up @@ -180,8 +180,8 @@ def grid_peaks_from(self, deflections_func, grid, source_plane_coordinate):
)

grid_peaks = grid_peaks_from(
distance_1d=source_plane_distances,
grid_slim=grid,
distance_1d=np.array(source_plane_distances),
grid_slim=np.array(grid),
neighbors=neighbors.astype("int"),
has_neighbors=has_neighbors,
)
Expand Down
9 changes: 6 additions & 3 deletions test_autolens/analysis/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import autofit as af
import autolens as al
from autoarray import Array2D

from autolens.analysis import result as res
from autolens.imaging.model.result import ResultImaging
Expand Down Expand Up @@ -366,11 +367,13 @@ def test___image_dict(analysis_imaging_7x7):
)

image_dict = result.image_galaxy_dict
assert isinstance(image_dict[str(("galaxies", "lens"))], np.ndarray)
assert isinstance(image_dict[str(("galaxies", "source"))], np.ndarray)

assert isinstance(image_dict[str(("galaxies", "lens"))], Array2D)
assert isinstance(image_dict[str(("galaxies", "source"))], Array2D)

result.instance.galaxies.lens = al.Galaxy(redshift=0.5)

image_dict = result.image_galaxy_dict

assert (image_dict[str(("galaxies", "lens"))].native == np.zeros((7, 7))).all()
assert isinstance(image_dict[str(("galaxies", "source"))], np.ndarray)
assert isinstance(image_dict[str(("galaxies", "source"))], Array2D)
2 changes: 1 addition & 1 deletion test_autolens/point/test_point_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def test__simple_arrays(self):

peaks_coordinates = pos.grid_peaks_from(
distance_1d=distance_1d,
grid_slim=grid_slim,
grid_slim=np.array(grid_slim),
neighbors=neighbors_1d.astype("int"),
has_neighbors=has_neighbors,
)
Expand Down

0 comments on commit 9414ade

Please sign in to comment.