diff --git a/tests/test_coreg/test_biascorr.py b/tests/test_coreg/test_biascorr.py index 21ccb554..963327f8 100644 --- a/tests/test_coreg/test_biascorr.py +++ b/tests/test_coreg/test_biascorr.py @@ -1,14 +1,19 @@ """Tests for the biascorr module (non-rigid coregistrations).""" + from __future__ import annotations import re import warnings +import dask.array as da import geopandas as gpd import geoutils as gu import numpy as np import pytest +import rasterio +import rioxarray import scipy +from xarray.core.dataarray import DataArray import xdem.terrain from xdem import examples @@ -28,6 +33,24 @@ def load_examples() -> tuple[gu.Raster, gu.Raster, gu.Vector]: return reference_raster, to_be_aligned_raster, glacier_mask +def load_examples_xarray() -> tuple[DataArray, DataArray, DataArray]: + """Load cog example files as xarrays to try delayed / dask coregistration methods with.""" + chunk_size = 256 # the rasters are COGs with blocksizes 256 + reference_raster = rioxarray.open_rasterio( + filename=examples.get_path("longyearbyen_ref_dem"), chunks={"x": chunk_size, "y": chunk_size} + ).squeeze() + to_be_aligned_raster = rioxarray.open_rasterio( + filename=examples.get_path("longyearbyen_tba_dem"), chunks={"x": chunk_size, "y": chunk_size} + ).squeeze() + + # Create a raster mask on the fly from the vector data + glacier_mask_vector = gu.Vector(examples.get_path("longyearbyen_glacier_outlines")) + inlier_mask = glacier_mask_vector.create_mask(raster=gu.Raster(examples.get_path("longyearbyen_ref_dem"))) + inlier_mask = DataArray(da.from_array(inlier_mask.data.data, chunks=reference_raster.chunks)) + + return reference_raster, to_be_aligned_raster, inlier_mask + + class TestBiasCorr: ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask. inlier_mask = ~outlines.create_mask(ref) @@ -41,6 +64,15 @@ class TestBiasCorr: verbose=True, ) + # Load Xarray - Xarray example data + ref_xarr, tba_xarr, mask_xarr = load_examples_xarray() + fit_args_xarr_xarr = dict( + reference_elev=ref_xarr, + to_be_aligned_elev=tba_xarr, + inlier_mask=mask_xarr, + verbose=True, + ) + # Convert DEMs to points with a bit of subsampling for speed-up tba_pts = tba.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds @@ -64,6 +96,12 @@ class TestBiasCorr: all_fit_args = [fit_args_rst_rst, fit_args_rst_pts, fit_args_pts_rst] + # used to test the methods that have already been adapted to dask + # once all methods are adapted the fit_args_xarr_xarr can be added to all_fit_args + # without having to define them separately + all_fit_args_xaray = all_fit_args.copy() + all_fit_args_xaray.append(fit_args_xarr_xarr) + def test_biascorr(self) -> None: """Test the parent class BiasCorr instantiation.""" @@ -498,11 +536,14 @@ def test_deramp(self) -> None: # Check that variable names are defined during instantiation assert deramp.meta["bias_var_names"] == ["xx", "yy"] - @pytest.mark.parametrize("fit_args", all_fit_args) # type: ignore + @pytest.mark.parametrize("fit_args", all_fit_args_xaray) # type: ignore @pytest.mark.parametrize("order", [1, 2, 3, 4]) # type: ignore def test_deramp__synthetic(self, fit_args, order: int) -> None: """Run the deramp for varying polynomial orders using a synthetic elevation difference.""" + # These warning will cause pytest to fail, even though there is no issue with the data + warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) + # Get coordinates xx, yy = np.meshgrid(np.arange(0, self.ref.shape[1]), np.arange(0, self.ref.shape[0])) @@ -515,28 +556,59 @@ def test_deramp__synthetic(self, fit_args, order: int) -> None: # Create a synthetic bias and add to the DEM synthetic_bias = polynomial_2d((xx, yy), *params) - bias_dem = self.ref - synthetic_bias + + elev_fit_args = fit_args.copy() + + if isinstance(elev_fit_args["reference_elev"], DataArray): + # Unfortunately subtracting two rioxarrays looses their geospatial properties. So we need to create + # a new output rioxarray DataArray + bias_dem = DataArray( + da.from_array( + elev_fit_args["reference_elev"].data.compute() - synthetic_bias, + chunks=elev_fit_args["reference_elev"].data.chunks, + ) + ) + # Reset properties. Order matters!! + bias_dem = bias_dem.rio.write_transform(elev_fit_args["reference_elev"].rio.transform()) + bias_dem = bias_dem.rio.set_crs(elev_fit_args["reference_elev"].rio.crs) + bias_dem = bias_dem.rio.set_nodata(input_nodata=elev_fit_args["reference_elev"].rio.nodata) + + else: + bias_dem = self.ref - synthetic_bias # Fit deramp = biascorr.Deramp(poly_order=order) - elev_fit_args = fit_args.copy() if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=30000, random_state=42).ds else: bias_elev = bias_dem - deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, subsample=20000, random_state=42) + + deramp.fit( + elev_fit_args["reference_elev"], + to_be_aligned_elev=bias_elev, + inlier_mask=elev_fit_args["inlier_mask"], + subsample=20000, + random_state=42, + ) # Check high-order fit parameters are the same within 10% fit_params = deramp.meta["fit_params"] assert np.shape(fit_params) == np.shape(params) assert np.allclose( - params.reshape(order + 1, order + 1)[-1:, -1:], fit_params.reshape(order + 1, order + 1)[-1:, -1:], rtol=0.1 + params.reshape(order + 1, order + 1)[-1:, -1:], + fit_params.reshape(order + 1, order + 1)[-1:, -1:], + rtol=0.1, ) # Run apply and check that 99% of the variance was corrected - corrected_dem = deramp.apply(bias_dem) - # Need to standardize by the synthetic bias spread to avoid huge/small values close to infinity - assert np.nanvar((corrected_dem - self.ref) / np.nanstd(synthetic_bias)) < 0.01 + if isinstance(bias_dem, DataArray): + corrected_dem, _ = deramp.apply(bias_dem) + corrected_dem = corrected_dem.compute() + assert np.nanvar((corrected_dem - elev_fit_args["reference_elev"]) / np.nanstd(synthetic_bias)) < 0.01 + else: + corrected_dem = deramp.apply(bias_dem) + # Need to standardize by the synthetic bias spread to avoid huge/small values close to infinity + assert np.nanvar((corrected_dem - self.ref) / np.nanstd(synthetic_bias)) < 0.01 def test_terrainbias(self) -> None: """Test the subclass TerrainBias.""" diff --git a/tests/test_coreg/test_delayed.py b/tests/test_coreg/test_delayed.py new file mode 100644 index 00000000..835cefea --- /dev/null +++ b/tests/test_coreg/test_delayed.py @@ -0,0 +1,137 @@ +from unittest.mock import Mock + +import dask.array as da +import numpy as np +import pytest +import rasterio as rio + +from xdem._typing import NDArrayb, NDArrayf +from xdem.coreg.base import ( + _select_transform_crs, + get_valid_data, + mask_data, + valid_data_darr, +) + + +@pytest.mark.filterwarnings("ignore::UserWarning") # type: ignore [misc] +@pytest.mark.parametrize( # type: ignore [misc] + "epsg_ref, epsg_other, epsg, expected", + [ + (3246, 4326, 3005, 3246), + (None, 4326, 3005, 4326), + (None, None, 3005, 3005), + ], +) +def test__select_transform_crs_selects_correct_crs( + epsg_ref: int | None, epsg_other: int | None, epsg: int, expected: int +) -> None: + """Test _select_transform_crs selects the correct crs.""" + mock_transform = Mock(rio.transform.Affine) # we dont care about the transform in this test + + # for epsg_ref, epsg_other, epsg, expected in epsg_pairs: + _, crs = _select_transform_crs( + transform=mock_transform, + crs=rio.crs.CRS.from_epsg(epsg), + transform_reference=mock_transform, + transform_other=mock_transform, + crs_reference=rio.crs.CRS.from_epsg(epsg_ref) if epsg_ref is not None else epsg_ref, + crs_other=rio.crs.CRS.from_epsg(epsg_other) if epsg_other is not None else epsg_other, + ) + assert crs.to_epsg() == expected + + +def test__select_transform_crs_selects_correct_transform() -> None: + """Test _select_transform_crs selects the correct transform.""" + # TODO + pass + + +@pytest.mark.parametrize( # type: ignore[misc] + "input,nodata,expected", + [ + ( + np.array([np.nan, 1, -100, 1]), + -100, + np.array([np.nan, 1, np.nan, 1]), + ), + ( + np.array([1, 1, -100, 1]), + -100, + np.array([1, 1, np.nan, 1]), + ), + ( + np.array([np.nan, 1, 1, 1]), + -100, + np.array([np.nan, 1, 1, 1]), + ), + ], +) +def test_mask_data(input: NDArrayf, nodata: int, expected: NDArrayf) -> None: + """Test that mask_data masks the correct values.""" + output = mask_data(data=input, nodata=nodata) + assert np.array_equal(output, expected, equal_nan=True) + + +@pytest.mark.parametrize( # type: ignore [misc] + "input_arrays,nodatas,expected", + [ + ( + (np.array([np.nan, 1, -100, 1]),), + (-100,), + np.array([False, True, False, True]), + ), + ( + ( + np.array([np.nan, 1, -100, 1]), + np.array([1, -200, 1, 1]), + ), + (-100, -200), + np.array([False, False, False, True]), + ), + ( + ( + np.array([np.nan, 1, -100, 1]), + np.array([1, -200, 1, 1]), + np.array([1, 1, 1, -400]), + ), + (-100, -200, -400), + np.array([False, False, False, False]), + ), + ], +) +def test_get_valid_data(input_arrays: tuple[NDArrayf], nodatas: tuple[int], expected: NDArrayb) -> None: + """Test get_valid_data returns correct output.""" + output = get_valid_data(*input_arrays, nodatas=nodatas) + assert np.array_equal(output, expected, equal_nan=True) + + +@pytest.mark.parametrize( # type: ignore [misc] + "input_arrays,mask,nodatas,expected", + [ + ( + ( + da.from_array(np.array([1, 1, -100, 1]), chunks=2), + da.from_array(np.array([1, 1, -200, 1]), chunks=2), + ), + None, + (-100, -200), + np.array([True, True, False, True]), + ), + ( + ( + da.from_array(np.array([1, 1, -100, 1]), chunks=2), + da.from_array(np.array([1, 1, -200, 1]), chunks=2), + ), + da.from_array([False, True, True, True]), + (-100, -200), + np.array([False, True, False, True]), + ), + ], +) +def test_valid_data_darr( + input_arrays: tuple[NDArrayf], mask: NDArrayb, nodatas: tuple[int], expected: NDArrayb +) -> None: + """Test valid_data_darr returns correct output.""" + output = valid_data_darr(*input_arrays, mask=mask, nodatas=nodatas).compute() + assert np.array_equal(output, expected, equal_nan=True) diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index 365edd8b..e0e01588 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -47,6 +47,8 @@ subdivide_array, subsample_array, ) +from geoutils.raster.delayed import delayed_reproject, delayed_subsample +from rasterio.enums import Resampling from tqdm import tqdm from xdem._typing import MArrayf, NDArrayb, NDArrayf @@ -61,6 +63,9 @@ _HAS_P3D = False +import dask.array as da +from xarray.core.dataarray import DataArray + ########################################### # Generic functions for preprocessing ########################################### @@ -301,6 +306,140 @@ def _mask_as_array(reference_raster: gu.Raster, mask: str | gu.Vector | gu.Raste return mask_array +def _select_transform_crs( + transform: rio.transform.Affine | None, + crs: rio.crs.CRS | None, + transform_reference: rio.transform.Affine | None, + transform_other: rio.transform.Affine | None, + crs_reference: rio.crs.CRS | None, + crs_other: rio.crs.CRS | None, +) -> tuple[rio.transform.Affine, rio.crs.CRS]: + """Choose the coorrect transform and CRS.""" + # Choose the correct transform according to order of priority: dem_reference, dem_to_be_aligned, transform + new_transform = transform + if transform_reference is not None: + if new_transform is not None: + warnings.warn("transform of the reference DEM overrides the given 'transform'.") + new_transform = transform_reference + elif transform_other is not None: + if new_transform is not None: + warnings.warn("transform of the DEM to be aligned overrides the given 'transform'.") + new_transform = transform_other + + # Choose the crs with the same priority as the transform + new_crs = crs + if crs_reference is not None: + if new_crs is not None: + warnings.warn("crs of the reference DEM overrides the given 'crs'.") + new_crs = crs_reference + elif crs_other is not None: + if new_crs is not None: + warnings.warn("crs of the DEM to be aligned overrides the given 'crs'.") + new_crs = crs_other + + # Check that we have set the transforms + if new_transform is None: + raise ValueError("'transform' must be given if both DEMs are Xarrays without a transform.") + if new_crs is None: + raise ValueError("'crs' must be given if both DEMs are Xarrays without a CRS.") + return new_transform, new_crs + + +def get_valid_data(*datasets: NDArrayf, nodatas: tuple[int | float]) -> NDArrayb: + """Get the valid pixels across datasets.""" + return np.logical_and.reduce( + list( + map( + lambda data, nodata: (np.isfinite(data)) & (data != nodata), + datasets, + nodatas, + ) + ), + ) + + +def valid_data_darr(*datasets: da.Array, mask: da.Array = None, nodatas: tuple[int | float]) -> da.Array: + """Get the valid data from a set of dask arrays. + + Valid data is defined as data which is not nan or that rasters nodata. + + :param *datasets: The dask arrays to get valid data for. + :param mask: Optional mask to combine with additionally. + :param nodatas: The nodatas of the datasets. The order needs to be the same as the datasets passed. # type: ignore + """ + valid_data = da.map_blocks( + get_valid_data, + *datasets, + nodatas=nodatas, + dtype="bool", + ) + + # logical operators operate on a chunked dask array. + if mask is not None: + return mask & valid_data + return valid_data + + +def mask_data(data: NDArrayf, nodata: int | float) -> NDArrayf: + """Set invalid data in a dask array to nan.""" + return np.where(~get_valid_data(data, nodatas=(nodata,)), np.nan, data) + + +def _preprocess_coreg_fit_xarray_xarray( + reference_dem: DataArray, + dem_to_be_aligned: DataArray, + inlier_mask: DataArray | None = None, + transform: rio.transform.Affine | None = None, + crs: rio.crs.CRS | None = None, +) -> tuple[da.Array, da.Array, da.Array, affine.Affine, rio.crs.CRS]: + """Pre-processing and checks of fit() for xarray(dask) inputs. Outputs are dask arrays.""" + + # validate that both inputs are valid xarrays + if not all(isinstance(dem, DataArray) for dem in (reference_dem, dem_to_be_aligned)): + raise TypeError("Both DEMs need to be xarrays.") + if (inlier_mask is not None) and (not isinstance(inlier_mask, DataArray)): + raise TypeError(f"Mask has invalid type: {type(inlier_mask)}. Expected {DataArray}.") + + # TODO make sure the underlying data format is a chunked dask array ?... otherwise throw error + # TODO do we need to make sure that the inputs have aligned chunks? - would there be a problem with dask? + + transform, crs = _select_transform_crs( + transform=transform, + crs=crs, + transform_reference=reference_dem.rio.transform(), + transform_other=dem_to_be_aligned.rio.transform(), + crs_reference=reference_dem.rio.crs, + crs_other=dem_to_be_aligned.rio.crs, + ) + # this is needed for _get_valid_mask + dem_tba_nodata = dem_to_be_aligned.rio.nodata + + # reproject DEM to be aligned if it is not in the correct grid. + dem_to_be_aligned_reprojected = delayed_reproject( + darr=dem_to_be_aligned.data, + src_transform=dem_to_be_aligned.rio.transform(), + src_crs=dem_to_be_aligned.rio.crs, + dst_transform=reference_dem.rio.transform(), + dst_shape=reference_dem.shape, + dst_crs=reference_dem.rio.crs, + resampling=Resampling.bilinear, + src_nodata=dem_to_be_aligned.rio.nodata, + dst_nodata=reference_dem.rio.nodata, + dst_chunksizes=None, # reproject will use the destination chunksizes if set to None. + ) + + inlier_mask_all = valid_data_darr( + reference_dem.data, + dem_to_be_aligned.data, + mask=inlier_mask.data if inlier_mask is not None else None, + nodatas=(reference_dem.rio.nodata, dem_tba_nodata), # type: ignore [arg-type] + ) + + # TODO handle mask has no inliers -> np.all(~mask) + # outputs are dask arrays + return reference_dem.data, dem_to_be_aligned_reprojected, inlier_mask_all, transform, crs + + def _preprocess_coreg_fit_raster_raster( reference_dem: NDArrayf | MArrayf | RasterType, dem_to_be_aligned: NDArrayf | MArrayf | RasterType, @@ -442,18 +581,23 @@ def _preprocess_coreg_fit_point_point( def _preprocess_coreg_fit( - reference_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, - to_be_aligned_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, - inlier_mask: NDArrayb | Mask | None = None, + reference_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame | DataArray, + to_be_aligned_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame | DataArray, + inlier_mask: NDArrayb | Mask | DataArray | None = None, transform: rio.transform.Affine | None = None, crs: rio.crs.CRS | None = None, ) -> tuple[ - NDArrayf | gpd.GeoDataFrame, NDArrayf | gpd.GeoDataFrame, NDArrayb | None, affine.Affine | None, rio.crs.CRS | None + NDArrayf | gpd.GeoDataFrame | da.Array, + NDArrayf | gpd.GeoDataFrame | da.Array, + NDArrayb | da.Array | None, + affine.Affine | None, + rio.crs.CRS | None, ]: """Pre-processing and checks of fit for any input.""" if not all( - isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame)) for elev in (reference_elev, to_be_aligned_elev) + isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame, DataArray)) + for elev in (reference_elev, to_be_aligned_elev) ): raise ValueError("Input elevation data should be a raster, an array or a geodataframe.") @@ -489,6 +633,17 @@ def _preprocess_coreg_fit( ref_elev = point_elev tba_elev = raster_elev + elif all(isinstance(elev, DataArray) for elev in (reference_elev, to_be_aligned_elev)): + + # outputs are now dask arrays + ref_elev, tba_elev, inlier_mask, transform, crs = _preprocess_coreg_fit_xarray_xarray( + reference_dem=reference_elev, + dem_to_be_aligned=to_be_aligned_elev, + inlier_mask=inlier_mask, + transform=transform, + crs=crs, + ) + # If both inputs are points, simply reproject to the same CRS else: ref_elev, tba_elev = _preprocess_coreg_fit_point_point( @@ -499,13 +654,13 @@ def _preprocess_coreg_fit( def _preprocess_coreg_apply( - elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, + elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame | DataArray, transform: rio.transform.Affine | None = None, crs: rio.crs.CRS | None = None, -) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine, rio.crs.CRS]: +) -> tuple[NDArrayf | gpd.GeoDataFrame | da.Array, affine.Affine, rio.crs.CRS]: """Pre-processing and checks of apply for any input.""" - if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame)): + if not isinstance(elev, (np.ndarray, gu.Raster, gpd.GeoDataFrame, DataArray)): raise ValueError("Input elevation data should be a raster, an array or a geodataframe.") # If input is geodataframe @@ -514,6 +669,20 @@ def _preprocess_coreg_apply( new_transform = None new_crs = None + # If input is a Dataarray + elif isinstance(elev, DataArray): + new_transform, new_crs = _select_transform_crs( + transform=transform, + crs=crs, + transform_reference=elev.rio.transform(), + transform_other=None, + crs_reference=elev.rio.crs, + crs_other=None, + ) + + # get the masked elev + elev_out = da.map_blocks(mask_data, elev.data, nodata=elev.rio.nodata, chunks=elev.chunks, dtype=elev.dtype) + # If input is a raster or array else: # If input is raster @@ -552,6 +721,54 @@ def _postprocess_coreg_apply_pts( return applied_elev +def _postprocess_coreg_apply_xarray( + elev: da.Array, + applied_elev: da.Array, + transform: affine.Affine, + out_transform: affine.Affine, + crs: rio.crs.CRS, + resample: bool, + resampling: rio.warp.Resampling | None = None, +) -> tuple[DataArray, affine.Affine]: + """Post-processing and checks of apply for dask inputs.""" + + # make sure the datatype is correct + if applied_elev.dtype != np.float32: + applied_elev = applied_elev.astype(np.float32) + + # Reproject the corrected elevation + # NOTE is there a way to make this optional? It can save some compute time. + reprojected = delayed_reproject( + darr=applied_elev, + src_transform=out_transform, + src_crs=crs, + dst_transform=transform, + dst_shape=elev.shape, + dst_crs=crs, + resampling=resampling, + src_nodata=np.nan, + dst_nodata=elev.rio.nodata, + dst_chunksizes=None, + ) + + # Set nans to nodata value + reprojected = da.where(da.isnan(reprojected), elev.rio.nodata, reprojected) + + output_ds = DataArray( + da.expand_dims(reprojected, axis=0), + coords=elev.coords, # TODO is this correct? + dims=["band", "y", "x"], + name="Corrected DEM", + # attrs={}, # it's possible to set geotiff metadata via the attrs parameter + ) + + # Set crs and nodata value + output_ds = output_ds.rio.set_crs(elev.rio.crs) + output_ds = output_ds.rio.set_nodata(elev.rio.nodata) + + return output_ds, out_transform + + def _postprocess_coreg_apply_rst( elev: NDArrayf | gu.Raster, applied_elev: NDArrayf, @@ -608,14 +825,14 @@ def _postprocess_coreg_apply_rst( def _postprocess_coreg_apply( - elev: NDArrayf | gu.Raster | gpd.GeoDataFrame, - applied_elev: NDArrayf | gpd.GeoDataFrame, + elev: NDArrayf | gu.Raster | gpd.GeoDataFrame | da.Array, + applied_elev: NDArrayf | gpd.GeoDataFrame | da.Array, transform: affine.Affine, out_transform: affine.Affine, crs: rio.crs.CRS, resample: bool, resampling: rio.warp.Resampling | None = None, -) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine]: +) -> tuple[NDArrayf | gpd.GeoDataFrame | DataArray, affine.Affine]: """ Post-processing and checks of apply for any input. @@ -637,6 +854,17 @@ def _postprocess_coreg_apply( resample=resample, resampling=resampling, ) + elif isinstance(applied_elev, da.Array): + applied_elev, out_transform = _postprocess_coreg_apply_xarray( + elev=elev, + applied_elev=applied_elev, + transform=transform, + crs=crs, + out_transform=out_transform, + resample=resample, + resampling=resampling, + ) + else: applied_elev = _postprocess_coreg_apply_pts(applied_elev) @@ -1150,13 +1378,29 @@ def _get_subsample_on_valid_mask(self, valid_mask: NDArrayb, verbose: bool = Fal return subsample_mask + def _get_subsample_indices_dask(self, data: NDArrayb) -> tuple[NDArrayf, NDArrayf]: + """Get subsampled indices from a dask array.""" + + # subsample value is handled in delayed_subsample + indices = delayed_subsample( + darr=data, + subsample=self._meta["subsample"], + return_indices=True, + silence_max_subsample=True, + ) + + # Write final subsample to class + self._meta["subsample_final"] = len(indices[0]) + + return indices + def fit( self: CoregType, - reference_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, - to_be_aligned_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame, - inlier_mask: NDArrayb | Mask | None = None, - bias_vars: dict[str, NDArrayf | MArrayf | RasterType] | None = None, - weights: NDArrayf | None = None, + reference_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame | DataArray, + to_be_aligned_elev: NDArrayf | MArrayf | RasterType | gpd.GeoDataFrame | DataArray, + inlier_mask: NDArrayb | Mask | DataArray | None = None, + bias_vars: dict[str, NDArrayf | MArrayf | RasterType | DataArray] | None = None, + weights: NDArrayf | DataArray | None = None, # TODO is DataArray correct here? subsample: float | int | None = None, transform: rio.transform.Affine | None = None, crs: rio.crs.CRS | None = None, @@ -1292,9 +1536,23 @@ def apply( ) -> RasterType | gpd.GeoDataFrame: ... + @overload def apply( self, - elev: MArrayf | NDArrayf | RasterType | gpd.GeoDataFrame, + elev: DataArray, + bias_vars: dict[str, DataArray] | None = None, + resample: bool = True, + resampling: str | rio.warp.Resampling = "bilinear", + transform: rio.transform.Affine | None = None, + crs: rio.crs.CRS | None = None, + z_name: str = "z", + **kwargs: Any, + ) -> tuple[DataArray, rio.transform.Affine]: + ... + + def apply( + self, + elev: MArrayf | NDArrayf | RasterType | gpd.GeoDataFrame | DataArray, bias_vars: dict[str, NDArrayf | MArrayf | RasterType] | None = None, resample: bool = True, resampling: str | rio.warp.Resampling = "bilinear", @@ -1302,7 +1560,13 @@ def apply( crs: rio.crs.CRS | None = None, z_name: str = "z", **kwargs: Any, - ) -> RasterType | gpd.GeoDataFrame | tuple[NDArrayf, rio.transform.Affine] | tuple[MArrayf, rio.transform.Affine]: + ) -> ( + RasterType + | gpd.GeoDataFrame + | tuple[NDArrayf, rio.transform.Affine] + | tuple[MArrayf, rio.transform.Affine] + | tuple[DataArray, rio.transform.Affine] + ): """ Apply the estimated transform to a DEM. @@ -1331,6 +1595,7 @@ def apply( if self._is_affine: warnings.warn("This coregistration method is affine, ignoring `bias_vars` passed to apply().") + # TODO adapt this for dask for var in bias_vars.keys(): bias_vars[var] = gu.raster.get_array_and_mask(bias_vars[var])[0] @@ -1641,7 +1906,7 @@ def _fit_func( """ # Determine if input is raster-raster, raster-point or point-point - if all(isinstance(dem, np.ndarray) for dem in (kwargs["ref_elev"], kwargs["tba_elev"])): + if all(isinstance(dem, (np.ndarray, da.Array)) for dem in (kwargs["ref_elev"], kwargs["tba_elev"])): rop = "r-r" elif all(isinstance(dem, gpd.GeoDataFrame) for dem in (kwargs["ref_elev"], kwargs["tba_elev"])): rop = "p-p" @@ -1709,11 +1974,11 @@ def _fit_func( f"No point-point method found for coregistration {self.__class__.__name__}." ) - def _apply_func(self, **kwargs: Any) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine]: + def _apply_func(self, **kwargs: Any) -> tuple[NDArrayf | gpd.GeoDataFrame, affine.Affine | da.Array]: """Distribute to _apply_rst and _apply_pts based on input and method availability.""" # If input is a raster - if isinstance(kwargs["elev"], np.ndarray): + if isinstance(kwargs["elev"], (np.ndarray, da.Array)): # See if a _apply_rst exists try: diff --git a/xdem/coreg/biascorr.py b/xdem/coreg/biascorr.py index 5a0481be..980b45ff 100644 --- a/xdem/coreg/biascorr.py +++ b/xdem/coreg/biascorr.py @@ -1,9 +1,11 @@ """Bias corrections (i.e., non-affine coregistration) classes.""" + from __future__ import annotations import inspect from typing import Any, Callable, Iterable, Literal, TypeVar +import dask.array as da import geopandas as gpd import geoutils as gu import numpy as np @@ -15,6 +17,7 @@ from xdem._typing import NDArrayb, NDArrayf from xdem.coreg.base import Coreg from xdem.fit import ( + fit_chunked, polynomial_1d, polynomial_2d, robust_nfreq_sumsin_fit, @@ -41,9 +44,9 @@ class BiasCorr(Coreg): def __init__( self, fit_or_bin: Literal["bin_and_fit"] | Literal["fit"] | Literal["bin"] = "fit", - fit_func: Callable[..., NDArrayf] - | Literal["norder_polynomial"] - | Literal["nfreq_sumsin"] = "norder_polynomial", + fit_func: ( + Callable[..., NDArrayf] | Literal["norder_polynomial"] | Literal["nfreq_sumsin"] + ) = "norder_polynomial", fit_optimizer: Callable[..., tuple[NDArrayf, Any]] = scipy.optimize.curve_fit, bin_sizes: int | dict[str, int | Iterable[float]] = 10, bin_statistic: Callable[[NDArrayf], np.floating[Any]] = np.nanmedian, @@ -145,9 +148,9 @@ def __init__( def _fit_biascorr( # type: ignore self, - ref_elev: NDArrayf, - tba_elev: NDArrayf, - inlier_mask: NDArrayb, + ref_elev: NDArrayf | da.Array, + tba_elev: NDArrayf | da.Array, + inlier_mask: NDArrayb | da.Array, transform: rio.transform.Affine, # Never None thanks to Coreg.fit() pre-process crs: rio.crs.CRS, # Never None thanks to Coreg.fit() pre-process z_name: str, @@ -188,16 +191,34 @@ def _fit_biascorr( # type: ignore # TODO: Move the check up to Coreg.fit()? diff = ref_elev - tba_elev - valid_mask = np.logical_and.reduce( - (inlier_mask, np.isfinite(diff), *(np.isfinite(var) for var in bias_vars.values())) - ) - # Raise errors if all values are NaN after introducing masks from the variables - # (Others are already checked in Coreg.fit()) - if np.all(~valid_mask): - raise ValueError("Some 'bias_vars' have only NaNs in the inlier mask.") + if all(isinstance(dem, da.Array) for dem in (ref_elev, tba_elev, inlier_mask)): + + # calculate the valid mask from which to sample from + data = [inlier_mask, da.isfinite(diff), *(da.isfinite(var) for var in bias_vars.values())] + valid_mask = da.map_blocks( + lambda *arrays: np.logical_and.reduce(arrays), + *data, + chunks=inlier_mask.chunks, # type: ignore [union-attr] + dtype="bool", + ) + + # TODO the output is called mask but it's the indices. Find a nicer way to handle this + subsample_mask = self._get_subsample_indices_dask(data=valid_mask) + else: + valid_mask = np.logical_and.reduce( + (inlier_mask, np.isfinite(diff), *(np.isfinite(var) for var in bias_vars.values())) + ) - subsample_mask = self._get_subsample_on_valid_mask(valid_mask=valid_mask, verbose=verbose) + # Raise errors if all values are NaN after introducing masks from the variables + # (Others are already checked in Coreg.fit()) + if np.all(~valid_mask): + raise ValueError("Some 'bias_vars' have only NaNs in the inlier mask.") + + subsample_mask = self._get_subsample_on_valid_mask( # type: ignore [assignment] + valid_mask=valid_mask, + verbose=verbose, + ) # Get number of variables nd = len(bias_vars) @@ -220,6 +241,19 @@ def _fit_biascorr( # type: ignore else: bin_sizes = self._meta["bin_sizes"] + if isinstance(diff, np.ndarray): + ydata = diff[subsample_mask] + xdata = [var[subsample_mask] for var in bias_vars.values()] + sigma = weights[subsample_mask] if weights is not None else None + + elif isinstance(diff, da.Array): + ydata = diff.vindex[subsample_mask].compute() # type:ignore [assignment] + xdata = [var.vindex[subsample_mask].compute() for var in bias_vars.values()] + # TODO - where do the weights come from? Are they also dask arrays? + sigma = weights.vindex[subsample_mask].compute() if weights is not None else None + else: + raise TypeError(f"Incompatible input type for arrays {type(diff)}.") + # Option 1: Run fit and save optimized function parameters if self._fit_or_bin == "fit": @@ -230,11 +264,12 @@ def _fit_biascorr( # type: ignore "with function {}.".format(", ".join(list(bias_vars.keys())), self._meta["fit_func"].__name__) ) + # we dont need to call the fit_chunked here because the data going in is not a chunked dask array. results = self._meta["fit_optimizer"]( f=self._meta["fit_func"], - xdata=np.array([var[subsample_mask].flatten() for var in bias_vars.values()]).squeeze(), - ydata=diff[subsample_mask].flatten(), - sigma=weights[subsample_mask].flatten() if weights is not None else None, + ydata=ydata.flatten(), + xdata=np.array([data.flatten() for data in xdata]), + sigma=sigma.flatten() if sigma is not None else None, absolute_sigma=True, **kwargs, ) @@ -249,8 +284,8 @@ def _fit_biascorr( # type: ignore ) df = xdem.spatialstats.nd_binning( - values=diff[subsample_mask], - list_var=[var[subsample_mask] for var in bias_vars.values()], + values=ydata, + list_var=xdata, list_var_names=list(bias_vars.keys()), list_var_bins=bin_sizes, statistics=(self._meta["bin_statistic"], "count"), @@ -271,8 +306,8 @@ def _fit_biascorr( # type: ignore ) df = xdem.spatialstats.nd_binning( - values=diff[subsample_mask], - list_var=[var[subsample_mask] for var in bias_vars.values()], + values=ydata, + list_var=xdata, list_var_names=list(bias_vars.keys()), list_var_bins=bin_sizes, statistics=(self._meta["bin_statistic"], "count"), @@ -335,9 +370,9 @@ def _fit_biascorr( # type: ignore def _fit_rst_rst( self, - ref_elev: NDArrayf, - tba_elev: NDArrayf, - inlier_mask: NDArrayb, + ref_elev: NDArrayf | da.Array, + tba_elev: NDArrayf | da.Array, + inlier_mask: NDArrayb | da.Array, transform: rio.transform.Affine, crs: rio.crs.CRS, z_name: str, @@ -441,12 +476,12 @@ def _fit_rst_pts( # type: ignore def _apply_rst( # type: ignore self, - elev: NDArrayf, + elev: NDArrayf | da.Array, transform: rio.transform.Affine, # Never None thanks to Coreg.fit() pre-process crs: rio.crs.CRS, # Never None thanks to Coreg.fit() pre-process bias_vars: None | dict[str, NDArrayf] = None, **kwargs: Any, - ) -> tuple[NDArrayf, rio.transform.Affine]: + ) -> tuple[NDArrayf, rio.transform.Affine] | tuple[da.Array, rio.transform.Affine]: if bias_vars is None: raise ValueError("At least one `bias_var` should be passed to the `apply` function, got None.") @@ -460,7 +495,12 @@ def _apply_rst( # type: ignore # Apply function to get correction (including if binning was done before) if self._fit_or_bin in ["fit", "bin_and_fit"]: - corr = self._meta["fit_func"](tuple(bias_vars.values()), *self._meta["fit_params"]) + if isinstance(list(bias_vars.values())[0], da.Array): + corr = fit_chunked( + tuple(bias_vars.values()), *self._meta["fit_params"], fit_func=self._meta["fit_func"] + ) + else: + corr = self._meta["fit_func"](tuple(bias_vars.values()), *self._meta["fit_params"]) # Apply binning to get correction else: @@ -645,9 +685,9 @@ def __init__( self, terrain_attribute: str = "maximum_curvature", fit_or_bin: Literal["bin_and_fit"] | Literal["fit"] | Literal["bin"] = "bin", - fit_func: Callable[..., NDArrayf] - | Literal["norder_polynomial"] - | Literal["nfreq_sumsin"] = "norder_polynomial", + fit_func: ( + Callable[..., NDArrayf] | Literal["norder_polynomial"] | Literal["nfreq_sumsin"] + ) = "norder_polynomial", fit_optimizer: Callable[..., tuple[NDArrayf, Any]] = scipy.optimize.curve_fit, bin_sizes: int | dict[str, int | Iterable[float]] = 100, bin_statistic: Callable[[NDArrayf], np.floating[Any]] = np.nanmedian, @@ -796,6 +836,20 @@ def _apply_rst( return super()._apply_rst(elev=elev, transform=transform, crs=crs, bias_vars=bias_vars, **kwargs) +# TODO move this function somewhere sensible +def meshgrid( + _: NDArrayf | NDArrayb, + block_info: dict[Any, Any], + axis: Literal["x", "y"] = "x", +) -> NDArrayf: + """A bit of a hack to create a meshgrid for a dask array.""" + loc = block_info[0]["array-location"] + mesh = np.meshgrid(np.arange(*loc[1]), np.arange(*loc[0])) + if axis == "x": + return mesh[0] + return mesh[1] + + class Deramp(BiasCorr): """ Correct for a 2D polynomial along X/Y coordinates, for example from residual camera model deformations @@ -858,7 +912,11 @@ def _fit_rst_rst( # type: ignore p0 = np.ones(shape=((self._meta["poly_order"] + 1) ** 2)) # Coordinates (we don't need the actual ones, just array coordinates) - xx, yy = np.meshgrid(np.arange(0, ref_elev.shape[1]), np.arange(0, ref_elev.shape[0])) + if type(ref_elev) == da.Array: + xx = da.map_blocks(meshgrid, ref_elev, chunks=ref_elev.chunks, dtype=ref_elev.dtype) + yy = da.map_blocks(meshgrid, ref_elev, axis="y", chunks=ref_elev.chunks, dtype=ref_elev.dtype) + else: + xx, yy = np.meshgrid(np.arange(0, ref_elev.shape[1]), np.arange(0, ref_elev.shape[0])) self._fit_biascorr( ref_elev=ref_elev, @@ -921,6 +979,10 @@ def _apply_rst( ) -> tuple[NDArrayf, rio.transform.Affine]: # Define the coordinates for applying the correction - xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0])) + if type(elev) == da.Array: + xx = da.map_blocks(meshgrid, elev, chunks=elev.chunks, dtype=elev.dtype) + yy = da.map_blocks(meshgrid, elev, axis="y", chunks=elev.chunks, dtype=elev.dtype) + else: + xx, yy = np.meshgrid(np.arange(0, elev.shape[1]), np.arange(0, elev.shape[0])) return super()._apply_rst(elev=elev, transform=transform, crs=crs, bias_vars={"xx": xx, "yy": yy}, **kwargs) diff --git a/xdem/examples.py b/xdem/examples.py index f85a74b2..69fb1517 100644 --- a/xdem/examples.py +++ b/xdem/examples.py @@ -1,4 +1,5 @@ """Utility functions to download and find example data.""" + import os import tarfile import tempfile diff --git a/xdem/fit.py b/xdem/fit.py index 58690c62..7a15cb1b 100644 --- a/xdem/fit.py +++ b/xdem/fit.py @@ -1,12 +1,14 @@ """ Functions to perform normal, weighted and robust fitting. """ + from __future__ import annotations import inspect import warnings from typing import Any, Callable +import dask.array as da import numpy as np import scipy from geoutils.raster import subsample_array @@ -65,6 +67,33 @@ def soft_loss(z: NDArrayf, scale: float = 0.5) -> float: ###################################################### +def fit_chunked(arrays: tuple[da.Array, ...], *params: NDArrayf, fit_func: Callable[..., NDArrayf]) -> da.Array: + """Call a fit func with dask arrays to run the fit function on the array chunks. + + :param arrays: Tuple of dask arrays + :param params: `params` passed to the fit func. + :param fit_func: The fit function to call. Needs to have the function signature f(x , *params) + + :return: Output, Delayed. + """ + + def fit_chunk(*arrays: da.Array, fit_func: Callable[..., NDArrayf], other_params: NDArrayf) -> NDArrayf: + return fit_func(arrays, *other_params).squeeze() + + # if no chunks are passed, map_blocks will use the chunks of the first input array. + if not isinstance(arrays, tuple): + raise TypeError("Inputs to the fit wrapper chunked must be tuple of arrays.") + + # when calling apply input is a tuple + return da.map_blocks( + fit_chunk, + *arrays, + fit_func=fit_func, + other_params=params, + dtype=np.float32, + ) + + def sumsin_1d(xx: NDArrayf, *params: NDArrayf) -> NDArrayf: """ Sum of N sinusoids in 1D.