Skip to content

Commit

Permalink
Add inplace argument to all potential functions (Raster to modifi…
Browse files Browse the repository at this point in the history
…ed `Raster`, same for `Vector`) (#455)
  • Loading branch information
rhugonnet authored Jan 30, 2024
1 parent fe8795c commit a2a66c9
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 26 deletions.
214 changes: 200 additions & 14 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ def astype(self, dtype: DTypeLike, inplace: bool = False) -> Raster | None:
:param dtype: Any numpy dtype or string accepted by numpy.astype.
:param inplace: Whether to modify the raster in-place.
:returns: Raster with updated dtype.
:returns: Raster with updated dtype (or None if inplace).
"""
# Check that dtype is supported by rasterio
if not rio.dtypes.check_dtype(dtype):
Expand Down Expand Up @@ -1979,7 +1979,7 @@ def crop(
crop_geom: RasterType | Vector | list[float] | tuple[float, ...],
mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel",
*,
inplace: Literal[False] = ...,
inplace: Literal[False] = False,
) -> RasterType:
...

Expand All @@ -1999,7 +1999,7 @@ def crop(
crop_geom: RasterType | Vector | list[float] | tuple[float, ...],
mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel",
*,
inplace: bool = ...,
inplace: bool = False,
) -> RasterType | None:
...

Expand All @@ -2025,7 +2025,7 @@ def crop(
will match the extent exactly, adjusting the pixel resolution to fit the extent.
:param inplace: Whether to update the raster in-place.
:returns: A new raster, or None if cropping in-place.
:returns: A new raster (or None if inplace).
"""
assert mode in [
"match_extent",
Expand Down Expand Up @@ -2109,6 +2109,7 @@ def crop(
newraster.tags["AREA_OR_POINT"] = "Area"
return newraster

@overload
def reproject(
self: RasterType,
ref: RasterType | str | None = None,
Expand All @@ -2120,10 +2121,70 @@ def reproject(
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.bilinear,
force_source_nodata: int | float | None = None,
*,
inplace: Literal[False] = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> RasterType:
...

@overload
def reproject(
self: RasterType,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.bilinear,
force_source_nodata: int | float | None = None,
*,
inplace: Literal[True],
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> None:
...

@overload
def reproject(
self: RasterType,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.bilinear,
force_source_nodata: int | float | None = None,
*,
inplace: bool = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> RasterType | None:
...

def reproject(
self: RasterType,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.bilinear,
force_source_nodata: int | float | None = None,
inplace: bool = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> RasterType | None:
"""
Reproject raster to a different geotransform (resolution, bounds) and/or coordinate reference system (CRS).
Expand All @@ -2148,12 +2209,13 @@ def reproject(
:param resampling: A Rasterio resampling method, can be passed as a string.
See https://rasterio.readthedocs.io/en/stable/api/rasterio.enums.html#rasterio.enums.Resampling
for the full list.
:param inplace: Whether to update the raster in-place.
:param force_source_nodata: Force a source nodata value (read from the metadata by default).
:param silent: Whether to print warning statements.
:param n_threads: Number of threads. Defaults to (os.cpu_count() - 1).
:param memory_limit: Memory limit in MB for warp operations. Larger values may perform better.
:returns: Reprojected raster.
:returns: Reprojected raster (or None if inplace).
"""
# --- Sanity checks on inputs and defaults -- #
Expand Down Expand Up @@ -2328,21 +2390,69 @@ def reproject(
assert transform == transformed

# Write results to a new Raster.
r = self.from_array(data, transformed, crs, nodata)
if inplace:
# Order is important here, because calling self.data will use nodata to mask the array properly
self._crs = crs
self._nodata = nodata
self._transform = transform
# A little trick to force the right shape of data in, then update the mask properly through the data setter
self._data = data.squeeze()
self.data = data
return None
else:
return self.from_array(data, transformed, crs, nodata)

return r
@overload
def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: Literal[False] = False,
) -> RasterType:
...

@overload
def shift(
self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced"
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: Literal[True],
) -> None:
...

@overload
def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: bool = False,
) -> RasterType | None:
...

def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
inplace: bool = False,
) -> RasterType | None:
"""
Shift the raster by a (x,y) offset.
Shift a raster by a (x,y) offset.
The shifting only updates the geotransform (no resampling is performed).
:param xoff: Translation x offset.
:param yoff: Translation y offset.
:param distance_unit: Distance unit, either 'georeferenced' (default) or 'pixel'.
:param inplace: Whether to modify the raster in-place.
:returns: Shifted raster (or None if inplace).
"""
if distance_unit not in ["georeferenced", "pixel"]:
raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.")
Expand All @@ -2355,8 +2465,16 @@ def shift(
xoff *= self.res[0]
yoff *= self.res[1]

# Overwrite transform by shifted transform
self.transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff)
shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff)

if inplace:
# Overwrite transform by shifted transform
self.transform = shifted_transform
return None
else:
raster_copy = self.copy()
raster_copy.transform = shifted_transform
return raster_copy

def save(
self,
Expand Down Expand Up @@ -3523,6 +3641,7 @@ def _repr_html_(self) -> str:

return str(s)

@overload
def reproject(
self: Mask,
ref: RasterType | str | None = None,
Expand All @@ -3534,10 +3653,70 @@ def reproject(
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.nearest,
force_source_nodata: int | float | None = None,
*,
inplace: Literal[False] = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> Mask:
...

@overload
def reproject(
self: Mask,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.nearest,
force_source_nodata: int | float | None = None,
*,
inplace: Literal[True],
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> None:
...

@overload
def reproject(
self: Mask,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.nearest,
force_source_nodata: int | float | None = None,
*,
inplace: bool = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> Mask | None:
...

def reproject(
self: Mask,
ref: RasterType | str | None = None,
crs: CRS | str | int | None = None,
res: float | abc.Iterable[float] | None = None,
grid_size: tuple[int, int] | None = None,
bounds: dict[str, float] | rio.coords.BoundingBox | None = None,
nodata: int | float | None = None,
dtype: DTypeLike | None = None,
resampling: Resampling | str = Resampling.nearest,
force_source_nodata: int | float | None = None,
inplace: bool = False,
silent: bool = False,
n_threads: int = 0,
memory_limit: int = 64,
) -> Mask | None:
# Depending on resampling, adjust to rasterio supported types
if resampling in [Resampling.nearest, "nearest"]:
self._data = self.data.astype("uint8") # type: ignore
Expand All @@ -3558,6 +3737,7 @@ def reproject(
nodata=nodata,
dtype=dtype,
resampling=resampling,
inplace=False,
force_source_nodata=force_source_nodata,
silent=silent,
n_threads=n_threads,
Expand All @@ -3567,7 +3747,13 @@ def reproject(
# Transform back to a boolean array
output._data = output.data.astype(bool) # type: ignore

return output
if inplace:
self._transform = output._transform # type: ignore
self._crs = output._crs # type: ignore
self.data = output._data # type: ignore
return None
else:
return output

# Note the star is needed because of the default argument 'mode' preceding non default arg 'inplace'
# Then the final overload must be duplicated
Expand All @@ -3577,7 +3763,7 @@ def crop(
crop_geom: Mask | Vector | list[float] | tuple[float, ...],
mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel",
*,
inplace: Literal[False] = ...,
inplace: Literal[False] = False,
) -> Mask:
...

Expand All @@ -3597,7 +3783,7 @@ def crop(
crop_geom: Mask | Vector | list[float] | tuple[float, ...],
mode: Literal["match_pixel"] | Literal["match_extent"] = "match_pixel",
*,
inplace: bool = ...,
inplace: bool = False,
) -> Mask | None:
...

Expand Down
Loading

0 comments on commit a2a66c9

Please sign in to comment.