From d9f09c6f451fed708271075b4796e2346a2b3a47 Mon Sep 17 00:00:00 2001 From: "M. Teodoro" Date: Thu, 28 Sep 2023 10:01:09 -0400 Subject: [PATCH] Fix docstring returned type issue. --- src/stcal/alignment/util.py | 48 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/stcal/alignment/util.py b/src/stcal/alignment/util.py index 4042467f..71f7d87a 100644 --- a/src/stcal/alignment/util.py +++ b/src/stcal/alignment/util.py @@ -66,7 +66,9 @@ def _calculate_fiducial_from_spatial_footprint( y_mid = (np.max(y) + np.min(y)) / 2.0 z_mid = (np.max(z) + np.min(z)) / 2.0 lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0 - lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2))) + lat_fiducial = np.rad2deg( + np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2)) + ) return lon_fiducial, lat_fiducial @@ -132,7 +134,9 @@ def _generate_tranform( calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity), (2, 2) ) - rotation = astmodels.AffineTransformation2D(pc, name="pc_rotation_matrix") + rotation = astmodels.AffineTransformation2D( + pc, name="pc_rotation_matrix" + ) transform = [rotation] if sky_axes: if not pscale: @@ -175,7 +179,9 @@ def _get_axis_min_and_bounding_box(ref_model, wcs_list, ref_wcs): ((x0_lower, x0_upper), (x1_lower, x1_upper)). """ footprints = [w.footprint().T for w in wcs_list] - domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints]) + domain_bounds = np.hstack( + [ref_wcs.backward_transform(*f) for f in footprints] + ) axis_min_values = np.min(domain_bounds, axis=1) domain_bounds = (domain_bounds.T - axis_min_values).T @@ -333,7 +339,9 @@ def _calculate_new_wcs( wcs_new.bounding_box = output_bounding_box if shape is None: - shape = [int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]] + shape = [ + int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1] + ] wcs_new.pixel_shape = shape[::-1] wcs_new.array_shape = shape @@ -363,7 +371,9 @@ def _validate_wcs_list(wcs_list): instance of WCS. """ if not isiterable(wcs_list): - raise ValueError("Expected 'wcs_list' to be an iterable of WCS objects.") + raise ValueError( + "Expected 'wcs_list' to be an iterable of WCS objects." + ) elif len(wcs_list): if not all(isinstance(w, gwcs.WCS) for w in wcs_list): raise TypeError( @@ -460,7 +470,9 @@ def compute_scale( spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0] delta[spatial_idx[0]] = 1 - crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T + crpix_with_offsets = np.vstack( + (crpix, crpix + delta, crpix + np.roll(delta, 1)) + ).T crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False) coords = SkyCoord( @@ -516,7 +528,9 @@ def compute_fiducial(wcslist: list, bounding_box=None) -> np.ndarray: axes_types = wcslist[0].output_frame.axes_type spatial_axes = np.array(axes_types) == "SPATIAL" spectral_axes = np.array(axes_types) == "SPECTRAL" - footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist]) + footprints = np.hstack( + [w.footprint(bounding_box=bounding_box).T for w in wcslist] + ) spatial_footprint = footprints[spatial_axes] spectral_footprint = footprints[spectral_axes] @@ -715,7 +729,9 @@ def update_s_region_imaging(model, center=True): ### which means we are interested in each pixel's vertice, not its center. ### By using center=True, a difference of 0.5 pixel should be accounted for ### when comparing the world coordinates of the bounding box and the footprint. - footprint = model.meta.wcs.footprint(bbox, center=center, axis_type="spatial").T + footprint = model.meta.wcs.footprint( + bbox, center=center, axis_type="spatial" + ).T # take only imaging footprint footprint = footprint[:2, :] @@ -791,7 +807,6 @@ def reproject(wcs1, wcs2): Returns ------- - _reproject : func Function to compute the transformations. It takes x, y positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ @@ -809,8 +824,7 @@ def _get_forward_transform_func(wcs1): forward_transform = wcs1.forward_transform else: raise TypeError( - "Expected input to be astropy.wcs.WCS or gwcs.WCS " - "object" + "Expected input to be astropy.wcs.WCS or gwcs.WCS " "object" ) return forward_transform @@ -821,12 +835,13 @@ def _get_backward_transform_func(wcs2): backward_transform = wcs2.backward_transform else: raise TypeError( - "Expected input to be astropy.wcs.WCS or gwcs.WCS " - "object" + "Expected input to be astropy.wcs.WCS or gwcs.WCS " "object" ) return backward_transform - def _reproject(x: Union[float, np.ndarray], y: Union[float, np.ndarray]) -> tuple: + def _reproject( + x: Union[float, np.ndarray], y: Union[float, np.ndarray] + ) -> tuple: """ Reprojects the input coordinates from one WCS to another. @@ -856,9 +871,12 @@ def _reproject(x: Union[float, np.ndarray], y: Union[float, np.ndarray]) -> tupl flat_sky = [] for axis in sky: flat_sky.append(axis.flatten()) - det = np.array(_get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0)) + det = np.array( + _get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0) + ) det_reshaped = [] for axis in det: det_reshaped.append(axis.reshape(x.shape)) return tuple(det_reshaped) + return _reproject