From a2a66c95c60ea868fe29d2d06a9167cce8fe4ed9 Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Tue, 30 Jan 2024 03:28:42 -0900 Subject: [PATCH] Add `inplace` argument to all potential functions (`Raster` to modified `Raster`, same for `Vector`) (#455) --- geoutils/raster/raster.py | 214 +++++++++++++++++++++++++++++++++++--- geoutils/vector.py | 66 ++++++++++-- tests/test_raster.py | 31 +++++- tests/test_vector.py | 5 + 4 files changed, 290 insertions(+), 26 deletions(-) diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index 3b2453ca..3d93f680 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -1343,7 +1343,7 @@ def astype(self, dtype: DTypeLike, inplace: bool = False) -> Raster | None: :param dtype: Any numpy dtype or string accepted by numpy.astype. :param inplace: Whether to modify the raster in-place. - :returns: Raster with updated dtype. + :returns: Raster with updated dtype (or None if inplace). """ # Check that dtype is supported by rasterio if not rio.dtypes.check_dtype(dtype): @@ -1979,7 +1979,7 @@ def crop( crop_geom: RasterType | Vector | list[float] | tuple[float, ...], mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel", *, - inplace: Literal[False] = ..., + inplace: Literal[False] = False, ) -> RasterType: ... @@ -1999,7 +1999,7 @@ def crop( crop_geom: RasterType | Vector | list[float] | tuple[float, ...], mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel", *, - inplace: bool = ..., + inplace: bool = False, ) -> RasterType | None: ... @@ -2025,7 +2025,7 @@ def crop( will match the extent exactly, adjusting the pixel resolution to fit the extent. :param inplace: Whether to update the raster in-place. - :returns: A new raster, or None if cropping in-place. + :returns: A new raster (or None if inplace). """ assert mode in [ "match_extent", @@ -2109,6 +2109,7 @@ def crop( newraster.tags["AREA_OR_POINT"] = "Area" return newraster + @overload def reproject( self: RasterType, ref: RasterType | str | None = None, @@ -2120,10 +2121,70 @@ def reproject( dtype: DTypeLike | None = None, resampling: Resampling | str = Resampling.bilinear, force_source_nodata: int | float | None = None, + *, + inplace: Literal[False] = False, silent: bool = False, n_threads: int = 0, memory_limit: int = 64, ) -> RasterType: + ... + + @overload + def reproject( + self: RasterType, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.bilinear, + force_source_nodata: int | float | None = None, + *, + inplace: Literal[True], + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> None: + ... + + @overload + def reproject( + self: RasterType, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.bilinear, + force_source_nodata: int | float | None = None, + *, + inplace: bool = False, + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> RasterType | None: + ... + + def reproject( + self: RasterType, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.bilinear, + force_source_nodata: int | float | None = None, + inplace: bool = False, + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> RasterType | None: """ Reproject raster to a different geotransform (resolution, bounds) and/or coordinate reference system (CRS). @@ -2148,12 +2209,13 @@ def reproject( :param resampling: A Rasterio resampling method, can be passed as a string. See https://rasterio.readthedocs.io/en/stable/api/rasterio.enums.html#rasterio.enums.Resampling for the full list. + :param inplace: Whether to update the raster in-place. :param force_source_nodata: Force a source nodata value (read from the metadata by default). :param silent: Whether to print warning statements. :param n_threads: Number of threads. Defaults to (os.cpu_count() - 1). :param memory_limit: Memory limit in MB for warp operations. Larger values may perform better. - :returns: Reprojected raster. + :returns: Reprojected raster (or None if inplace). """ # --- Sanity checks on inputs and defaults -- # @@ -2328,21 +2390,69 @@ def reproject( assert transform == transformed # Write results to a new Raster. - r = self.from_array(data, transformed, crs, nodata) + if inplace: + # Order is important here, because calling self.data will use nodata to mask the array properly + self._crs = crs + self._nodata = nodata + self._transform = transform + # A little trick to force the right shape of data in, then update the mask properly through the data setter + self._data = data.squeeze() + self.data = data + return None + else: + return self.from_array(data, transformed, crs, nodata) - return r + @overload + def shift( + self: RasterType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced", + *, + inplace: Literal[False] = False, + ) -> RasterType: + ... + @overload def shift( - self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced" + self: RasterType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced", + *, + inplace: Literal[True], ) -> None: + ... + + @overload + def shift( + self: RasterType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced", + *, + inplace: bool = False, + ) -> RasterType | None: + ... + + def shift( + self: RasterType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced", + inplace: bool = False, + ) -> RasterType | None: """ - Shift the raster by a (x,y) offset. + Shift a raster by a (x,y) offset. The shifting only updates the geotransform (no resampling is performed). :param xoff: Translation x offset. :param yoff: Translation y offset. :param distance_unit: Distance unit, either 'georeferenced' (default) or 'pixel'. + :param inplace: Whether to modify the raster in-place. + + :returns: Shifted raster (or None if inplace). """ if distance_unit not in ["georeferenced", "pixel"]: raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.") @@ -2355,8 +2465,16 @@ def shift( xoff *= self.res[0] yoff *= self.res[1] - # Overwrite transform by shifted transform - self.transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff) + shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff) + + if inplace: + # Overwrite transform by shifted transform + self.transform = shifted_transform + return None + else: + raster_copy = self.copy() + raster_copy.transform = shifted_transform + return raster_copy def save( self, @@ -3523,6 +3641,7 @@ def _repr_html_(self) -> str: return str(s) + @overload def reproject( self: Mask, ref: RasterType | str | None = None, @@ -3534,10 +3653,70 @@ def reproject( dtype: DTypeLike | None = None, resampling: Resampling | str = Resampling.nearest, force_source_nodata: int | float | None = None, + *, + inplace: Literal[False] = False, silent: bool = False, n_threads: int = 0, memory_limit: int = 64, ) -> Mask: + ... + + @overload + def reproject( + self: Mask, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.nearest, + force_source_nodata: int | float | None = None, + *, + inplace: Literal[True], + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> None: + ... + + @overload + def reproject( + self: Mask, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.nearest, + force_source_nodata: int | float | None = None, + *, + inplace: bool = False, + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> Mask | None: + ... + + def reproject( + self: Mask, + ref: RasterType | str | None = None, + crs: CRS | str | int | None = None, + res: float | abc.Iterable[float] | None = None, + grid_size: tuple[int, int] | None = None, + bounds: dict[str, float] | rio.coords.BoundingBox | None = None, + nodata: int | float | None = None, + dtype: DTypeLike | None = None, + resampling: Resampling | str = Resampling.nearest, + force_source_nodata: int | float | None = None, + inplace: bool = False, + silent: bool = False, + n_threads: int = 0, + memory_limit: int = 64, + ) -> Mask | None: # Depending on resampling, adjust to rasterio supported types if resampling in [Resampling.nearest, "nearest"]: self._data = self.data.astype("uint8") # type: ignore @@ -3558,6 +3737,7 @@ def reproject( nodata=nodata, dtype=dtype, resampling=resampling, + inplace=False, force_source_nodata=force_source_nodata, silent=silent, n_threads=n_threads, @@ -3567,7 +3747,13 @@ def reproject( # Transform back to a boolean array output._data = output.data.astype(bool) # type: ignore - return output + if inplace: + self._transform = output._transform # type: ignore + self._crs = output._crs # type: ignore + self.data = output._data # type: ignore + return None + else: + return output # Note the star is needed because of the default argument 'mode' preceding non default arg 'inplace' # Then the final overload must be duplicated @@ -3577,7 +3763,7 @@ def crop( crop_geom: Mask | Vector | list[float] | tuple[float, ...], mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel", *, - inplace: Literal[False] = ..., + inplace: Literal[False] = False, ) -> Mask: ... @@ -3597,7 +3783,7 @@ def crop( crop_geom: Mask | Vector | list[float] | tuple[float, ...], mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel", *, - inplace: bool = ..., + inplace: bool = False, ) -> Mask | None: ... diff --git a/geoutils/vector.py b/geoutils/vector.py index a4a331fd..67e3547c 100644 --- a/geoutils/vector.py +++ b/geoutils/vector.py @@ -934,10 +934,20 @@ def ds(self, new_ds: gpd.GeoDataFrame | gpd.GeoSeries) -> None: else: raise ValueError("The dataset of a vector must be set with a GeoSeries or a GeoDataFrame.") - def vector_equal(self, other: gu.Vector) -> bool: - """Check if two vectors are equal.""" + def vector_equal(self, other: gu.Vector, **kwargs: Any) -> bool: + """ + Check if two vectors are equal. + + Keyword arguments are passed to geopandas.assert_geodataframe_equal. + """ - return assert_geodataframe_equal(self.ds, other.ds) + try: + assert_geodataframe_equal(self.ds, other.ds, **kwargs) + vector_eq = True + except AssertionError: + vector_eq = False + + return vector_eq @property def name(self) -> str | None: @@ -965,7 +975,7 @@ def crop( crop_geom: gu.Raster | Vector | list[float] | tuple[float, ...], clip: bool, *, - inplace: Literal[False] = ..., + inplace: Literal[False] = False, ) -> VectorType: ... @@ -985,7 +995,7 @@ def crop( crop_geom: gu.Raster | Vector | list[float] | tuple[float, ...], clip: bool, *, - inplace: bool = ..., + inplace: bool = False, ) -> VectorType | None: ... @@ -1009,7 +1019,9 @@ def crop( coordinates. If ``crop_geom`` is a raster or a vector, will crop to the bounds. If ``crop_geom`` is a list of coordinates, the order is assumed to be [xmin, ymin, xmax, ymax]. :param clip: Whether to clip the geometry to the given extent (by default keeps all intersecting). - :param inplace: Update the vector in-place or return copy. + :param inplace: Whether to update the vector in-place. + + :returns: Cropped vector (or None if inplace). """ if isinstance(crop_geom, (gu.Raster, Vector)): # For another Vector or Raster, we reproject the bounding box in the same CRS as self @@ -1033,11 +1045,42 @@ def crop( new_vector._ds = new_vector.ds.clip(mask=(xmin, ymin, xmax, ymax)) return new_vector + @overload def reproject( self: Vector, ref: gu.Raster | rio.io.DatasetReader | VectorType | gpd.GeoDataFrame | str | None = None, crs: CRS | str | int | None = None, + *, + inplace: Literal[False] = False, ) -> Vector: + ... + + @overload + def reproject( + self: Vector, + ref: gu.Raster | rio.io.DatasetReader | VectorType | gpd.GeoDataFrame | str | None = None, + crs: CRS | str | int | None = None, + *, + inplace: Literal[True], + ) -> None: + ... + + @overload + def reproject( + self: Vector, + ref: gu.Raster | rio.io.DatasetReader | VectorType | gpd.GeoDataFrame | str | None = None, + crs: CRS | str | int | None = None, + *, + inplace: bool = False, + ) -> Vector | None: + ... + + def reproject( + self: Vector, + ref: gu.Raster | rio.io.DatasetReader | VectorType | gpd.GeoDataFrame | str | None = None, + crs: CRS | str | int | None = None, + inplace: bool = False, + ) -> Vector | None: """ Reproject vector to a specified coordinate reference system. @@ -1051,8 +1094,9 @@ def reproject( Can be provided as a raster, vector, Rasterio dataset, GeoPandas dataframe, or path to the file. :param crs: Specify the Coordinate Reference System or EPSG to reproject to. If dst_ref not set, defaults to self.crs. + :param inplace: Whether to update the vector in-place. - :returns: Reprojected vector. + :returns: Reprojected vector (or None if inplace). """ # Check that either ref or crs is provided @@ -1086,7 +1130,13 @@ def reproject( # Determine user-input target CRS crs = CRS.from_user_input(crs) - return Vector(self.ds.to_crs(crs=crs)) + new_ds = self.ds.to_crs(crs=crs) + + if inplace: + self.ds = new_ds + return None + else: + return Vector(new_ds) @overload def create_mask( diff --git a/tests/test_raster.py b/tests/test_raster.py index 10b31dd5..859cd1bd 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -985,7 +985,7 @@ def test_getitem_setitem(self, example: str) -> None: rst[arr[:-1, :-1]] # An error when the georeferencing of the Mask does not match - mask.shift(1, 1) + mask.shift(1, 1, inplace=True) with pytest.raises(ValueError, match=re.escape(message_raster.format(op_name_index))): rst[mask] @@ -1215,7 +1215,14 @@ def test_shift(self, example: str) -> None: orig_bounds = r.bounds # Shift raster by georeferenced units (default) - r.shift(xoff=1, yoff=1) + # Check the default behaviour is not inplace + r_notinplace = r.shift(xoff=1, yoff=1) + assert isinstance(r_notinplace, gu.Raster) + + # Check inplace + r.shift(xoff=1, yoff=1, inplace=True) + # Both shifts should have yielded the same transform + assert r.transform == r_notinplace.transform # Only bounds should change assert orig_transform.c + 1 == r.transform.c @@ -1232,7 +1239,7 @@ def test_shift(self, example: str) -> None: orig_transform = r.transform orig_bounds = r.bounds orig_res = r.res - r.shift(xoff=1, yoff=1, distance_unit="pixel") + r.shift(xoff=1, yoff=1, distance_unit="pixel", inplace=True) # Only bounds should change assert orig_transform.c + 1 * orig_res[0] == r.transform.c @@ -1502,7 +1509,7 @@ def test_reproject(self, example: str) -> None: plt.show() - # - Check that if mask is modified afterwards, it is taken into account during reproject - # + # -- Check that if mask is modified afterwards, it is taken into account during reproject -- # # Create a raster with (additional) random gaps r_gaps = r.copy() nsamples = 200 @@ -1529,6 +1536,22 @@ def test_reproject(self, example: str) -> None: r3 = r_nodata.reproject(r2) assert r_nodata.nodata == r3.nodata + # -- Check inplace behaviour works -- # + + # Check when transform is updated (via res) + r_tmp_res = r.copy() + r_res = r_tmp_res.reproject(res=r.res[0] / 2) + r_tmp_res.reproject(res=r.res[0] / 2, inplace=True) + + assert r_res.raster_equal(r_tmp_res) + + # Check when CRS is updated + r_tmp_crs = r.copy() + r_crs = r_tmp_crs.reproject(crs=out_crs) + r_tmp_crs.reproject(crs=out_crs, inplace=True) + + assert r_crs.raster_equal(r_tmp_crs) + # -- Test additional errors raised for argument combinations -- # # If both ref and crs are set diff --git a/tests/test_vector.py b/tests/test_vector.py index 79ce97d3..51b18ab9 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -118,6 +118,11 @@ def test_reproject(self) -> None: assert isinstance(v1, gu.Vector) assert v1.crs.to_epsg() == 32617 + # Check the inplace behaviour matches the not-inplace one + v2 = v0.copy() + v2.reproject(crs=32617, inplace=True) + v2.vector_equal(v1) + # Check that the reprojection is the same as with geopandas gpd1 = v0.ds.to_crs(epsg=32617) assert_geodataframe_equal(gpd1, v1.ds)