Skip to content

Commit

Permalink
Incremental commit on accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Nov 15, 2024
1 parent c75aa07 commit f680eaa
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 240 deletions.
205 changes: 192 additions & 13 deletions geoutils/raster/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import warnings
from typing import Any, Callable, Iterable, Literal, TypeVar, overload

import affine
from affine import Affine
import geopandas as gpd
import numpy as np
import rasterio as rio
import xarray as xr
import xarray.core.indexing
from packaging.version import Version
from rasterio.crs import CRS
from rasterio.enums import Resampling
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(self):

# Main attributes of a raster
self._data: MArrayNum | xr.DataArray | None = None
self._transform: affine.Affine | None = None
self._transform: Affine | None = None
self._crs: CRS | None = None
self._nodata: int | float | None = None
self._area_or_point: Literal["Area", "Point"] | None = None
Expand All @@ -80,45 +79,222 @@ def __init__(self):
self._disk_shape: tuple[int, int, int] | None = None
self._disk_bands: tuple[int] | None = None
self._disk_dtype: str | None = None
self._disk_transform: affine.Affine | None = None
self._disk_transform: Affine | None = None
self._out_count: int | None = None
self._out_shape: tuple[int, int] | None = None
self._disk_hash: int | None = None
self._is_modified = True
self._downsample: int | float = 1

@property
def is_xr(self) -> bool:
def _is_xr(self) -> bool:
"""Whether the underlying object is a Xarray Dataset through accessor, or not."""
return self._obj is not None

@property
def data(self) -> MArrayNum | xr.DataArray:
if self.is_xr:
if self._is_xr:
return self._obj.data
else:
return self._data

@property
def transform(self) -> affine.Affine:
if self.is_xr:
def transform(self) -> Affine:
"""
Geotransform of the raster.
:returns: Affine matrix geotransform.
"""
if self._is_xr:
return self._obj.rio.transform(recalc=True)
else:
return self._transform

@transform.setter
def transform(self, new_transform: tuple[float, ...] | Affine | None) -> None:

self.set_transform(new_transform=new_transform)

def set_transform(self, new_transform: Affine) -> None:
"""
Set the geotransform of the raster.
"""
if not isinstance(new_transform, Affine) or new_transform is not None:
if isinstance(new_transform, tuple):
new_transform = Affine(*new_transform)
else:
raise TypeError("The transform argument needs to be Affine or tuple.")

if self._is_xr:
self._obj.rio.write_transform(new_transform)
else:
self._transform = new_transform

@property
def crs(self) -> CRS:
if self.is_xr:
"""
Coordinate reference system of the raster.
:returns: Pyproj coordinate reference system.
"""
if self._is_xr:
return self._obj.rio.crs
else:
return self._crs

@crs.setter
def crs(self, new_crs: CRS | int | str | None) -> None:

self.set_crs(new_crs)

def set_crs(self, new_crs: CRS) -> None:
"""
Set the coordinate reference system of the raster.
"""

if new_crs is not None:
new_crs = CRS.from_user_input(value=new_crs)

if self._is_xr:
self._obj.rio.set_crs(new_crs)
else:
self._crs = new_crs

@property
def nodata(self) -> int | float | None:
if self.is_xr:
"""
Nodata value of the raster.
When setting with self.nodata = new_nodata, uses the default arguments of self.set_nodata().
:returns: Nodata value.
"""
if self._is_xr:
return self._obj.rio.nodata
else:
return self._nodata

@nodata.setter
def nodata(self, new_nodata: int | float | None) -> None:
"""
Set .nodata and update .data by calling set_nodata() with default parameters.
By default, the old nodata values are updated into the new nodata in the data array .data.data, and the
mask .data.mask is updated to mask all new nodata values (i.e., the mask from old nodata stays and is extended
to potential new values of new nodata found in the array).
To set nodata for more complex cases (e.g., redefining a wrong nodata that has a valid value in the array),
call the function set_nodata() directly to set the arguments update_array and update_mask adequately.
:param new_nodata: New nodata to assign to this instance of Raster.
"""

self.set_nodata(new_nodata=new_nodata)

def set_nodata(
self,
new_nodata: int | float | None,
update_array: bool = True,
update_mask: bool = True,
) -> None:
"""
Set a new nodata value for all bands. This updates the old nodata into a new nodata value in the metadata,
replaces the nodata values in the data of the masked array, and updates the mask of the masked array.
Careful! If the new nodata value already exists in the array, the related grid cells will be masked by default.
If the nodata value was not defined in the raster, run this function with a new nodata value corresponding to
the value of nodata that exists in the data array and is not yet accounted for. All those values will be masked.
If a nodata value was correctly defined in the raster, and you wish to change it to a new value, run
this function with that new value. All values having either the old or new nodata value will be masked.
If the nodata value was wrongly defined in the raster, and you wish to change it to a new value without
affecting data that might have the value of the old nodata, run this function with the update_array
argument as False. Only the values of the new nodata will be masked.
If you wish to set nodata value without updating the mask, run this function with the update_mask argument as
False.
If None is passed as nodata, only the metadata is updated and the mask of old nodata unset.
:param new_nodata: New nodata value.
:param update_array: Update the old nodata values into new nodata values in the data array.
:param update_mask: Update the old mask by unmasking old nodata and masking new nodata (if array is updated,
old nodata are changed to new nodata and thus stay masked).
"""
if new_nodata is not None and not isinstance(new_nodata, (int, float, np.integer, np.floating)):
raise ValueError("Type of nodata not understood, must be float or int.")

if new_nodata is not None:
if not rio.dtypes.can_cast_dtype(new_nodata, self.dtype):
raise ValueError(f"Nodata value {new_nodata} incompatible with self.dtype {self.dtype}.")

if self._is_xr:
self._obj.rio.set_nodata(new_nodata)

else:
# If we update mask or array, get the masked array
if update_array or update_mask:

# Extract the data variable, so the self.data property doesn't have to be called a bunch of times
imgdata = self.data

# Get the index of old nodatas
index_old_nodatas = imgdata.data == self.nodata

# Get the index of new nodatas, if it is defined
index_new_nodatas = imgdata.data == new_nodata

if np.count_nonzero(index_new_nodatas) > 0:
if update_array and update_mask:
warnings.warn(
message="New nodata value cells already exist in the data array. These cells will now be "
"masked, and the old nodata value cells will update to the same new value. "
"Use set_nodata() with update_array=False or update_mask=False to change "
"this behaviour.",
category=UserWarning,
)
elif update_array:
warnings.warn(
"New nodata value cells already exist in the data array. The old nodata cells will update to "
"the same new value. Use set_nodata() with update_array=False to change this behaviour.",
category=UserWarning,
)
elif update_mask:
warnings.warn(
"New nodata value cells already exist in the data array. These cells will now be masked. "
"Use set_nodata() with update_mask=False to change this behaviour.",
category=UserWarning,
)

if update_array:
# Only update array with new nodata if it is defined
if new_nodata is not None:
# Replace the nodata value in the Raster
imgdata.data[index_old_nodatas] = new_nodata

if update_mask:
# If a mask already exists, unmask the old nodata values before masking the new ones
# Can be skipped if array is updated (nodata is transferred from old to new, this part of the mask
# stays the same)
if np.ma.is_masked(imgdata) and (not update_array or new_nodata is None):
# No way to unmask a value from the masked array, so we modify the mask directly
imgdata.mask[index_old_nodatas] = False

# Masking like this works from the masked array directly, whether a mask exists or not
imgdata[index_new_nodatas] = np.ma.masked

# Update the data
self._data = imgdata

# Update the nodata value
self._nodata = new_nodata

# Update the fill value only if the data is loaded
if self.is_loaded:
self.data.fill_value = new_nodata

def set_area_or_point(
self, new_area_or_point: Literal["Area", "Point"] | None, shift_area_or_point: bool | None = None
) -> None:
Expand Down Expand Up @@ -223,7 +399,7 @@ def area_or_point(self, new_area_or_point: Literal["Area", "Point"] | None) -> N
@property
def is_loaded(self) -> bool:
"""Whether the raster array is loaded."""
if self.is_xr:
if self._is_xr:
# TODO: Activating this requires to have _disk_shape defined for RasterAccessor
return True
# return isinstance(self._obj.variable._data, np.ndarray)
Expand Down Expand Up @@ -829,11 +1005,14 @@ def translate(

if inplace:
# Overwrite transform by translated transform
self.transform = translated_transform
self.set_transform(translated_transform)
return None
else:
raster_copy = self.copy()
raster_copy.transform = translated_transform
if self._is_xr:
raster_copy.rst.set_transform(translated_transform)
else:
raster_copy.set_transform(translated_transform)
return raster_copy

def reduce_points(
Expand Down
57 changes: 33 additions & 24 deletions geoutils/raster/geotransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def _is_reproj_needed(src_shape: tuple[int, int], reproj_kwargs: dict[str, Any])
]
)


def _reproject(
source_raster: gu.Raster,
ref: gu.Raster,
Expand Down Expand Up @@ -381,40 +380,50 @@ def _reproject(
reproj_kwargs.update({"num_threads": num_threads, "warp_mem_limit": memory_limit})

# --- Run the reprojection of data --- #
# If data is loaded, reproject the numpy array directly
if source_raster.is_loaded:
# All masked values must be set to a nodata value for rasterio's reproject to work properly
# TODO: another option is to apply rio.warp.reproject to the mask to identify invalid pixels
if src_nodata is None and np.sum(source_raster.data.mask) > 0:
raise ValueError(
"No nodata set, set one for the raster with self.set_nodata() or use a temporary one "
"with `force_source_nodata`."
)

# Mask not taken into account by rasterio, need to fill with src_nodata
data, transformed = rio.warp.reproject(source_raster.data.filled(src_nodata), **reproj_kwargs)
if source_raster._is_xr:

src_data = source_raster.data
src_data[np.isnan(src_data)] = src_nodata
data, transformed = rio.warp.reproject(source_raster.data, **reproj_kwargs)
data[data == nodata] = np.nan

# If not, uses the dataset instead
else:
data = [] # type: ignore
for k in range(source_raster.count):
with rio.open(source_raster.filename) as ds:
band = rio.band(ds, k + 1)
band, transformed = rio.warp.reproject(band, **reproj_kwargs)
data.append(band.squeeze())
# If data is loaded, reproject the numpy array directly
if source_raster.is_loaded:
# All masked values must be set to a nodata value for rasterio's reproject to work properly
# TODO: another option is to apply rio.warp.reproject to the mask to identify invalid pixels
if src_nodata is None and np.sum(source_raster.data.mask) > 0:
raise ValueError(
"No nodata set, set one for the raster with self.set_nodata() or use a temporary one "
"with `force_source_nodata`."
)

data = np.array(data)
# Mask not taken into account by rasterio, need to fill with src_nodata
data, transformed = rio.warp.reproject(source_raster.data.filled(src_nodata), **reproj_kwargs)

# Enforce output type
data = np.ma.masked_array(data.astype(dtype), fill_value=nodata)
# If not, uses the dataset instead
else:
data = [] # type: ignore
for k in range(source_raster.count):
with rio.open(source_raster.filename) as ds:
band = rio.band(ds, k + 1)
band, transformed = rio.warp.reproject(band, **reproj_kwargs)
data.append(band.squeeze())

data = np.array(data)

if nodata is not None:
data.mask = data == nodata
# Enforce output type
data = np.ma.masked_array(data.astype(dtype), fill_value=nodata)

if nodata is not None:
data.mask = data == nodata

# Check for funny business.
if reproj_kwargs["dst_transform"] is not None:
assert reproj_kwargs["dst_transform"] == transformed


return False, data, transformed, crs, nodata


Expand Down
Loading

0 comments on commit f680eaa

Please sign in to comment.