Skip to content

Commit 97fe2f6

Browse files
author
nicolasK
committed
perf: zonal stats
1 parent e66f940 commit 97fe2f6

File tree

2 files changed

+106
-168
lines changed

2 files changed

+106
-168
lines changed

earthdaily/earthdatastore/cube_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from rasterio.enums import Resampling
88
from shapely.geometry import box
99
from .geometry_manager import GeometryManager
10-
from ._zonal import zonal_stats, zonal_stats_numpy
10+
from ._zonal import zonal_stats
1111
from .harmonizer import Harmonizer
1212
from .asset_mapper import AssetMapper
1313
import rioxarray

earthdaily/earthdatastore/cube_utils/_zonal.py

Lines changed: 105 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import numpy as np
1010
import xarray as xr
1111
import tqdm
12-
12+
import logging
13+
import time
1314
from . import custom_reducers
1415
from .preprocessing import rasterize
1516
from scipy.sparse import csr_matrix
@@ -47,197 +48,134 @@ def _rasterize(gdf, dataset, all_touched=False):
4748
yx_pos = _indices_sparse(feats)
4849
return feats, yx_pos
4950

50-
51-
def zonal_stats_numpy(
52-
dataset,
51+
def _memory_time_chunks(dataset, memory=None):
52+
import psutil
53+
if memory is None:
54+
memory = psutil.virtual_memory().available/1e6
55+
logging.debug(f"Hoping to use a maximum memory {memory}Mo.")
56+
nbytes_per_date = int(dataset.nbytes/1e6)/dataset.time.size*3
57+
max_time_chunks = int(np.arange(0,memory,nbytes_per_date+0.1).size)
58+
time_chunks = int(dataset.time.size/np.arange(0,dataset.time.size,max_time_chunks).size)
59+
logging.debug(f"Mo per date : {nbytes_per_date:0.2f}, total : {(nbytes_per_date*dataset.time.size):0.2f}.")
60+
logging.debug(f"Time chunks : {time_chunks} (on {dataset.time.size} time).")
61+
return time_chunks
62+
63+
def _zonal_stats_numpy(
64+
dataset:xr.Dataset,
5365
positions,
54-
gdf,
55-
operations=dict(mean=np.nanmean),
56-
all_touched=False,
57-
):
58-
def _get_field_dataset(positions, dataset):
59-
for idx, pos in enumerate(positions):
60-
if pos.size == 0:
61-
continue
62-
pos_xr = dict(
63-
x=xr.DataArray(pos[1], dims="z"), y=xr.DataArray(pos[0], dims="z")
64-
)
65-
yield idx, dataset.isel(**pos_xr)
66-
67-
def _zonal_stats_from_field(dc_field, operations, idx):
68-
return xr.concat(
69-
[
70-
getattr(dc_field, reducer)("z").expand_dims(
71-
feature=[idx], zonal_stats=[reducer]
72-
)
73-
for reducer in operations.keys()
74-
],
75-
dim="zonal_stats",
76-
)
77-
78-
def compute_zonal_stats_apply_ufunc(dataset, positions, reducers):
66+
reducers:list=['mean'],
67+
all_touched=False):
68+
69+
def _zonal_stats_ufunc(dataset, positions, reducers):
7970
zs = []
8071
for idx in range(len(positions)):
8172
field_stats = []
8273
for reducer in reducers:
8374
field_arr = dataset[..., *positions[idx]]
84-
field_arr = reducer(field_arr, axis=-1)
75+
func = f'nan{reducer}' if hasattr(np,f"nan{reducer}") else reducer
76+
field_arr = getattr(np, func)(field_arr, axis=-1)
8577
field_stats.append(field_arr)
8678
field_stats = np.asarray(field_stats)
8779
zs.append(field_stats)
8880
zs = np.asarray(zs)
89-
zs = zs.swapaxes(-2, 0)
81+
zs = zs.swapaxes(-1, 0).swapaxes(-1,-2)
9082
return zs
91-
92-
result = xr.apply_ufunc(
93-
compute_zonal_stats_apply_ufunc,
94-
dataset.to_dataarray(dim="band"),
83+
84+
85+
dask_ufunc = "parallelized"
86+
87+
zs = xr.apply_ufunc(
88+
_zonal_stats_ufunc,
89+
dataset,
9590
vectorize=False,
96-
dask="forbidden",
97-
input_core_dims=[["band", "y", "x"]],
98-
output_core_dims=[["zonal_stats", "feature", "band"]],
91+
dask=dask_ufunc,
92+
input_core_dims=[["y","x"]],
93+
output_core_dims=[["feature", "zonal_statistics"]],
9994
exclude_dims=set(["x", "y"]),
10095
output_dtypes=[float],
101-
output_sizes=dict(feature=len(positions), zonal_stats=len(operations.values())),
102-
kwargs=dict(reducers=operations.values(), positions=positions),
96+
kwargs=dict(reducers=reducers, positions=positions),
97+
dask_gufunc_kwargs={"allow_rechunk":True,
98+
"output_sizes":dict(geometry=len(positions), zonal_statistics=len(reducers))}
10399
)
100+
104101
del dataset
105-
return result.to_dataset(dim="band")
106-
# zs = []
107-
# for idx, dc_field in tqdm.tqdm(_get_field_dataset(positions, dataset),total=gdf.shape[0], mininterval=1, desc="Zonal stats"):
108-
# zs.append(_zonal_stats_from_field(dc_field, operations, idx))
109-
# zs = xr.concat(zs, dim='feature')
110-
# return zs.transpose("feature", "time", "zonal_stats")
111-
112-
113-
def zonal_stats(
114-
dataset,
115-
gdf,
116-
operations: list = ["mean"],
117-
all_touched=False,
118-
method="geocube",
119-
verbose=False,
120-
raise_missing_geometry=False,
121-
):
102+
103+
return zs
104+
105+
def zonal_stats(dataset:xr.Dataset,
106+
geoms,
107+
method:str="numpy",
108+
smart_load:bool=False,
109+
memory:int = None,
110+
reducers:list=['mean'],
111+
all_touched = True):
122112
"""
123-
113+
Xr Zonal stats using np.nan functions.
124114
125115
Parameters
126116
----------
127117
dataset : xr.Dataset
128118
DESCRIPTION.
129-
gdf : gpd.GeoDataFrame
119+
geoms : TYPE
130120
DESCRIPTION.
131-
operations : TYPE, list.
132-
DESCRIPTION. The default is ["mean"].
133-
all_touched : TYPE, optional
134-
DESCRIPTION. The default is False.
135-
method : TYPE, optional
136-
DESCRIPTION. The default is "geocube".
137-
verbose : TYPE, optional
138-
DESCRIPTION. The default is False.
139-
raise_missing_geometry : TYPE, optional
140-
DESCRIPTION. The default is False.
141-
142-
Raises
121+
method : str
122+
"xvec" or "numpy". The default is "numpy".
123+
smart_load : bool
124+
Will load in memory the maximum of time and loop on it for "numpy"
125+
method. The default is False.
126+
memory : int, optional
127+
Only for the "numpy" method, by default it will take the maximum memory
128+
available. But in some cases it can be too much or too little.
129+
The default is None.
130+
reducers : list, optional
131+
Any np.nan function ("mean" is "np.nanmean"). The default is ['mean'].
132+
133+
Yields
143134
------
144-
ValueError
145-
DESCRIPTION.
146-
NotImplementedError
147-
DESCRIPTION.
148-
149-
Returns
150-
-------
151-
TYPE
135+
zs : TYPE
152136
DESCRIPTION.
153137
154138
"""
155-
156-
if method == "geocube":
157-
from geocube.api.core import make_geocube
158-
from geocube.rasterize import rasterize_image
159-
160-
def custom_rasterize_image(all_touched=all_touched, **kwargs):
161-
return rasterize_image(all_touched=all_touched, **kwargs)
162-
163-
gdf["tmp_index"] = np.arange(gdf.shape[0])
164-
out_grid = make_geocube(
165-
gdf,
166-
measurements=["tmp_index"],
167-
like=dataset, # ensure the data are on the same grid
168-
rasterize_function=custom_rasterize_image,
169-
)
170-
cube = dataset.groupby(out_grid.tmp_index)
171-
zonal_stats = xr.concat(
172-
[getattr(cube, operation)() for operation in operations], dim="stats"
173-
)
174-
zonal_stats["stats"] = operations
175-
176-
if zonal_stats["tmp_index"].size != gdf.shape[0]:
177-
index_list = [
178-
gdf.index[i] for i in zonal_stats["tmp_index"].values.astype(np.int16)
179-
]
180-
if raise_missing_geometry:
181-
diff = gdf.shape[0] - len(index_list)
182-
raise ValueError(
183-
f'{diff} geometr{"y is" if diff==1 else "ies are"} missing in the zonal stats. This can be due to too small geometries, duplicated...'
184-
)
139+
140+
def _loop_time_chunks(dataset, method, smart_load, time_chunks):
141+
logging.debug(f"Batching every {time_chunks} dates ({np.ceil(dataset.time.size/time_chunks).astype(int)} loops).")
142+
for time_idx in tqdm.trange(0,dataset.time.size,time_chunks):
143+
isel_time = np.arange(time_idx,np.min((time_idx+time_chunks,dataset.time.size)))
144+
ds = dataset.copy().isel(time=isel_time)
145+
if smart_load:
146+
t0 = time.time()
147+
ds = ds.load()
148+
logging.debug(f'Subdataset of {ds.time.size} dates loaded in memory in {(time.time()-t0):0.2f}s.')
149+
t0 = time.time()
150+
# for method in tqdm.tqdm(["np"]):
151+
zs = _zonal_stats_numpy(ds,
152+
positions,
153+
reducers)
154+
zs = zs.load()
155+
del ds
156+
logging.debug(f'Zonal stats computed in {(time.time()-t0):0.2f}s.')
157+
yield zs
158+
159+
t_start = time.time()
160+
dataset = dataset.rio.clip_box(*geoms.to_crs(dataset.rio.crs).total_bounds)
161+
if method == 'numpy':
162+
feats, yx_pos = _rasterize(geoms, dataset, all_touched=all_touched)
163+
positions = [np.asarray(yx_pos[i + 1]) for i in np.arange(geoms.shape[0])]
164+
positions = [position for position in positions if position.size>0]
165+
del feats,yx_pos
166+
time_chunks = _memory_time_chunks(dataset, memory)
167+
if smart_load:
168+
zs = xr.concat([z for z in _loop_time_chunks(dataset, method, smart_load, time_chunks)], dim="time")
185169
else:
186-
index_list = list(gdf.index)
187-
zonal_stats["tmp_index"] = index_list
188-
return zonal_stats.rename(dict(tmp_index="feature"))
189-
190-
tqdm_bar = tqdm.tqdm(total=gdf.shape[0])
191-
192-
if dataset.rio.crs != gdf.crs:
193-
Warning(
194-
f"Different projections. Reproject vector to EPSG:{dataset.rio.crs.to_epsg()}."
195-
)
196-
gdf = gdf.to_crs(dataset.rio.crs)
197-
198-
zonal_ds_list = []
199-
200-
dataset = dataset.rio.clip_box(*gdf.to_crs(dataset.rio.crs).total_bounds)
201-
202-
if method == "optimized":
203-
feats, yx_pos = _rasterize(gdf, dataset, all_touched=all_touched)
204-
205-
for gdf_idx in tqdm.trange(gdf.shape[0], disable=not verbose):
206-
tqdm_bar.update(1)
207-
if gdf_idx + 1 >= len(yx_pos):
208-
continue
209-
yx_pos_idx = yx_pos[gdf_idx + 1]
210-
if np.asarray(yx_pos_idx).size == 0:
211-
continue
212-
datacube_spatial_subset = dataset.isel(
213-
x=xr.DataArray(yx_pos_idx[1], dims="xy"),
214-
y=xr.DataArray(yx_pos_idx[0], dims="xy"),
215-
)
216-
del yx_pos_idx
217-
zonal_ds_list.append(
218-
datacube_time_stats(datacube_spatial_subset, operations).expand_dims(
219-
dim={"feature": [gdf.iloc[gdf_idx].name]}
220-
)
221-
)
222-
223-
del yx_pos, feats
224-
225-
elif method == "standard":
226-
for idx_gdb, feat in tqdm.tqdm(
227-
gdf.iterrows(), total=gdf.shape[0], disable=not verbose
228-
):
229-
tqdm_bar.update(1)
230-
if feat.geometry.geom_type == "MultiPolygon":
231-
shapes = feat.geometry.geoms
232-
else:
233-
shapes = [feat.geometry]
234-
datacube_spatial_subset = dataset.rio.clip(shapes, all_touched=all_touched)
235-
236-
zonal_feat = datacube_time_stats(
237-
datacube_spatial_subset, operations
238-
).expand_dims(dim={"feature": [feat.name]})
239-
240-
zonal_ds_list.append(zonal_feat)
241-
else:
242-
raise NotImplementedError('method available are : "standard" or "optimized"')
243-
return xr.concat(zonal_ds_list, dim="feature")
170+
zs = _zonal_stats_numpy(dataset,
171+
positions,
172+
reducers)
173+
zs = zs.assign_coords(zonal_statistics=reducers)#,feature=geoms.to_crs('EPSG:4326').geometry)
174+
175+
if method == "xvec":
176+
import xvec
177+
zs = dataset.xvec.zonal_stats(geoms.to_crs(dataset.rio.crs).geometry, y_coords='y',x_coords='x', stats=reducers,
178+
method="rasterize", all_touched=all_touched)
179+
logging.info(f"Zonal stats method {method} tooks {time.time()-t_start}s.")
180+
del dataset
181+
return zs

0 commit comments

Comments
 (0)