-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dask for coregistration #525
base: main
Are you sure you want to change the base?
Conversation
xdem/coreg/biascorr.py
Outdated
sigma = None | ||
elif isinstance(diff, da.Array): | ||
ydata = diff.vindex[subsample_mask].flatten().compute() # type:ignore [assignment] | ||
xdata = [var.vindex[subsample_mask].flatten() for var in bias_vars.values()] # type:ignore [assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you need a .compute()
here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, even if the bias_vars
are not delayed (no operation since their creation/passing in the _fit_rst_rst
) function, the compute seems needed, otherwise you get a dask array output that won't be understood by the following steps:
da.ones((10000, 10000)).vindex[[0, 1], [0, 1]]
Out[12]: dask.array<vindex-merge, shape=(2,), dtype=float64, chunksize=(2,), chunktype=numpy.ndarray>
da.ones((10000, 10000)).vindex[[0, 1], [0, 1]].compute()
Out[13]: array([1., 1.])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you! nice catch :) - changed it!
xdem/coreg/biascorr.py
Outdated
@@ -1077,7 +1107,10 @@ def _fit_rst_rst( # type: ignore | |||
p0 = np.ones(shape=((self._meta["poly_order"] + 1) ** 2)) | |||
|
|||
# Coordinates (we don't need the actual ones, just array coordinates) | |||
xx, yy = np.meshgrid(np.arange(0, ref_elev.shape[1]), np.arange(0, ref_elev.shape[0])) | |||
if type(ref_elev) == da.Array: | |||
xx, yy = da.meshgrid(da.arange(0, ref_elev.shape[1]), da.arange(0, ref_elev.shape[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should define the chunks
here to ensure the auto ones are not completely irrelevant and don't result in a lot of memory usage. Maybe simply copying the chunks of the reference DEM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it seems like chunks
is not a supported argument to da.meshgrid
unfortunately. i tried rechunking it but it doesn't seem to make that big of a difference computation-wise 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strange indeed, it seems important... We can check if that matters later with a memory usage-specific test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes i agree! I'll add a note so we remember to have a closer look at it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so i might have found a work around on how to calculate the meshgrid for a dask array.
Code
import dask.array as da
import numpy as np
data = np.array(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
)
m_grid = np.meshgrid(
np.arange(0, data.shape[0]),
np.arange(0, data.shape[1]),
)
def meshgrid(_, axis="x", block_info=None):
"""A bit of a hack to create a meshgrid for a dask array."""
loc = block_info[0]["array-location"]
mesh = np.meshgrid(np.arange(*loc[1]), np.arange(*loc[0]))
if axis == "x":
return mesh[0]
return mesh[1]
da_data = da.from_array(data, chunks=(2, 2))
m_grid_x = da.map_blocks(meshgrid, da_data, chunks=da_data.chunks, dtype=da_data.dtype).compute()
m_grid_y = da.map_blocks(meshgrid, da_data, axis="y", chunks=da_data.chunks, dtype=da_data.dtype).compute()
assert np.array_equal(m_grid[0], m_grid_x)
assert np.array_equal(m_grid[1], m_grid_y)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing!
I still found it weird that they didn't include a chunk argument directly in Dask, and found this interesting discussion:
dask/dask#2943 (comment) (and dask/dask#2943 (comment) is also important).
If I understood correctly, doing da.meshgrid(da.arange(, chunks=), da.arange(, chunks=)
should have the same behaviour as chunking classically! (or even more optimized thanks to what they do in broadcast_to
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! Adding the chunks to the input da.arange()
makes it possible to chunk the output array. However, I didn't find a way to combine it with delayed. Meaning even if it's chunked I think it loads the entire array into memory. Instead of "creating"? / "reading" ? that part only when needed in the processing. But I'd love to be proven wrong 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to combine it with delayed, because it's only sampled later by vindex
(no in-between calculations on the chunks), and that works well out-of-memory without delaying!
import dask.array as da
xx, yy = da.meshgrid(da.arange(0, 20000, chunks=500), da.arange(0, 20000, chunks=500))
xx.vindex[[0, 1], [0, 1]]
Out[60]: array([0, 1])
Which has no memory usage despite being a ~5GB array! 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why exactly but when I use this version of meshgrid the entire application just hangs without doing anything... maybe i'll investigate further but currently I can't seem to figure out what's the issue 🤔
) | ||
if all(isinstance(dem, (da.Array, Delayed)) for dem in (ref_elev, tba_elev)): | ||
diff = da.subtract(ref_elev, tba_elev) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to compute the valid_mask
here for the delayed inputs to ensure we don't introduce outliers from the ref_elev
or tba_elev
, same as below:
valid_mask = np.logical_and.reduce(
(inlier_mask, np.isfinite(diff), *(np.isfinite(var) for var in bias_vars.values()))
)
Not sure if this is supported directly, or if we need to use map_blocks
for it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess map_blocks
the same way you used it further up would work perfectly :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the reason I have not calculated this is because I calculate the inlier_mask from all of the inputs here . Taking also the non-valid values into account from the ref_elev
and tba_elev
. I think this would cover that? Or did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for the bias_vars
(that can have different non-finite values that will propagate during the fit), not covered by _get_valid_data_mask
!
And actually this will also be relevant in affine functions (for instance for slope/aspect in NuthKaab
that propagate NaNs of the DEM)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahhh ok ok. I missed that thank you! 👍 ill make sure to add it 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldn't have an effect on the Deramp
algorithm though since the bias_vars
is created from the input coordinates of the reference_elevation
input correct? I'm just asking because I'm running into some unexpected outputs when I have nans in the original input data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it shouldn't have any effect in this case, the meshgrid should be always a full array 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok yes that makes sense. I did end up figuring out what the issue was with the nan data. I was calculating the mask on the output from the delayed_reproject
(here is the change). I'm not 100% sure, but my assumption is that there is some smoothing/ resampling happening and i guess the nodata values accidentally get smeared into neighboring pixels?... Meaning it was potentially sampling some very high values and the estimated function was pretty off. This was my assumption after looking at the input y values going into the fit function.
798bdd4
to
1ff633d
Compare
6dbb6ce
to
0452bd8
Compare
xdem/coreg/biascorr.py
Outdated
@@ -230,11 +251,24 @@ def _fit_biascorr( # type: ignore | |||
"with function {}.".format(", ".join(list(bias_vars.keys())), self._meta["fit_func"].__name__) | |||
) | |||
|
|||
if isinstance(diff, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the data subsampling that can be moved out of the if loop "Option 1: Run fit...", then xdata
/ydata
/etc can be passed to every option similarly 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you're right! that makes sense :) 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is what you meant? 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which of the unit test would you expect to pass if passed an xarray instead of the normal input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly!
I guess for tests with only fit()
, everything should pass exactly the same with just the Xarray input.
For tests that also have apply()
, you'd have to write a check on the input type and call .compute()
to be able to run the final checks on the output arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think you should be able to pass all tests by just adding an Xarray input in the current list: https://github.com/ameliefroessl/xdem/blob/1c57a389b43bb252eedbf7f92f4164d6b9e443cc/tests/test_coreg/test_biascorr.py#L97
And adding a compute()
only for tests using the output of apply()
(the ones just calling apply()
without using the output shouldn't need to, I think?), like here: https://github.com/ameliefroessl/xdem/blob/1c57a389b43bb252eedbf7f92f4164d6b9e443cc/tests/test_coreg/test_biascorr.py#L520, and other synthetic tests further below.
Except if I missed something...
@@ -460,7 +494,12 @@ def _apply_rst( # type: ignore | |||
|
|||
# Apply function to get correction (including if binning was done before) | |||
if self._fit_or_bin in ["fit", "bin_and_fit"]: | |||
corr = self._meta["fit_func"](tuple(bias_vars.values()), *self._meta["fit_params"]) | |||
if isinstance(list(bias_vars.values())[0], da.Array): | |||
corr = fit_chunked( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to do almost exactly the same map_blocks
wrapping right below as you did forself._meta["fit_func"]
, this time the line corr = bin_interpolator(*args)
, and the line corr = get_perbin_nd_binning(*args)
!
Use dask to improve the memory usage of the
BiasCorr
methods.Remaining todo item:
_postprocess_coreg_apply_rst()
to_postprocess_coreg_apply_xarray
bias_vars
mask that you mentioned Dask for coregistration #525 (comment)rioxarray
+chunks
)