Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask for coregistration #525

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

ameliefroessl
Copy link

@ameliefroessl ameliefroessl commented May 21, 2024

Use dask to improve the memory usage of the BiasCorr methods.

Remaining todo item:

  • Map remaining logic from _postprocess_coreg_apply_rst() to _postprocess_coreg_apply_xarray
  • bias_vars mask that you mentioned Dask for coregistration #525 (comment)
  • Adapt other methods
    • Method ...
  • Generic method for map_blocks in the fit function
  • Create unit tests for the introduced functions
  • Fix unit tests
  • Save output to file that works out of memory
  • Update geoutils dependency to 0.1.6
  • Rebase to xDEM main
  • Add documentation on how to load the inputs such that the data is processed with dask. (rioxarray + chunks)

sigma = None
elif isinstance(diff, da.Array):
ydata = diff.vindex[subsample_mask].flatten().compute() # type:ignore [assignment]
xdata = [var.vindex[subsample_mask].flatten() for var in bias_vars.values()] # type:ignore [assignment]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need a .compute() here as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, even if the bias_vars are not delayed (no operation since their creation/passing in the _fit_rst_rst) function, the compute seems needed, otherwise you get a dask array output that won't be understood by the following steps:

da.ones((10000, 10000)).vindex[[0, 1], [0, 1]]
Out[12]: dask.array<vindex-merge, shape=(2,), dtype=float64, chunksize=(2,), chunktype=numpy.ndarray>
da.ones((10000, 10000)).vindex[[0, 1], [0, 1]].compute()
Out[13]: array([1., 1.])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! nice catch :) - changed it!

@@ -1077,7 +1107,10 @@ def _fit_rst_rst( # type: ignore
p0 = np.ones(shape=((self._meta["poly_order"] + 1) ** 2))

# Coordinates (we don't need the actual ones, just array coordinates)
xx, yy = np.meshgrid(np.arange(0, ref_elev.shape[1]), np.arange(0, ref_elev.shape[0]))
if type(ref_elev) == da.Array:
xx, yy = da.meshgrid(da.arange(0, ref_elev.shape[1]), da.arange(0, ref_elev.shape[0]))
Copy link
Member

@rhugonnet rhugonnet May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define the chunks here to ensure the auto ones are not completely irrelevant and don't result in a lot of memory usage. Maybe simply copying the chunks of the reference DEM?

Copy link
Author

@ameliefroessl ameliefroessl May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems like chunks is not a supported argument to da.meshgrid unfortunately. i tried rechunking it but it doesn't seem to make that big of a difference computation-wise 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange indeed, it seems important... We can check if that matters later with a memory usage-specific test!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i agree! I'll add a note so we remember to have a closer look at it.

Copy link
Author

@ameliefroessl ameliefroessl May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so i might have found a work around on how to calculate the meshgrid for a dask array.

Code
import dask.array as da
import numpy as np

data = np.array(
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16],
    ]
)

m_grid = np.meshgrid(
    np.arange(0, data.shape[0]),
    np.arange(0, data.shape[1]),
)


def meshgrid(_, axis="x", block_info=None):
    """A bit of a hack to create a meshgrid for a dask array."""
    loc = block_info[0]["array-location"]
    mesh = np.meshgrid(np.arange(*loc[1]), np.arange(*loc[0]))
    if axis == "x":
        return mesh[0]
    return mesh[1]


da_data = da.from_array(data, chunks=(2, 2))

m_grid_x = da.map_blocks(meshgrid, da_data, chunks=da_data.chunks, dtype=da_data.dtype).compute()
m_grid_y = da.map_blocks(meshgrid, da_data, axis="y", chunks=da_data.chunks, dtype=da_data.dtype).compute()


assert np.array_equal(m_grid[0], m_grid_x)
assert np.array_equal(m_grid[1], m_grid_y)

Copy link
Member

@rhugonnet rhugonnet May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!

I still found it weird that they didn't include a chunk argument directly in Dask, and found this interesting discussion:
dask/dask#2943 (comment) (and dask/dask#2943 (comment) is also important).

If I understood correctly, doing da.meshgrid(da.arange(, chunks=), da.arange(, chunks=) should have the same behaviour as chunking classically! (or even more optimized thanks to what they do in broadcast_to)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! Adding the chunks to the input da.arange() makes it possible to chunk the output array. However, I didn't find a way to combine it with delayed. Meaning even if it's chunked I think it loads the entire array into memory. Instead of "creating"? / "reading" ? that part only when needed in the processing. But I'd love to be proven wrong 😃

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to combine it with delayed, because it's only sampled later by vindex (no in-between calculations on the chunks), and that works well out-of-memory without delaying!

import dask.array as da
xx, yy = da.meshgrid(da.arange(0, 20000, chunks=500), da.arange(0, 20000, chunks=500))
xx.vindex[[0, 1], [0, 1]]
Out[60]: array([0, 1])

Which has no memory usage despite being a ~5GB array! 😄

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why exactly but when I use this version of meshgrid the entire application just hangs without doing anything... maybe i'll investigate further but currently I can't seem to figure out what's the issue 🤔

)
if all(isinstance(dem, (da.Array, Delayed)) for dem in (ref_elev, tba_elev)):
diff = da.subtract(ref_elev, tba_elev)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to compute the valid_mask here for the delayed inputs to ensure we don't introduce outliers from the ref_elev or tba_elev, same as below:

            valid_mask = np.logical_and.reduce(
                (inlier_mask, np.isfinite(diff), *(np.isfinite(var) for var in bias_vars.values()))
            )

Not sure if this is supported directly, or if we need to use map_blocks for it...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess map_blocks the same way you used it further up would work perfectly :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason I have not calculated this is because I calculate the inlier_mask from all of the inputs here . Taking also the non-valid values into account from the ref_elev and tba_elev. I think this would cover that? Or did I miss something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for the bias_vars (that can have different non-finite values that will propagate during the fit), not covered by _get_valid_data_mask!
And actually this will also be relevant in affine functions (for instance for slope/aspect in NuthKaab that propagate NaNs of the DEM)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh ok ok. I missed that thank you! 👍 ill make sure to add it 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't have an effect on the Deramp algorithm though since the bias_vars is created from the input coordinates of the reference_elevation input correct? I'm just asking because I'm running into some unexpected outputs when I have nans in the original input data.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it shouldn't have any effect in this case, the meshgrid should be always a full array 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yes that makes sense. I did end up figuring out what the issue was with the nan data. I was calculating the mask on the output from the delayed_reproject (here is the change). I'm not 100% sure, but my assumption is that there is some smoothing/ resampling happening and i guess the nodata values accidentally get smeared into neighboring pixels?... Meaning it was potentially sampling some very high values and the estimated function was pretty off. This was my assumption after looking at the input y values going into the fit function.

@@ -230,11 +251,24 @@ def _fit_biascorr( # type: ignore
"with function {}.".format(", ".join(list(bias_vars.keys())), self._meta["fit_func"].__name__)
)

if isinstance(diff, np.ndarray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the data subsampling that can be moved out of the if loop "Option 1: Run fit...", then xdata/ydata/etc can be passed to every option similarly 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right! that makes sense :) 👍

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is what you meant? 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which of the unit test would you expect to pass if passed an xarray instead of the normal input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly!
I guess for tests with only fit(), everything should pass exactly the same with just the Xarray input.
For tests that also have apply(), you'd have to write a check on the input type and call .compute() to be able to run the final checks on the output arrays.

Copy link
Member

@rhugonnet rhugonnet Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think you should be able to pass all tests by just adding an Xarray input in the current list: https://github.com/ameliefroessl/xdem/blob/1c57a389b43bb252eedbf7f92f4164d6b9e443cc/tests/test_coreg/test_biascorr.py#L97
And adding a compute() only for tests using the output of apply() (the ones just calling apply() without using the output shouldn't need to, I think?), like here: https://github.com/ameliefroessl/xdem/blob/1c57a389b43bb252eedbf7f92f4164d6b9e443cc/tests/test_coreg/test_biascorr.py#L520, and other synthetic tests further below.
Except if I missed something...

@@ -460,7 +494,12 @@ def _apply_rst( # type: ignore

# Apply function to get correction (including if binning was done before)
if self._fit_or_bin in ["fit", "bin_and_fit"]:
corr = self._meta["fit_func"](tuple(bias_vars.values()), *self._meta["fit_params"])
if isinstance(list(bias_vars.values())[0], da.Array):
corr = fit_chunked(
Copy link
Member

@rhugonnet rhugonnet Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to do almost exactly the same map_blocks wrapping right below as you did forself._meta["fit_func"], this time the line corr = bin_interpolator(*args), and the line corr = get_perbin_nd_binning(*args)!

@rhugonnet rhugonnet changed the title Dask for XDEM Dask for coregistrations Oct 26, 2024
@rhugonnet rhugonnet changed the title Dask for coregistrations Dask for coregistration Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants