Skip to content

Commit

Permalink
Fix docstring returned type issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
mairanteodoro committed Sep 28, 2023
1 parent 4751d30 commit d9f09c6
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _calculate_fiducial_from_spatial_footprint(
y_mid = (np.max(y) + np.min(y)) / 2.0
z_mid = (np.max(z) + np.min(z)) / 2.0
lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0
lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2)))
lat_fiducial = np.rad2deg(
np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2))
)
return lon_fiducial, lat_fiducial


Expand Down Expand Up @@ -132,7 +134,9 @@ def _generate_tranform(
calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity), (2, 2)
)

rotation = astmodels.AffineTransformation2D(pc, name="pc_rotation_matrix")
rotation = astmodels.AffineTransformation2D(
pc, name="pc_rotation_matrix"
)
transform = [rotation]
if sky_axes:
if not pscale:
Expand Down Expand Up @@ -175,7 +179,9 @@ def _get_axis_min_and_bounding_box(ref_model, wcs_list, ref_wcs):
((x0_lower, x0_upper), (x1_lower, x1_upper)).
"""
footprints = [w.footprint().T for w in wcs_list]
domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints])
domain_bounds = np.hstack(
[ref_wcs.backward_transform(*f) for f in footprints]
)
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T

Expand Down Expand Up @@ -333,7 +339,9 @@ def _calculate_new_wcs(
wcs_new.bounding_box = output_bounding_box

if shape is None:
shape = [int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]]
shape = [
int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]
]

wcs_new.pixel_shape = shape[::-1]
wcs_new.array_shape = shape
Expand Down Expand Up @@ -363,7 +371,9 @@ def _validate_wcs_list(wcs_list):
instance of WCS.
"""
if not isiterable(wcs_list):
raise ValueError("Expected 'wcs_list' to be an iterable of WCS objects.")
raise ValueError(
"Expected 'wcs_list' to be an iterable of WCS objects."
)
elif len(wcs_list):
if not all(isinstance(w, gwcs.WCS) for w in wcs_list):
raise TypeError(
Expand Down Expand Up @@ -460,7 +470,9 @@ def compute_scale(
spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0]
delta[spatial_idx[0]] = 1

crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T
crpix_with_offsets = np.vstack(
(crpix, crpix + delta, crpix + np.roll(delta, 1))
).T
crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)

coords = SkyCoord(
Expand Down Expand Up @@ -516,7 +528,9 @@ def compute_fiducial(wcslist: list, bounding_box=None) -> np.ndarray:
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == "SPATIAL"
spectral_axes = np.array(axes_types) == "SPECTRAL"
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
footprints = np.hstack(
[w.footprint(bounding_box=bounding_box).T for w in wcslist]
)
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]

Expand Down Expand Up @@ -715,7 +729,9 @@ def update_s_region_imaging(model, center=True):
### which means we are interested in each pixel's vertice, not its center.
### By using center=True, a difference of 0.5 pixel should be accounted for
### when comparing the world coordinates of the bounding box and the footprint.
footprint = model.meta.wcs.footprint(bbox, center=center, axis_type="spatial").T
footprint = model.meta.wcs.footprint(
bbox, center=center, axis_type="spatial"
).T
# take only imaging footprint
footprint = footprint[:2, :]

Expand Down Expand Up @@ -791,7 +807,6 @@ def reproject(wcs1, wcs2):
Returns
-------
_reproject : func
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""
Expand All @@ -809,8 +824,7 @@ def _get_forward_transform_func(wcs1):
forward_transform = wcs1.forward_transform
else:
raise TypeError(
"Expected input to be astropy.wcs.WCS or gwcs.WCS "
"object"
"Expected input to be astropy.wcs.WCS or gwcs.WCS " "object"
)
return forward_transform

Expand All @@ -821,12 +835,13 @@ def _get_backward_transform_func(wcs2):
backward_transform = wcs2.backward_transform
else:
raise TypeError(
"Expected input to be astropy.wcs.WCS or gwcs.WCS "
"object"
"Expected input to be astropy.wcs.WCS or gwcs.WCS " "object"
)
return backward_transform

def _reproject(x: Union[float, np.ndarray], y: Union[float, np.ndarray]) -> tuple:
def _reproject(
x: Union[float, np.ndarray], y: Union[float, np.ndarray]
) -> tuple:
"""
Reprojects the input coordinates from one WCS to another.
Expand Down Expand Up @@ -856,9 +871,12 @@ def _reproject(x: Union[float, np.ndarray], y: Union[float, np.ndarray]) -> tupl
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
det = np.array(_get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0))
det = np.array(
_get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0)
)
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
return tuple(det_reshaped)

return _reproject

0 comments on commit d9f09c6

Please sign in to comment.