|
9 | 9 | import numpy as np
|
10 | 10 | import xarray as xr
|
11 | 11 | import tqdm
|
12 |
| - |
| 12 | +import logging |
| 13 | +import time |
13 | 14 | from . import custom_reducers
|
14 | 15 | from .preprocessing import rasterize
|
15 | 16 | from scipy.sparse import csr_matrix
|
@@ -47,197 +48,134 @@ def _rasterize(gdf, dataset, all_touched=False):
|
47 | 48 | yx_pos = _indices_sparse(feats)
|
48 | 49 | return feats, yx_pos
|
49 | 50 |
|
50 |
| - |
51 |
| -def zonal_stats_numpy( |
52 |
| - dataset, |
| 51 | +def _memory_time_chunks(dataset, memory=None): |
| 52 | + import psutil |
| 53 | + if memory is None: |
| 54 | + memory = psutil.virtual_memory().available/1e6 |
| 55 | + logging.debug(f"Hoping to use a maximum memory {memory}Mo.") |
| 56 | + nbytes_per_date = int(dataset.nbytes/1e6)/dataset.time.size*3 |
| 57 | + max_time_chunks = int(np.arange(0,memory,nbytes_per_date+0.1).size) |
| 58 | + time_chunks = int(dataset.time.size/np.arange(0,dataset.time.size,max_time_chunks).size) |
| 59 | + logging.debug(f"Mo per date : {nbytes_per_date:0.2f}, total : {(nbytes_per_date*dataset.time.size):0.2f}.") |
| 60 | + logging.debug(f"Time chunks : {time_chunks} (on {dataset.time.size} time).") |
| 61 | + return time_chunks |
| 62 | + |
| 63 | +def _zonal_stats_numpy( |
| 64 | + dataset:xr.Dataset, |
53 | 65 | positions,
|
54 |
| - gdf, |
55 |
| - operations=dict(mean=np.nanmean), |
56 |
| - all_touched=False, |
57 |
| -): |
58 |
| - def _get_field_dataset(positions, dataset): |
59 |
| - for idx, pos in enumerate(positions): |
60 |
| - if pos.size == 0: |
61 |
| - continue |
62 |
| - pos_xr = dict( |
63 |
| - x=xr.DataArray(pos[1], dims="z"), y=xr.DataArray(pos[0], dims="z") |
64 |
| - ) |
65 |
| - yield idx, dataset.isel(**pos_xr) |
66 |
| - |
67 |
| - def _zonal_stats_from_field(dc_field, operations, idx): |
68 |
| - return xr.concat( |
69 |
| - [ |
70 |
| - getattr(dc_field, reducer)("z").expand_dims( |
71 |
| - feature=[idx], zonal_stats=[reducer] |
72 |
| - ) |
73 |
| - for reducer in operations.keys() |
74 |
| - ], |
75 |
| - dim="zonal_stats", |
76 |
| - ) |
77 |
| - |
78 |
| - def compute_zonal_stats_apply_ufunc(dataset, positions, reducers): |
| 66 | + reducers:list=['mean'], |
| 67 | + all_touched=False): |
| 68 | + |
| 69 | + def _zonal_stats_ufunc(dataset, positions, reducers): |
79 | 70 | zs = []
|
80 | 71 | for idx in range(len(positions)):
|
81 | 72 | field_stats = []
|
82 | 73 | for reducer in reducers:
|
83 | 74 | field_arr = dataset[..., *positions[idx]]
|
84 |
| - field_arr = reducer(field_arr, axis=-1) |
| 75 | + func = f'nan{reducer}' if hasattr(np,f"nan{reducer}") else reducer |
| 76 | + field_arr = getattr(np, func)(field_arr, axis=-1) |
85 | 77 | field_stats.append(field_arr)
|
86 | 78 | field_stats = np.asarray(field_stats)
|
87 | 79 | zs.append(field_stats)
|
88 | 80 | zs = np.asarray(zs)
|
89 |
| - zs = zs.swapaxes(-2, 0) |
| 81 | + zs = zs.swapaxes(-1, 0).swapaxes(-1,-2) |
90 | 82 | return zs
|
91 |
| - |
92 |
| - result = xr.apply_ufunc( |
93 |
| - compute_zonal_stats_apply_ufunc, |
94 |
| - dataset.to_dataarray(dim="band"), |
| 83 | + |
| 84 | + |
| 85 | + dask_ufunc = "parallelized" |
| 86 | + |
| 87 | + zs = xr.apply_ufunc( |
| 88 | + _zonal_stats_ufunc, |
| 89 | + dataset, |
95 | 90 | vectorize=False,
|
96 |
| - dask="forbidden", |
97 |
| - input_core_dims=[["band", "y", "x"]], |
98 |
| - output_core_dims=[["zonal_stats", "feature", "band"]], |
| 91 | + dask=dask_ufunc, |
| 92 | + input_core_dims=[["y","x"]], |
| 93 | + output_core_dims=[["feature", "zonal_statistics"]], |
99 | 94 | exclude_dims=set(["x", "y"]),
|
100 | 95 | output_dtypes=[float],
|
101 |
| - output_sizes=dict(feature=len(positions), zonal_stats=len(operations.values())), |
102 |
| - kwargs=dict(reducers=operations.values(), positions=positions), |
| 96 | + kwargs=dict(reducers=reducers, positions=positions), |
| 97 | + dask_gufunc_kwargs={"allow_rechunk":True, |
| 98 | + "output_sizes":dict(geometry=len(positions), zonal_statistics=len(reducers))} |
103 | 99 | )
|
| 100 | + |
104 | 101 | del dataset
|
105 |
| - return result.to_dataset(dim="band") |
106 |
| - # zs = [] |
107 |
| - # for idx, dc_field in tqdm.tqdm(_get_field_dataset(positions, dataset),total=gdf.shape[0], mininterval=1, desc="Zonal stats"): |
108 |
| - # zs.append(_zonal_stats_from_field(dc_field, operations, idx)) |
109 |
| - # zs = xr.concat(zs, dim='feature') |
110 |
| - # return zs.transpose("feature", "time", "zonal_stats") |
111 |
| - |
112 |
| - |
113 |
| -def zonal_stats( |
114 |
| - dataset, |
115 |
| - gdf, |
116 |
| - operations: list = ["mean"], |
117 |
| - all_touched=False, |
118 |
| - method="geocube", |
119 |
| - verbose=False, |
120 |
| - raise_missing_geometry=False, |
121 |
| -): |
| 102 | + |
| 103 | + return zs |
| 104 | + |
| 105 | +def zonal_stats(dataset:xr.Dataset, |
| 106 | + geoms, |
| 107 | + method:str="numpy", |
| 108 | + smart_load:bool=False, |
| 109 | + memory:int = None, |
| 110 | + reducers:list=['mean'], |
| 111 | + all_touched = True): |
122 | 112 | """
|
123 |
| -
|
| 113 | + Xr Zonal stats using np.nan functions. |
124 | 114 |
|
125 | 115 | Parameters
|
126 | 116 | ----------
|
127 | 117 | dataset : xr.Dataset
|
128 | 118 | DESCRIPTION.
|
129 |
| - gdf : gpd.GeoDataFrame |
| 119 | + geoms : TYPE |
130 | 120 | DESCRIPTION.
|
131 |
| - operations : TYPE, list. |
132 |
| - DESCRIPTION. The default is ["mean"]. |
133 |
| - all_touched : TYPE, optional |
134 |
| - DESCRIPTION. The default is False. |
135 |
| - method : TYPE, optional |
136 |
| - DESCRIPTION. The default is "geocube". |
137 |
| - verbose : TYPE, optional |
138 |
| - DESCRIPTION. The default is False. |
139 |
| - raise_missing_geometry : TYPE, optional |
140 |
| - DESCRIPTION. The default is False. |
141 |
| -
|
142 |
| - Raises |
| 121 | + method : str |
| 122 | + "xvec" or "numpy". The default is "numpy". |
| 123 | + smart_load : bool |
| 124 | + Will load in memory the maximum of time and loop on it for "numpy" |
| 125 | + method. The default is False. |
| 126 | + memory : int, optional |
| 127 | + Only for the "numpy" method, by default it will take the maximum memory |
| 128 | + available. But in some cases it can be too much or too little. |
| 129 | + The default is None. |
| 130 | + reducers : list, optional |
| 131 | + Any np.nan function ("mean" is "np.nanmean"). The default is ['mean']. |
| 132 | +
|
| 133 | + Yields |
143 | 134 | ------
|
144 |
| - ValueError |
145 |
| - DESCRIPTION. |
146 |
| - NotImplementedError |
147 |
| - DESCRIPTION. |
148 |
| -
|
149 |
| - Returns |
150 |
| - ------- |
151 |
| - TYPE |
| 135 | + zs : TYPE |
152 | 136 | DESCRIPTION.
|
153 | 137 |
|
154 | 138 | """
|
155 |
| - |
156 |
| - if method == "geocube": |
157 |
| - from geocube.api.core import make_geocube |
158 |
| - from geocube.rasterize import rasterize_image |
159 |
| - |
160 |
| - def custom_rasterize_image(all_touched=all_touched, **kwargs): |
161 |
| - return rasterize_image(all_touched=all_touched, **kwargs) |
162 |
| - |
163 |
| - gdf["tmp_index"] = np.arange(gdf.shape[0]) |
164 |
| - out_grid = make_geocube( |
165 |
| - gdf, |
166 |
| - measurements=["tmp_index"], |
167 |
| - like=dataset, # ensure the data are on the same grid |
168 |
| - rasterize_function=custom_rasterize_image, |
169 |
| - ) |
170 |
| - cube = dataset.groupby(out_grid.tmp_index) |
171 |
| - zonal_stats = xr.concat( |
172 |
| - [getattr(cube, operation)() for operation in operations], dim="stats" |
173 |
| - ) |
174 |
| - zonal_stats["stats"] = operations |
175 |
| - |
176 |
| - if zonal_stats["tmp_index"].size != gdf.shape[0]: |
177 |
| - index_list = [ |
178 |
| - gdf.index[i] for i in zonal_stats["tmp_index"].values.astype(np.int16) |
179 |
| - ] |
180 |
| - if raise_missing_geometry: |
181 |
| - diff = gdf.shape[0] - len(index_list) |
182 |
| - raise ValueError( |
183 |
| - f'{diff} geometr{"y is" if diff==1 else "ies are"} missing in the zonal stats. This can be due to too small geometries, duplicated...' |
184 |
| - ) |
| 139 | + |
| 140 | + def _loop_time_chunks(dataset, method, smart_load, time_chunks): |
| 141 | + logging.debug(f"Batching every {time_chunks} dates ({np.ceil(dataset.time.size/time_chunks).astype(int)} loops).") |
| 142 | + for time_idx in tqdm.trange(0,dataset.time.size,time_chunks): |
| 143 | + isel_time = np.arange(time_idx,np.min((time_idx+time_chunks,dataset.time.size))) |
| 144 | + ds = dataset.copy().isel(time=isel_time) |
| 145 | + if smart_load: |
| 146 | + t0 = time.time() |
| 147 | + ds = ds.load() |
| 148 | + logging.debug(f'Subdataset of {ds.time.size} dates loaded in memory in {(time.time()-t0):0.2f}s.') |
| 149 | + t0 = time.time() |
| 150 | + # for method in tqdm.tqdm(["np"]): |
| 151 | + zs = _zonal_stats_numpy(ds, |
| 152 | + positions, |
| 153 | + reducers) |
| 154 | + zs = zs.load() |
| 155 | + del ds |
| 156 | + logging.debug(f'Zonal stats computed in {(time.time()-t0):0.2f}s.') |
| 157 | + yield zs |
| 158 | + |
| 159 | + t_start = time.time() |
| 160 | + dataset = dataset.rio.clip_box(*geoms.to_crs(dataset.rio.crs).total_bounds) |
| 161 | + if method == 'numpy': |
| 162 | + feats, yx_pos = _rasterize(geoms, dataset, all_touched=all_touched) |
| 163 | + positions = [np.asarray(yx_pos[i + 1]) for i in np.arange(geoms.shape[0])] |
| 164 | + positions = [position for position in positions if position.size>0] |
| 165 | + del feats,yx_pos |
| 166 | + time_chunks = _memory_time_chunks(dataset, memory) |
| 167 | + if smart_load: |
| 168 | + zs = xr.concat([z for z in _loop_time_chunks(dataset, method, smart_load, time_chunks)], dim="time") |
185 | 169 | else:
|
186 |
| - index_list = list(gdf.index) |
187 |
| - zonal_stats["tmp_index"] = index_list |
188 |
| - return zonal_stats.rename(dict(tmp_index="feature")) |
189 |
| - |
190 |
| - tqdm_bar = tqdm.tqdm(total=gdf.shape[0]) |
191 |
| - |
192 |
| - if dataset.rio.crs != gdf.crs: |
193 |
| - Warning( |
194 |
| - f"Different projections. Reproject vector to EPSG:{dataset.rio.crs.to_epsg()}." |
195 |
| - ) |
196 |
| - gdf = gdf.to_crs(dataset.rio.crs) |
197 |
| - |
198 |
| - zonal_ds_list = [] |
199 |
| - |
200 |
| - dataset = dataset.rio.clip_box(*gdf.to_crs(dataset.rio.crs).total_bounds) |
201 |
| - |
202 |
| - if method == "optimized": |
203 |
| - feats, yx_pos = _rasterize(gdf, dataset, all_touched=all_touched) |
204 |
| - |
205 |
| - for gdf_idx in tqdm.trange(gdf.shape[0], disable=not verbose): |
206 |
| - tqdm_bar.update(1) |
207 |
| - if gdf_idx + 1 >= len(yx_pos): |
208 |
| - continue |
209 |
| - yx_pos_idx = yx_pos[gdf_idx + 1] |
210 |
| - if np.asarray(yx_pos_idx).size == 0: |
211 |
| - continue |
212 |
| - datacube_spatial_subset = dataset.isel( |
213 |
| - x=xr.DataArray(yx_pos_idx[1], dims="xy"), |
214 |
| - y=xr.DataArray(yx_pos_idx[0], dims="xy"), |
215 |
| - ) |
216 |
| - del yx_pos_idx |
217 |
| - zonal_ds_list.append( |
218 |
| - datacube_time_stats(datacube_spatial_subset, operations).expand_dims( |
219 |
| - dim={"feature": [gdf.iloc[gdf_idx].name]} |
220 |
| - ) |
221 |
| - ) |
222 |
| - |
223 |
| - del yx_pos, feats |
224 |
| - |
225 |
| - elif method == "standard": |
226 |
| - for idx_gdb, feat in tqdm.tqdm( |
227 |
| - gdf.iterrows(), total=gdf.shape[0], disable=not verbose |
228 |
| - ): |
229 |
| - tqdm_bar.update(1) |
230 |
| - if feat.geometry.geom_type == "MultiPolygon": |
231 |
| - shapes = feat.geometry.geoms |
232 |
| - else: |
233 |
| - shapes = [feat.geometry] |
234 |
| - datacube_spatial_subset = dataset.rio.clip(shapes, all_touched=all_touched) |
235 |
| - |
236 |
| - zonal_feat = datacube_time_stats( |
237 |
| - datacube_spatial_subset, operations |
238 |
| - ).expand_dims(dim={"feature": [feat.name]}) |
239 |
| - |
240 |
| - zonal_ds_list.append(zonal_feat) |
241 |
| - else: |
242 |
| - raise NotImplementedError('method available are : "standard" or "optimized"') |
243 |
| - return xr.concat(zonal_ds_list, dim="feature") |
| 170 | + zs = _zonal_stats_numpy(dataset, |
| 171 | + positions, |
| 172 | + reducers) |
| 173 | + zs = zs.assign_coords(zonal_statistics=reducers)#,feature=geoms.to_crs('EPSG:4326').geometry) |
| 174 | + |
| 175 | + if method == "xvec": |
| 176 | + import xvec |
| 177 | + zs = dataset.xvec.zonal_stats(geoms.to_crs(dataset.rio.crs).geometry, y_coords='y',x_coords='x', stats=reducers, |
| 178 | + method="rasterize", all_touched=all_touched) |
| 179 | + logging.info(f"Zonal stats method {method} tooks {time.time()-t_start}s.") |
| 180 | + del dataset |
| 181 | + return zs |
0 commit comments