Skip to content

Commit

Permalink
fix use of adapt data in hilbert.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Dec 20, 2023
1 parent 5768145 commit 5752dcd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 70 deletions.
88 changes: 38 additions & 50 deletions autoarray/inversion/pixelization/image_mesh/hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,20 @@ def generate2d(x, y, ax, ay, bx, by):


def super_resolution_grid_from(img_2d, mask, mask_radius, pixel_scales, sub_scale=11):
'''
This function will create a higher resolution grid for the img_2d. The new grid and its
interpolated values will be used to generate a sparse image grid.
"""
This function will create a higher resolution grid for the img_2d. The new grid and its
interpolated values will be used to generate a sparse image grid.
img_2d: the hyper image in 2d (e.g. hyper_source_model_image.native)
mask: the mask used for the fitting.
mask_radius: the circular mask radius. Currently, the code only works with a circular mask.
sub_scale: oversampling scale for each image pixel.
'''
img_2d: the hyper image in 2d (e.g. hyper_source_model_image.native)
mask: the mask used for the fitting.
mask_radius: the circular mask radius. Currently, the code only works with a circular mask.
sub_scale: oversampling scale for each image pixel.
"""

shape_nnn = np.shape(mask)[0]

grid = Grid2D.uniform(
shape_native=(shape_nnn, shape_nnn),
pixel_scales=pixel_scales,
sub_size=1
shape_native=(shape_nnn, shape_nnn), pixel_scales=pixel_scales, sub_size=1
)

new_mask = Mask2D.circular(
Expand All @@ -112,23 +110,19 @@ def super_resolution_grid_from(img_2d, mask, mask_radius, pixel_scales, sub_scal
new_grid = Grid2D.from_mask(new_mask)

new_img = griddata(
points=grid,
values=img_2d.ravel(),
xi=new_grid,
fill_value=0.0,
method='linear'
points=grid, values=img_2d.ravel(), xi=new_grid, fill_value=0.0, method="linear"
)

return new_img, new_grid


def grid_hilbert_order_from(length, mask_radius):
'''
This function will create a grid in the Hilbert space-filling curve order.
"""
This function will create a grid in the Hilbert space-filling curve order.
length: the size of the square grid.
mask_radius: the circular mask radius. This code only works with a circular mask.
'''
length: the size of the square grid.
mask_radius: the circular mask radius. This code only works with a circular mask.
"""

xy_generator = gilbert2d(length, length)

Expand All @@ -155,55 +149,52 @@ def grid_hilbert_order_from(length, mask_radius):


def image_and_grid_from(image, mask, mask_radius, pixel_scales, hilbert_length):
'''
This code will create a grid in Hilbert space-filling curve order and an interpolated hyper
image associated to that grid.
'''
"""
This code will create a grid in Hilbert space-filling curve order and an interpolated hyper
image associated to that grid.
"""

shape_nnn = np.shape(mask)[0]

grid = Grid2D.uniform(
shape_native=(shape_nnn, shape_nnn),
pixel_scales=pixel_scales,
sub_size=1
)
shape_native=(shape_nnn, shape_nnn), pixel_scales=pixel_scales, sub_size=1
)

x1d_hb, y1d_hb = grid_hilbert_order_from(
length=hilbert_length,
mask_radius=mask_radius
)
length=hilbert_length, mask_radius=mask_radius
)

grid_hb = np.stack((y1d_hb, x1d_hb), axis=-1)
grid_hb_radius = np.sqrt(grid_hb[:, 0]**2.0 + grid_hb[:, 1]**2.0)
grid_hb_radius = np.sqrt(grid_hb[:, 0] ** 2.0 + grid_hb[:, 1] ** 2.0)
new_grid = grid_hb[grid_hb_radius <= mask_radius]

new_img = griddata(
points=grid,
values=image.ravel(),
values=image.native.ravel(),
xi=new_grid,
fill_value=0.0,
method='linear'
)
method="linear",
)

return new_img, new_grid


def inverse_transform_sampling_interpolated(probabilities, n_samples, gridx, gridy):
'''
Given a 1d cumulative probability function, this code will generate points following the
probability distribution.
"""
Given a 1d cumulative probability function, this code will generate points following the
probability distribution.
probabilities: 1D normalized cumulative probablity curve.
n_samples: the number of points to draw.
'''
probabilities: 1D normalized cumulative probablity curve.
n_samples: the number of points to draw.
"""

cdf = np.cumsum(probabilities)
npixels = len(probabilities)
id_range = np.arange(0, npixels)
cdf[0] = 0.0
intp_func = interp1d(cdf, id_range, kind='linear')
intp_func_x = interp1d(id_range, gridx, kind='linear')
intp_func_y = interp1d(id_range, gridy, kind='linear')
intp_func = interp1d(cdf, id_range, kind="linear")
intp_func_x = interp1d(id_range, gridx, kind="linear")
intp_func_y = interp1d(id_range, gridy, kind="linear")
linear_points = np.linspace(0, 0.99999999, n_samples)
output_ids = intp_func(linear_points)
output_x = intp_func_x(output_ids)
Expand Down Expand Up @@ -277,7 +268,7 @@ def weight_map_from(self, adapt_data: np.ndarray):
# np.max(adapt_data) - np.min(adapt_data)
# ) + self.weight_floor * np.max(adapt_data)

# return np.power(weight_map, self.weight_power)
# return np.power(weight_map, self.weight_power)

weight_map = (np.abs(adapt_data) + self.weight_floor) ** self.weight_power
weight_map /= np.sum(weight_map)
Expand Down Expand Up @@ -305,7 +296,6 @@ def image_plane_mesh_grid_from(
"""

if not grid.mask.is_circular:

raise exc.PixelizationException(
"""
Hilbert image-mesh has been called but the input grid does not use a circular mask.
Expand Down Expand Up @@ -335,9 +325,7 @@ def image_plane_mesh_grid_from(
gridy=grid_hb[:, 0],
)

return Grid2DIrregular(
values=np.stack((drawn_y, drawn_x), axis=-1)
)
return Grid2DIrregular(values=np.stack((drawn_y, drawn_x), axis=-1))

@property
def is_stochastic(self):
Expand Down
15 changes: 10 additions & 5 deletions autoarray/mask/mask_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def zoom_mask_unmasked(self) -> "Mask2D":
)

@property
def is_circular(self)-> bool:
def is_circular(self) -> bool:
"""
Returns whether the mask is circular or not.
Expand All @@ -1055,15 +1055,16 @@ def is_circular(self)-> bool:
"""

if self.pixel_scales[0] != self.pixel_scales[1]:

raise exc.MaskException(
"""
The is_circular function cannot be called for a mask with different pixel scales in each dimension
(e.g. it does not support rectangular masks.
"""
)

pixel_coordinates_2d = self.geometry.pixel_coordinates_2d_from(scaled_coordinates_2d=self.mask_centre)
pixel_coordinates_2d = self.geometry.pixel_coordinates_2d_from(
scaled_coordinates_2d=self.mask_centre
)

central_row_pixels = sum(np.invert(self[pixel_coordinates_2d[0], :]))
central_column_pixels = sum(np.invert(self[:, pixel_coordinates_2d[1]]))
Expand Down Expand Up @@ -1098,8 +1099,12 @@ def circular_radius(self) -> float:
"""
)

pixel_coordinates_2d = self.geometry.pixel_coordinates_2d_from(scaled_coordinates_2d=self.mask_centre)
# print("aaa")

pixel_coordinates_2d = self.geometry.pixel_coordinates_2d_from(
scaled_coordinates_2d=self.mask_centre
)

central_row_pixels = sum(np.invert(self[pixel_coordinates_2d[0], :]))

return central_row_pixels * self.pixel_scales[0] / 2.0
return central_row_pixels * self.pixel_scales[0] / 2.0
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@


def test__image_plane_mesh_grid_from():

mask = aa.Mask2D.circular(
shape_native=(4,4),
shape_native=(4, 4),
radius=2.0,
pixel_scales=1.0,
sub_size=1,
)

grid = aa.Grid2D.from_mask(mask=mask)

adapt_data = np.ones(shape=mask.shape_native)
adapt_data = aa.Array2D.ones(
shape_native=mask.shape_native,
pixel_scales=1.0,
)

kmeans = aa.image_mesh.Hilbert(pixels=8)
image_mesh = kmeans.image_plane_mesh_grid_from(grid=grid, adapt_data=adapt_data)
Expand Down
27 changes: 15 additions & 12 deletions test_autoarray/mask/test_mask_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ def test__mask_centre():


def test__is_circular():

mask = np.array(
[
[True, True, True, True],
Expand All @@ -883,34 +882,38 @@ def test__is_circular():

assert mask.is_circular == False

mask = aa.Mask2D.circular(shape_native=(5,5), radius=1.0, pixel_scales=(1.0, 1.0))
mask = aa.Mask2D.circular(shape_native=(5, 5), radius=1.0, pixel_scales=(1.0, 1.0))

assert mask.is_circular == True

mask = aa.Mask2D.circular(shape_native=(10,10), radius=3.0, pixel_scales=(1.0, 1.0))
mask = aa.Mask2D.circular(
shape_native=(10, 10), radius=3.0, pixel_scales=(1.0, 1.0)
)

assert mask.is_circular == True

mask = aa.Mask2D.circular(shape_native=(10,10), radius=4.0, pixel_scales=(1.0, 1.0))
mask = aa.Mask2D.circular(
shape_native=(10, 10), radius=4.0, pixel_scales=(1.0, 1.0)
)

assert mask.is_circular == True

def test__circular_radius():

def test__circular_radius():
mask = aa.Mask2D.circular(
shape_native=(10, 10),
radius=3.0, pixel_scales=(1.0, 1.0))
shape_native=(10, 10), radius=3.0, pixel_scales=(1.0, 1.0)
)

assert mask.circular_radius == pytest.approx(3.0, 1e-4)

mask = aa.Mask2D.circular(
shape_native=(30, 30),
radius=5.5, pixel_scales=(0.5, 0.5))
shape_native=(30, 30), radius=5.5, pixel_scales=(0.5, 0.5)
)

assert mask.circular_radius == pytest.approx(5.5, 1e-4)

mask = aa.Mask2D.circular(
shape_native=(30, 30),
radius=5.75, pixel_scales=(0.5, 0.5))
shape_native=(30, 30), radius=5.75, pixel_scales=(0.5, 0.5)
)

assert mask.circular_radius == pytest.approx(5.5, 1e-4)
assert mask.circular_radius == pytest.approx(5.5, 1e-4)

0 comments on commit 5752dcd

Please sign in to comment.