Skip to content

Commit

Permalink
Improve error messages and make Raster indexing/assigment NumPy com…
Browse files Browse the repository at this point in the history
…patible (#454)
  • Loading branch information
rhugonnet authored Jan 30, 2024
1 parent 5c34af7 commit fe8795c
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 98 deletions.
122 changes: 79 additions & 43 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,51 @@ def _get_reproject_params(
return dst_transform, dst_size


def _check_cast_array_raster(
input1: RasterType | NDArrayNum, input2: RasterType | NDArrayNum, operation_name: str
) -> None:
"""
Check the casting between an array and a raster, or raise an (helpful) error message.
:param input1: Raster or array.
:param input2: Raster or array.
:param operation_name: Name of operation to raise in the error message.
:return: None.
"""

if isinstance(input1, Raster) and isinstance(input2, Raster):

# Check that both rasters have the same shape and georeferences
if input1.georeferenced_grid_equal(input2):
pass
else:
raise ValueError(
"Both rasters must have the same shape, transform and CRS for " + operation_name + ". "
"For example, use raster1 = raster1.reproject(raster2) to reproject raster1 on the "
"same grid and CRS than raster2."
)

else:

# The shape compatibility should be valid even when squeezing
if isinstance(input1, np.ndarray):
input1 = input1.squeeze()
elif isinstance(input2, np.ndarray):
input2 = input2.squeeze()

if input1.shape == input2.shape:
pass
else:
raise ValueError(
"The raster and array must have the same shape for " + operation_name + ". "
"For example, if the array comes from another raster, use raster1 = "
"raster1.reproject(raster2) beforehand to reproject raster1 on the same grid and CRS "
"than raster2. Or, if the array does not come from a raster, define one with raster = "
"Raster.from_array(array, array_transform, array_crs, array_nodata) then reproject."
)


class Raster:
"""
The georeferenced raster.
Expand Down Expand Up @@ -860,27 +905,25 @@ def __str__(self) -> str:

return str(s)

def __getitem__(self, index: Raster | Vector | NDArrayNum | list[float] | tuple[float, ...]) -> NDArrayNum | Raster:
def __getitem__(self, index: Mask | NDArrayBool | Any) -> NDArrayBool | Raster:
"""
Index or subset the raster.
Index the raster.
Two cases:
- If a mask of same georeferencing or array of same shape is passed, return the indexed raster array.
- If a raster, vector, list or tuple of bounds is passed, return the cropped raster matching those objects.
In addition to all index types supported by NumPy, also supports a mask of same georeferencing or a
boolean array of the same shape as the raster.
"""

if isinstance(index, (Mask, np.ndarray)):
_check_cast_array_raster(self, index, operation_name="an indexing operation") # type: ignore

# If input is Mask with the same shape and georeferencing
if isinstance(index, Mask):
if not self.georeferenced_grid_equal(index):
raise ValueError("Indexing a raster with a mask requires the two being on the same georeferenced grid.")
if self.count == 1:
return self.data[index.data.squeeze()]
else:
return self.data[:, index.data.squeeze()]
# If input is array with the same shape
elif isinstance(index, np.ndarray):
if np.shape(index) != self.shape:
raise ValueError("Indexing a raster with an array requires the two having the same shape.")
if str(index.dtype) != "bool":
index = index.astype(bool)
warnings.warn(message="Input array was cast to boolean for indexing.", category=UserWarning)
Expand All @@ -889,49 +932,47 @@ def __getitem__(self, index: Raster | Vector | NDArrayNum | list[float] | tuple[
else:
return self.data[:, index]

# Otherwise, subset with crop
# Otherwise, use any other possible index and leave it to NumPy
else:
return self.crop(crop_geom=index)
return self.data[index]

def __setitem__(self, index: Mask | NDArrayBool, assign: NDArrayNum | Number) -> None:
def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Number) -> None:
"""
Perform index assignment on the raster.
If a mask of same georeferencing or array of same shape is passed,
it is used as index to assign values to the raster array.
In addition to all index types supported by NumPy, also supports a mask of same georeferencing or a
boolean array of the same shape as the raster.
"""

# First, check index
if isinstance(index, (Mask, np.ndarray)):
_check_cast_array_raster(self, index, operation_name="an index assignment operation") # type: ignore

# If input is Mask with the same shape and georeferencing
if isinstance(index, Mask):
if not self.georeferenced_grid_equal(index):
raise ValueError("Indexing a raster with a mask requires the two being on the same georeferenced grid.")

ind = index.data.data
use_all_bands = False
# If input is array with the same shape
elif isinstance(index, np.ndarray):
if np.shape(index) != self.shape:
raise ValueError("Indexing a raster with an array requires the two having the same shape.")
if str(index.dtype) != "bool":
ind = index.astype(bool)
warnings.warn(message="Input array was cast to boolean for indexing.", category=UserWarning)
else:
ind = index
# Otherwise, raise an error
use_all_bands = False
# Otherwise, use the index, NumPy will raise appropriate errors itself
else:
raise ValueError(
"Indexing a raster requires a mask of same georeferenced grid, or a boolean array of same shape."
)
ind = index
use_all_bands = True

# Second, assign, NumPy will raise appropriate errors itself
# Second, assign the data, here also let NumPy do the job

# We need to explicitly load here, as we cannot call the data getter/setter directly
if not self.is_loaded:
self.load()
# Assign the values to the index
if self.count == 1:

# Assign the values to the index (single band raster with mask/array, or other NumPy index)
if self.count == 1 or use_all_bands:
self._data[ind] = assign # type: ignore
# For multi-band rasters with a mask/array
else:
self._data[:, ind] = assign # type: ignore
return None
Expand Down Expand Up @@ -991,18 +1032,12 @@ def _overloading_check(
nodata1 = self.nodata
dtype1 = self.data.dtype

# Raise error messages if grids don't match (CRS + transform for raster, shape for array)
if isinstance(other, (Raster, np.ndarray)):
_check_cast_array_raster(self, other, operation_name="an arithmetic operation") # type: ignore

# Case 1 - other is a Raster
if isinstance(other, Raster):
# Not necessary anymore with implicit loading
# # Check that both data are loaded
# if not (self.is_loaded & other.is_loaded):
# raise ValueError("Raster's data must be loaded with self.load().")

# Check that both rasters have the same shape and georeferences
if (self.data.shape == other.data.shape) & (self.transform == other.transform) & (self.crs == other.crs):
pass
else:
raise ValueError("Both rasters must have the same shape, transform and CRS.")

nodata2 = other.nodata
dtype2 = other.data.dtype
Expand All @@ -1018,11 +1053,6 @@ def _overloading_check(
else:
other_data = other

if self.data.shape == other_data.shape:
pass
else:
raise ValueError("The raster and array must have the same shape.")

nodata2 = None
dtype2 = other.dtype

Expand Down Expand Up @@ -1851,6 +1881,10 @@ def __array_ufunc__(

# If the universal function takes two inputs (Note: no ufunc exists that has three inputs or more)
else:

# Check the casting between Raster and array inputs, and return error messages if not consistent
_check_cast_array_raster(inputs[0], inputs[1], "an arithmetic operation") # type: ignore

if ufunc.nout == 1:
return self.from_array(
data=final_ufunc(inputs[0].data, inputs[1].data, **kwargs), # type: ignore
Expand Down Expand Up @@ -1911,6 +1945,8 @@ def __array_function__(
outputs = func(first_arg, *args[1:], **kwargs) # type: ignore
else:
second_arg = args[1].data
# Check the casting between Raster and array inputs, and return error messages if not consistent
_check_cast_array_raster(first_arg, second_arg, operation_name="an arithmetic operation")
outputs = func(first_arg, second_arg, *args[2:], **kwargs) # type: ignore

# Below, we recast to Raster if the shape was preserved, otherwise return an array
Expand Down
10 changes: 2 additions & 8 deletions geoutils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,16 +725,10 @@ def rename_geometry(self, col: str, inplace: bool = False) -> Vector | None:

def __getitem__(self, key: gu.Raster | Vector | list[float] | tuple[float, ...] | Any) -> Vector:
"""
Index the geodataframe or crop the vector.
If a raster, vector or tuple is passed, crops to its bounds.
Otherwise, indexes the geodataframe.
Index the geodataframe.
"""

if isinstance(key, (gu.Raster, Vector)):
return self.crop(crop_geom=key, clip=False)
else:
return self._override_gdf_output(self.ds.__getitem__(key))
return self._override_gdf_output(self.ds.__getitem__(key))

@copy_doc(gpd.GeoDataFrame, "Vector")
def __setitem__(self, key: Any, value: Any) -> None:
Expand Down
Loading

0 comments on commit fe8795c

Please sign in to comment.