From c75aa07004bae558b490e80efb635e3b4f57ca1f Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Thu, 14 Nov 2024 14:02:02 -0900 Subject: [PATCH] Incremental commit on Xarray accessor --- geoutils/__init__.py | 5 +- geoutils/raster/base.py | 244 ++++++++++-------- geoutils/raster/geotransformations.py | 2 +- geoutils/raster/raster.py | 18 +- .../raster/{accessor.py => rst_accessor.py} | 10 +- tests/test_raster/test_accessor.py | 38 --- tests/test_raster/test_base.py | 162 ++++++++++++ tests/test_raster/test_rst_accessor.py | 15 ++ 8 files changed, 325 insertions(+), 169 deletions(-) rename geoutils/raster/{accessor.py => rst_accessor.py} (82%) delete mode 100644 tests/test_raster/test_accessor.py create mode 100644 tests/test_raster/test_base.py create mode 100644 tests/test_raster/test_rst_accessor.py diff --git a/geoutils/__init__.py b/geoutils/__init__.py index 2b883507..ecac15c5 100644 --- a/geoutils/__init__.py +++ b/geoutils/__init__.py @@ -4,9 +4,8 @@ from geoutils import examples, projtools, raster, vector # noqa from geoutils._config import config # noqa -from geoutils.raster import accessor # noqa -from geoutils.raster import Mask, Raster, SatelliteImage # noqa -from geoutils.raster.accessor import open_raster # noqa +from geoutils.raster import Mask, Raster, SatelliteImage, rst_accessor # noqa +from geoutils.raster.rst_accessor import open_raster # noqa from geoutils.vector import Vector # noqa try: diff --git a/geoutils/raster/base.py b/geoutils/raster/base.py index 3db38bb3..1825e145 100644 --- a/geoutils/raster/base.py +++ b/geoutils/raster/base.py @@ -11,6 +11,7 @@ import numpy as np import rasterio as rio import xarray as xr +import xarray.core.indexing from packaging.version import Version from rasterio.crs import CRS from rasterio.enums import Resampling @@ -56,22 +57,35 @@ class RasterBase: It gathers all the functions shared by the Raster class and the 'rst' Xarray accessor. """ + def __init__(self): + """Initialize all raster metadata as None, for it to be overridden in sublasses.""" - _obj: None | xr.DataArray + # Attribute for Xarray accessor: will stay None in Raster class + self._obj: None | xr.DataArray = None - def __init__(self): # Main attributes of a raster - self._data: MArrayNum | xr.DataArray + self._data: MArrayNum | xr.DataArray | None = None self._transform: affine.Affine | None = None self._crs: CRS | None = None - self._nodata: int | float | None + self._nodata: int | float | None = None self._area_or_point: Literal["Area", "Point"] | None = None # Other non-derivatives attributes + self._bands: int | list[int] | None = None self._driver: str | None = None self._name: str | None = None self.filename: str | None = None self.tags: dict[str, Any] = {} + self._bands_loaded: int | tuple[int, ...] | None = None + self._disk_shape: tuple[int, int, int] | None = None + self._disk_bands: tuple[int] | None = None + self._disk_dtype: str | None = None + self._disk_transform: affine.Affine | None = None + self._out_count: int | None = None + self._out_shape: tuple[int, int] | None = None + self._disk_hash: int | None = None + self._is_modified = True + self._downsample: int | float = 1 @property def is_xr(self) -> bool: @@ -80,7 +94,7 @@ def is_xr(self) -> bool: @property def data(self) -> MArrayNum | xr.DataArray: if self.is_xr: - return self._obj.rio.data + return self._obj.data else: return self._data @@ -105,10 +119,117 @@ def nodata(self) -> int | float | None: else: return self._nodata + def set_area_or_point( + self, new_area_or_point: Literal["Area", "Point"] | None, shift_area_or_point: bool | None = None + ) -> None: + """ + Set new pixel interpretation of the raster. + + Overwrites the `area_or_point` attribute and updates "AREA_OR_POINT" in raster metadata tags. + + Optionally, shifts the raster to correct value coordinates in relation to interpretation: + + - By half a pixel (right and downwards) if old interpretation was "Area" and new is "Point", + - By half a pixel (left and upwards) if old interpretration was "Point" and new is "Area", + - No shift for all other cases. + + :param new_area_or_point: New pixel interpretation "Area", "Point" or None. + :param shift_area_or_point: Whether to shift with pixel interpretation, which shifts to center of pixel + indexes if self.area_or_point is "Point" and maintains corner pixel indexes if it is "Area" or None. + Defaults to True. Can be configured with the global setting geoutils.config["shift_area_or_point"]. + + :return: None. + """ + + # If undefined, default to the global system config + if shift_area_or_point is None: + shift_area_or_point = config["shift_area_or_point"] + + # Check input + if new_area_or_point is not None and not ( + isinstance(new_area_or_point, str) and new_area_or_point.lower() in ["area", "point"] + ): + raise ValueError("New pixel interpretation must be 'Area', 'Point' or None.") + + # Update string input as exactly "Area" or "Point" + if new_area_or_point is not None: + if new_area_or_point.lower() == "area": + new_area_or_point = "Area" + else: + new_area_or_point = "Point" + + # Save old area or point + old_area_or_point = self.area_or_point + + # Set new interpretation + self._area_or_point = new_area_or_point + # Update tag only if not None + if new_area_or_point is not None: + self.tags.update({"AREA_OR_POINT": new_area_or_point}) + else: + if "AREA_OR_POINT" in self.tags: + self.tags.pop("AREA_OR_POINT") + + # If shift is True, and both interpretation were different strings, a change is needed + if ( + shift_area_or_point + and isinstance(old_area_or_point, str) + and isinstance(new_area_or_point, str) + and old_area_or_point != new_area_or_point + ): + # The shift below represents +0.5/+0.5 or opposite in indexes (as done in xy2ij), but because + # the Y axis is inverted, a minus signs is added to shift the coordinate (even if the unit is in pixel) + + # If the new one is Point, we shift back by half a pixel + if new_area_or_point == "Point": + xoff = 0.5 + yoff = -0.5 + # Otherwise we shift forward half a pixel + else: + xoff = -0.5 + yoff = 0.5 + # We perform the shift in place + self.translate(xoff=xoff, yoff=yoff, distance_unit="pixel", inplace=True) + @property - def area_or_point(self): + def area_or_point(self) -> Literal["Area", "Point"] | None: + """ + Pixel interpretation of the raster. + + Based on the "AREA_OR_POINT" raster metadata: + + - If pixel interpretation is "Area", the value of the pixel is associated with its upper left corner. + - If pixel interpretation is "Point", the value of the pixel is associated with its center. + + When setting with self.area_or_point = new_area_or_point, uses the default arguments of + self.set_area_or_point(). + """ return self._area_or_point + @area_or_point.setter + def area_or_point(self, new_area_or_point: Literal["Area", "Point"] | None) -> None: + """ + Setter for pixel interpretation. + + Uses default arguments of self.set_area_or_point(): shifts by half a pixel going from "Area" to "Point", + or the opposite. + + :param new_area_or_point: New pixel interpretation "Area", "Point" or None. + + :return: None. + """ + self.set_area_or_point(new_area_or_point=new_area_or_point) + + @property + def is_loaded(self) -> bool: + """Whether the raster array is loaded.""" + if self.is_xr: + # TODO: Activating this requires to have _disk_shape defined for RasterAccessor + return True + # return isinstance(self._obj.variable._data, np.ndarray) + else: + return self._data is not None + @property def res(self) -> tuple[float | int, float | int]: """Resolution (X, Y) of the raster in georeferenced units.""" @@ -183,11 +304,6 @@ def shape(self) -> tuple[int, int]: # If data loaded or not, pass the disk/data shape through height and width return self.height, self.width - @property - def is_loaded(self) -> bool: - """Whether the raster array is loaded.""" - return self._data is not None - @property def dtype(self) -> str: """Data type of the raster (string representation).""" @@ -235,107 +351,6 @@ def name(self) -> str | None: """Name of the file on disk, if it exists.""" return self._name - def set_area_or_point( - self, new_area_or_point: Literal["Area", "Point"] | None, shift_area_or_point: bool | None = None - ) -> None: - """ - Set new pixel interpretation of the raster. - - Overwrites the `area_or_point` attribute and updates "AREA_OR_POINT" in raster metadata tags. - - Optionally, shifts the raster to correct value coordinates in relation to interpretation: - - - By half a pixel (right and downwards) if old interpretation was "Area" and new is "Point", - - By half a pixel (left and upwards) if old interpretration was "Point" and new is "Area", - - No shift for all other cases. - - :param new_area_or_point: New pixel interpretation "Area", "Point" or None. - :param shift_area_or_point: Whether to shift with pixel interpretation, which shifts to center of pixel - indexes if self.area_or_point is "Point" and maintains corner pixel indexes if it is "Area" or None. - Defaults to True. Can be configured with the global setting geoutils.config["shift_area_or_point"]. - - :return: None. - """ - - # If undefined, default to the global system config - if shift_area_or_point is None: - shift_area_or_point = config["shift_area_or_point"] - - # Check input - if new_area_or_point is not None and not ( - isinstance(new_area_or_point, str) and new_area_or_point.lower() in ["area", "point"] - ): - raise ValueError("New pixel interpretation must be 'Area', 'Point' or None.") - - # Update string input as exactly "Area" or "Point" - if new_area_or_point is not None: - if new_area_or_point.lower() == "area": - new_area_or_point = "Area" - else: - new_area_or_point = "Point" - - # Save old area or point - old_area_or_point = self.area_or_point - - # Set new interpretation - self._area_or_point = new_area_or_point - # Update tag only if not None - if new_area_or_point is not None: - self.tags.update({"AREA_OR_POINT": new_area_or_point}) - else: - if "AREA_OR_POINT" in self.tags: - self.tags.pop("AREA_OR_POINT") - - # If shift is True, and both interpretation were different strings, a change is needed - if ( - shift_area_or_point - and isinstance(old_area_or_point, str) - and isinstance(new_area_or_point, str) - and old_area_or_point != new_area_or_point - ): - # The shift below represents +0.5/+0.5 or opposite in indexes (as done in xy2ij), but because - # the Y axis is inverted, a minus signs is added to shift the coordinate (even if the unit is in pixel) - - # If the new one is Point, we shift back by half a pixel - if new_area_or_point == "Point": - xoff = 0.5 - yoff = -0.5 - # Otherwise we shift forward half a pixel - else: - xoff = -0.5 - yoff = 0.5 - # We perform the shift in place - self.translate(xoff=xoff, yoff=yoff, distance_unit="pixel", inplace=True) - - @property - def area_or_point(self) -> Literal["Area", "Point"] | None: - """ - Pixel interpretation of the raster. - - Based on the "AREA_OR_POINT" raster metadata: - - - If pixel interpretation is "Area", the value of the pixel is associated with its upper left corner. - - If pixel interpretation is "Point", the value of the pixel is associated with its center. - - When setting with self.area_or_point = new_area_or_point, uses the default arguments of - self.set_area_or_point(). - """ - return self._area_or_point - - @area_or_point.setter - def area_or_point(self, new_area_or_point: Literal["Area", "Point"] | None) -> None: - """ - Setter for pixel interpretation. - - Uses default arguments of self.set_area_or_point(): shifts by half a pixel going from "Area" to "Point", - or the opposite. - - :param new_area_or_point: New pixel interpretation "Area", "Point" or None. - - :return: None. - """ - self.set_area_or_point(new_area_or_point=new_area_or_point) - @overload def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ... @@ -1105,9 +1120,10 @@ def outside_image(self, xi: ArrayLike, yj: ArrayLike, index: bool = True) -> boo :param xi: Indices (or coordinates) of x direction to check. :param yj: Indices (or coordinates) of y direction to check. - :param index: Interpret ij as raster indices (default is ``True``). If False, assumes ij is coordinates. + :param index: Interpret xi and yj as raster indices (default is ``True``). If False, assumes xi and yj are + coordinates. - :returns is_outside: ``True`` if ij is outside the image. + :returns is_outside: ``True`` if xi/yj is outside the raster extent. """ return _outside_image( diff --git a/geoutils/raster/geotransformations.py b/geoutils/raster/geotransformations.py index e43e51a5..aa77fde4 100644 --- a/geoutils/raster/geotransformations.py +++ b/geoutils/raster/geotransformations.py @@ -444,7 +444,7 @@ def _crop( elif isinstance(crop_geom, (list, tuple)): xmin, ymin, xmax, ymax = crop_geom else: - raise ValueError("cropGeom must be a Raster, Vector, or list of coordinates.") + raise ValueError("'crop_geom' must be a Raster, Vector, or list of coordinates.") if mode == "match_pixel": # Finding the intersection of requested bounds and original bounds, cropped to image shape diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index 543fb4d9..3382207a 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -320,7 +320,7 @@ def __init__( bands: int | list[int] | None = None, load_data: bool = False, downsample: Number = 1, - nodata: int | float | None = None, + force_nodata: int | float | None = None, ) -> None: """ Instantiate a raster from a filename or rasterio dataset. @@ -333,21 +333,15 @@ def __init__( :param downsample: Downsample the array once loaded by a round factor. Default is no downsampling. - :param nodata: Nodata value to be used (overwrites the metadata). Default reads from metadata. + :param force_nodata: Force nodata value to be used (overwrites the metadata). Default reads from metadata. """ + + super().__init__() + self._data: MArrayNum | None = None + self._nodata = force_nodata self._bands = bands - self._bands_loaded: int | tuple[int, ...] | None = None self._masked = True - self._out_count: int | None = None - self._out_shape: tuple[int, int] | None = None - self._disk_hash: int | None = None - self._is_modified = True - self._disk_shape: tuple[int, int, int] | None = None - self._disk_bands: tuple[int] | None = None - self._disk_dtype: str | None = None - self._disk_transform: affine.Affine | None = None - self._downsample: int | float = 1 # This is for Raster.from_array to work. if isinstance(filename_or_dataset, dict): diff --git a/geoutils/raster/accessor.py b/geoutils/raster/rst_accessor.py similarity index 82% rename from geoutils/raster/accessor.py rename to geoutils/raster/rst_accessor.py index a51feb97..4ed12990 100644 --- a/geoutils/raster/accessor.py +++ b/geoutils/raster/rst_accessor.py @@ -30,11 +30,19 @@ def open_raster(filename: str, **kwargs): @xr.register_dataarray_accessor("rst") class RasterAccessor(RasterBase): def __init__(self, xarray_obj: xr.DataArray): + + super().__init__() + self._obj = xarray_obj + self._area_or_point = self._obj.attrs.get("AREA_OR_POINT", None) + + def copy(self, new_array: NDArrayNum | None = None) -> xr.DataArray: + + return self._obj.copy(data=new_array) def to_raster(self) -> RasterBase: """ - Convert to geoutils.Raster object. + Convert to Raster object. :return: """ diff --git a/tests/test_raster/test_accessor.py b/tests/test_raster/test_accessor.py deleted file mode 100644 index 94c4194f..00000000 --- a/tests/test_raster/test_accessor.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Tests on Xarray accessor mirroring Raster API.""" - -import warnings - -import rioxarray as rioxr - -from geoutils import Raster, examples, open_raster - - -class TestAccessor: - - def test_open_raster(self): - pass - - -class TestConsistencyRasterAccessor: - - # Test over many different rasters - landsat_b4_path = examples.get_path("everest_landsat_b4") - - @pytest.mark.parametrize("path_raster", [landsat_b4_path]) # type: ignore - @pytest.mark.parametrize("method", nongeo_properties) # type: ignore - def test_properties(self, path_raster: str, method: str) -> None: - """Check non-geometric properties are consistent with GeoPandas.""" - - # Open - ds = open_raster(path_raster) - raster = Raster(path_raster) - - # Remove warnings about operations in a non-projected system, and future changes - warnings.simplefilter("ignore", category=UserWarning) - warnings.simplefilter("ignore", category=FutureWarning) - - # Get method for each class - output_raster = getattr(raster, method) - output_ds = getattr(ds, method) - - # Assert equality diff --git a/tests/test_raster/test_base.py b/tests/test_raster/test_base.py new file mode 100644 index 00000000..f096f752 --- /dev/null +++ b/tests/test_raster/test_base.py @@ -0,0 +1,162 @@ +"""Test RasterBase class, parent of Raster class and 'rst' Xarray accessor.""" +from __future__ import annotations + +import warnings +from typing import Any + +import pytest +from pyproj import CRS +import numpy as np +import xarray as xr + +from geoutils import Vector, Raster, open_raster +from geoutils import examples + +class TestRasterBase: + + pass + + +def equal_xr_raster(ds: xr.DataArray, rast: Raster) -> bool: + """Check equality of a Raster object and Xarray object""" + # TODO: Move to raster_equal? + return all([ + np.array_equal(ds.data.values, rast.get_nanarray(), equal_nan=True), + ds.rst.transform == rast.transform, + ds.rst.crs == rast.crs, + ds.rst.nodata == rast.nodata, + ]) + +def output_equal(output1: Any, output2: Any) -> bool: + """Return equality of different output types.""" + + # For two vectors + if isinstance(output1, Vector) and isinstance(output2, Vector): + return output1.vector_equal(output2) + + # For two raster: Xarray or Raster objects + elif isinstance(output1, Raster) and isinstance(output2, Raster): + return output1.raster_equal(output2) + elif isinstance(output1, Raster) and isinstance(output2, xr.DataArray): + return equal_xr_raster(ds=output2, rast=output1) + elif isinstance(output1, xr.DataArray) and isinstance(output2, Raster): + return equal_xr_raster(ds=output1, rast=output2) + + # For arrays + elif isinstance(output1, np.ndarray): + return np.array_equal(output1, output2, equal_nan=True) + + # For tuple of arrays + elif isinstance(output1, tuple) and isinstance(output1[0], np.ndarray): + return np.array_equal(np.array(output1), np.array(output2), equal_nan=True) + + # For any other object type + else: + return output1 == output2 + +class TestClassVsAccessorConsistency: + """ + Test class to check the consistency between the outputs of the Raster class and Xarray accessor for the same + attributes or methods. + + All shared attributes should be the same. + All operations manipulating the array should yield a comparable results, accounting for the fact that Raster class + relies on masked-arrays and the Xarray accessor on NaN arrays. + """ + + # Run tests for different rasters + landsat_b4_path = examples.get_path("everest_landsat_b4") + aster_dem_path = examples.get_path("exploradores_aster_dem") + landsat_rgb_path = examples.get_path("everest_landsat_rgb") + + # Test common attributes + attributes = ["crs", "transform", "nodata", "area_or_point", "res", "count", "height", "width", "footprint", + "shape", "bands", "indexes", "is_xr", "is_loaded"] + + @pytest.mark.parametrize("path_raster", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore + @pytest.mark.parametrize("attr", attributes) # type: ignore + def test_attributes(self, path_raster: str, attr: str) -> None: + """Test that attributes of the two objects are exactly the same.""" + + # Open + ds = open_raster(path_raster) + raster = Raster(path_raster) + + # Remove warnings about operations in a non-projected system, and future changes + warnings.simplefilter("ignore", category=UserWarning) + warnings.simplefilter("ignore", category=FutureWarning) + + # Get attribute for each object + output_raster = getattr(raster, attr) + output_ds = getattr(getattr(ds, "rst"), attr) + + # Assert equality + if attr is not "is_xr": # Only attribute that is (purposely) not the same, but the opposite + assert output_equal(output_raster, output_ds) + else: + assert output_raster != output_ds + + + # Test common methods + methods_and_args = { + "reproject": {"crs": CRS.from_epsg(32610), "res": 10}, + "crop": {"crop_geom": "random"}, + "translate": {"xoff": 10.5, "yoff": 5}, + "xy2ij": {"x": "random", "y": "random"}, # This will be derived during the test to work on all inputs + "ij2xy": {"i": [0, 1, 2, 3], "j": [4, 5, 6, 7]}, + "coords": {"grid": True}, + "get_metric_crs": {"local_crs_type": "universal"}, + "reduce_points": {"points": "random"}, # This will be derived during the test to work on all inputs + "interp_points": {"points": "random"}, # This will be derived during the test to work on all inputs + "proximity": {"target_values": [100]}, + "outside_image": {"xi": [-2, 10000, 10], "yj": [10, 50, 20]}, + "to_pointcloud": {"subsample": 1000, "random_state": 42}, + "polygonize": {"target_values": "all"}, + "subsample": {"subsample": 1000, "random_state": 42}, + } + + @pytest.mark.parametrize("path_raster", [landsat_b4_path]) # type: ignore + @pytest.mark.parametrize("method", list(methods_and_args.keys())) # type: ignore + def test_methods(self, path_raster: str, method: str) -> None: + """ + Test that the outputs of the two objects are exactly the same + (converted for the case of a raster/vector output, as it can be a Xarray/GeoPandas object or Raster/Vector). + """ + + # Open both objects + ds = open_raster(path_raster) + raster = Raster(path_raster) + + # Remove warnings about operations in a non-projected system, and future changes + warnings.simplefilter("ignore", category=UserWarning) + warnings.simplefilter("ignore", category=FutureWarning) + + # Loop for specific inputs that require knowledge of the data + if "points" in self.methods_and_args[method].keys() or "x" in self.methods_and_args[method].keys(): + rng = np.random.default_rng(seed=42) + ninterp = 10 + res = raster.res + interp_x = (rng.choice(raster.shape[0], ninterp) + rng.random(ninterp)) * res[0] + interp_y = (rng.choice(raster.shape[1], ninterp) + rng.random(ninterp)) * res[1] + args = self.methods_and_args[method].copy() + if "points" in self.methods_and_args[method].keys(): + args.update({"points": (interp_x, interp_y)}) + elif "x" in self.methods_and_args[method].keys(): + args.update({"x": interp_x, "y": interp_y}) + + elif "crop_geom" in self.methods_and_args[method].keys(): + crop_geom = raster.bounds.left + 100, raster.bounds.bottom + 200, \ + raster.bounds.left + 320, raster.bounds.bottom + 411 + args = self.methods_and_args[method].copy() + args.update({"crop_geom": crop_geom}) + + else: + args = self.methods_and_args[method].copy() + + # Apply method for each class + output_raster = getattr(raster, method)(**args) + output_ds = getattr(getattr(ds, "rst"), method)(**args) + + # Assert equality of output + assert output_equal(output_raster, output_ds) + diff --git a/tests/test_raster/test_rst_accessor.py b/tests/test_raster/test_rst_accessor.py new file mode 100644 index 00000000..d024547a --- /dev/null +++ b/tests/test_raster/test_rst_accessor.py @@ -0,0 +1,15 @@ +"""Tests on Xarray accessor mirroring Raster API.""" + +import warnings + +import pytest + +from geoutils import Raster, examples, open_raster + + +class TestAccessor: + + def test_open_raster(self): + pass + +