diff --git a/.github/workflows/pytest-dev.yaml b/.github/workflows/pytest-dev.yaml new file mode 100644 index 00000000..29d7314b --- /dev/null +++ b/.github/workflows/pytest-dev.yaml @@ -0,0 +1,43 @@ +name: pytest-dev +on: + push: + branches: + - dev +jobs: + pytest: + name: pytest + strategy: + matrix: + os: + - ubuntu-latest + python-version: + - "3.10" + - "3.11" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Make conda environment + uses: mamba-org/setup-micromamba@v1 + with: + micromamba-version: 1.5.1-0 + environment-file: requirements_dev.yml + environment-name: earthdaily + create-args: python=${{ matrix.python-version }} + cache-environment: true + post-cleanup: all + init-shell: >- + bash + - name: Installing the library + shell: bash -l {0} + run: | + pip install . + - name: Run tests and coverage + shell: bash -l {0} + env: + EDS_CLIENT_ID: ${{ secrets.EDS_CLIENT_ID }} + EDS_AUTH_URL: ${{ secrets.EDS_AUTH_URL }} + EDS_SECRET: ${{ secrets.EDS_SECRET }} + EDS_API_URL: ${{ secrets.EDS_API_URL }} + run: pytest -v tests --cov=earthdaily diff --git a/.github/workflows/pytest-prod.yaml b/.github/workflows/pytest-prod.yaml new file mode 100644 index 00000000..75062ed0 --- /dev/null +++ b/.github/workflows/pytest-prod.yaml @@ -0,0 +1,45 @@ +name: pytest-main +on: + push: + branches: + - main +jobs: + pytest: + name: pytest + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + python-version: + - "3.10" + - "3.11" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Make conda environment + uses: mamba-org/setup-micromamba@v1 + with: + micromamba-version: 1.5.1-0 + environment-file: requirements_dev.yml + environment-name: earthdaily + create-args: python=${{ matrix.python-version }} + cache-environment: true + post-cleanup: all + init-shell: >- + bash + - name: Installing the library + shell: bash -l {0} + run: | + pip install . + - name: Run tests and coverage + shell: bash -l {0} + env: + EDS_CLIENT_ID: ${{ secrets.EDS_CLIENT_ID }} + EDS_AUTH_URL: ${{ secrets.EDS_AUTH_URL }} + EDS_SECRET: ${{ secrets.EDS_SECRET }} + EDS_API_URL: ${{ secrets.EDS_API_URL }} + run: pytest -v tests --cov=earthdaily diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml new file mode 100644 index 00000000..7db32d96 --- /dev/null +++ b/.github/workflows/ruff.yaml @@ -0,0 +1,26 @@ +name: ruff + +on: + push: + branches: + - dev +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: chartboost/ruff-action@v1 + with: + src: "./earthdaily" + args: --ignore F401 + - uses: chartboost/ruff-action@v1 + with: + src: "./earthdaily" + args: --ignore F401 + - uses: chartboost/ruff-action@v1 + with: + src: "./earthdaily" + args: format + - uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: 'gh-actions/python format using ruff' \ No newline at end of file diff --git a/.gitignore b/.gitignore index 22971554..ee71bb7a 100644 --- a/.gitignore +++ b/.gitignore @@ -64,4 +64,5 @@ docs/_* target/ # Auth -.env \ No newline at end of file +.env +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index a75484d3..ca975835 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.0.1-rc4] +## [0.0.1-rc5] + +### Added + +- `common_band_names` default set to True in datacube creation. It uses the new Assets Mapper in order to define to best suitable bands according to user needs. +- `clear_cover` in the datacube method when using the `earthdatastore.Auth` method. + +### Changed + +- masks statistics are not anymore suffixed with the cloudmask type : `clear_percent`and `clear_pixels`. Warns with a DeprecationWarning. +- all queries in `post_query` must return True to keep the item. If a key doesn't exist, considers the result as False (instead of failling). + +### Fixed + +- search `post_query` do not block if some properties are not available on all items. + +## [0.0.1-rc4] 2023-10-19 ### Changed diff --git a/README.md b/README.md index cd18709f..36814819 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Earthdaily Python Package [![PyPI version](https://badge.fury.io/py/earthdaily.png)](https://badge.fury.io/py/earthdaily) [![Documentation](https://img.shields.io/badge/Documentation-html-green.svg)](https://geosys.github.io/earthdaily-python-client/) +[![pytest](https://github.com/GEOSYS/earthdaily-python-client/actions/workflows/pytest.yaml/badge.svg?branch=main)](https://github.com/GEOSYS/earthdaily-python-client/actions/workflows/pytest.yaml) ## Your Gateway to the Stac Catalog Earth Data Store diff --git a/earthdaily/__init__.py b/earthdaily/__init__.py index 94d8f0bb..c1da8f75 100644 --- a/earthdaily/__init__.py +++ b/earthdaily/__init__.py @@ -1,3 +1,3 @@ from earthdaily import earthdatastore -__version__ = "0.0.1-rc4" +__version__ = "0.0.1-rc5" diff --git a/earthdaily/earthdatastore/__init__.py b/earthdaily/earthdatastore/__init__.py index 8dacb6dd..f9ec0e05 100644 --- a/earthdaily/earthdatastore/__init__.py +++ b/earthdaily/earthdatastore/__init__.py @@ -1,16 +1,19 @@ -import xarray as xr import json -from pystac_client import Client -from pystac.item_collection import ItemCollection -import requests -import pandas as pd -import geopandas as gpd -import os -import operator -from earthdaily.earthdatastore import mask, _scales_collections -from earthdaily.earthdatastore.cube_utils import datacube, metacube import logging +import operator +import os +import warnings +import geopandas as gpd +import pandas as pd +import requests +import xarray as xr +from pystac.item_collection import ItemCollection +from pystac_client import Client + +from earthdaily.earthdatastore import _scales_collections, mask +from earthdaily.earthdatastore import cube_utils +from earthdaily.earthdatastore.cube_utils import datacube, metacube logging.getLogger("earthdaily-earthdatastore") @@ -18,27 +21,26 @@ def post_query_items(items, query): items_ = [] for idx, item in enumerate(items): - item_already_checked = False - while not item_already_checked: - for k, v in query.items(): - for v_op, v_val in v.items(): - if isinstance(v_val, list): - for v_val_ in v_val: - operation = operator.__dict__[v_op]( - item.properties[k], v_val_ - ) - - if operation: - items_.append(item) - item_already_checked = True - else: - operation = operator.__dict__[v_op]( - item.properties[k], v_val - ) + queries_results = 0 + for k, v in query.items(): + if k not in item.properties.keys(): + continue + for v_op, v_val in v.items(): + if isinstance(v_val, list): + results = 0 + for v_val_ in v_val: + operation = operator.__dict__[v_op](item.properties[k], v_val_) + + if operation: + results += 1 + if results == len(v_val): + queries_results += 1 + else: + operation = operator.__dict__[v_op](item.properties[k], v_val) if operation: - items_.append(item) - item_already_checked = True - break + queries_results += 1 + if queries_results == len(query.keys()): + items_.append(item) items = ItemCollection(items_) return items @@ -63,7 +65,10 @@ def _cloud_path_to_http(cloud_path): def enhance_assets( - items, alternate="s3", use_http_url=False, add_default_scale_factor=False + items, + alternate="download", + use_http_url=False, + add_default_scale_factor=False, ): if any((alternate, use_http_url, add_default_scale_factor)): for idx, item in enumerate(items): @@ -83,9 +88,7 @@ def enhance_assets( if use_http_url: href = item.assets[asset].to_dict().get("href", {}) if href: - items[idx].assets[asset].href = _cloud_path_to_http( - href - ) + items[idx].assets[asset].href = _cloud_path_to_http(href) if add_default_scale_factor: scale_factor_collection = ( _scales_collections.scale_factor_collections.get( @@ -95,8 +98,8 @@ def enhance_assets( for scales_collection in scale_factor_collection: if asset in scales_collection.get("assets", []): if ( - not "raster:bands" - in items[idx].assets[asset].extra_fields + "raster:bands" + not in items[idx].assets[asset].extra_fields ): items[idx].assets[asset].extra_fields[ "raster:bands" @@ -107,15 +110,15 @@ def enhance_assets( .extra_fields["raster:bands"][0] .get("scale") ): - items[idx].assets[asset].extra_fields[ - "raster:bands" - ][0]["scale"] = scales_collection["scale"] - items[idx].assets[asset].extra_fields[ - "raster:bands" - ][0]["offset"] = scales_collection["offset"] - items[idx].assets[asset].extra_fields[ - "raster:bands" - ][0]["nodata"] = scales_collection["nodata"] + items[idx].assets[asset].extra_fields["raster:bands"][ + 0 + ]["scale"] = scales_collection["scale"] + items[idx].assets[asset].extra_fields["raster:bands"][ + 0 + ]["offset"] = scales_collection["offset"] + items[idx].assets[asset].extra_fields["raster:bands"][ + 0 + ]["nodata"] = scales_collection["nodata"] return items @@ -130,9 +133,7 @@ def _get_client(config=None): auth_url = config("EDS_AUTH_URL") secret = config("EDS_SECRET") client_id = config("EDS_CLIENT_ID") - eds_url = config( - "EDS_API_URL", "https://api.eds.earthdaily.com/archive/v1/stac/v1" - ) + eds_url = config("EDS_API_URL", "https://api.eds.earthdaily.com/archive/v1/stac/v1") if auth_url is None or secret is None or client_id is None: raise AttributeError( "You need to have env : EDS_AUTH_URL, EDS_SECRET and EDS_CLIENT_ID" @@ -172,10 +173,7 @@ def __first_item(self): @property def item_properties(self): - return { - k: self.item.properties[k] - for k in sorted(self.item.properties.keys()) - } + return {k: self.item.properties[k] for k in sorted(self.item.properties.keys())} def assets(self, asset_name=None): if asset_name: @@ -349,12 +347,36 @@ def explore(self, collection: str = None): """ if collection: if collection not in self._staccollectionexplorer.keys(): - self._staccollectionexplorer[ - collection - ] = StacCollectionExplorer(self.client, collection) + self._staccollectionexplorer[collection] = StacCollectionExplorer( + self.client, collection + ) return self._staccollectionexplorer.get(collection) return sorted(c.id for c in self.client.get_all_collections()) + def _update_search_kwargs_for_ag_cloud_mask(self, search_kwargs, collections): + search_kwargs = search_kwargs.copy() + # to get only items that have a ag_cloud_mask + ag_query = {"eda:ag_cloud_mask_available": {"eq": True}} + target_param = "query" + + # to check if field is queryable + # ============================================================================= + # queryables = self.client._stac_io.request( + # self.client.get_root_link().href + # + f"/queryables?collections={collections[0] if isinstance(collections,list) else collections}" + # ) + # queryables = json.loads(queryables) + # queryables = queryables["properties"] + # if "eda:ag_cloud_mask_available" not in queryables.keys(): + # target_param = "post_query" + # else: + # target_param = "query" + # ============================================================================= + query = search_kwargs.get("target_param", {}) + query.update(ag_query) + search_kwargs[target_param] = query + return search_kwargs + def datacube( self, collections: str | list, @@ -364,50 +386,60 @@ def datacube( bbox=None, mask_with: None | str = None, mask_statistics: bool | int = False, - prefer_alternate: (str, False) = "s3", - prefer_http: bool = True, + clear_cover: (int, float) = None, + prefer_alternate: (str, False) = "download", search_kwargs: dict = {}, add_default_scale_factor: bool = True, + common_band_names=True, **kwargs, - ): + ) -> xr.Dataset: + if mask_with and common_band_names: + if isinstance(collections, list): + if len(collections) > 1: + raise ValueError( + "Mask_with and assets_mapping only manage one collection at a time." + ) if mask_with: if mask_with not in mask._available_masks: raise ValueError( f"Specified mask '{mask_with}' is not available.\ Currently available masks provider are : {mask._available_masks}" ) - if isinstance(collections, list): - if len(collections) > 1: - raise ValueError( - "Mask_with only manage one collection at a time." - ) collection = collections[0] else: collection = collections - if mask_with == "ag_cloud_mask": - # to get only items that have a ag_cloud_mask - query = search_kwargs.get("query", {}) - query.update({"eda:ag_cloud_mask_available": {"eq": True}}) - search_kwargs["query"] = query + if mask_with == "ag_cloud_mask": + search_kwargs = self._update_search_kwargs_for_ag_cloud_mask( + search_kwargs, collections + ) items = self.search( collections=collections, bbox=bbox, intersects=intersects, datetime=datetime, - prefer_http=prefer_http, prefer_alternate=prefer_alternate, add_default_scale_factor=add_default_scale_factor, **search_kwargs, ) + xr_datacube = datacube( - items, intersects=intersects, bbox=bbox, assets=assets, **kwargs + items, + intersects=intersects, + bbox=bbox, + assets=assets, + common_band_names=common_band_names, + **kwargs, ) if mask_with: - if mask_with == "native": - mask_with = mask._native_mask_def_mapping.get( - collection, None + if clear_cover and mask_statistics is False: + mask_statistics = True + warnings.warn( + "Forcing mask_statistics in order to filter by clear coverage.", + category=Warning, ) + if mask_with == "native": + mask_with = mask._native_mask_def_mapping.get(collection, None) if mask_with is None: raise ValueError( f"Sorry, there's no native mask available for {collection}. Only these collections have native cloudmask : {list(mask._native_mask_mapping.keys())}." @@ -422,6 +454,7 @@ def datacube( groupby_date="max", epsg=xr_datacube.rio.crs.to_epsg(), resolution=xr_datacube.rio.resolution()[0], + prefer_alternate="download", ) xr_datacube["time"] = xr_datacube.time.astype("M8[s]") acm_datacube["time"] = acm_datacube.time.astype("M8[s]") @@ -431,7 +464,7 @@ def datacube( mask_assets = mask._native_mask_asset_mapping[collections] if "groupby_date" in kwargs: kwargs["groupby_date"] = "max" - if not "resolution" in kwargs: + if "resolution" not in kwargs: kwargs["resolution"] = xr_datacube.rio.resolution()[0] clouds_datacube = datacube( items, @@ -447,6 +480,9 @@ def datacube( Mask = mask.Mask(xr_datacube, intersects=intersects, bbox=bbox) xr_datacube = getattr(Mask, mask_with)(**mask_kwargs) + + if clear_cover: + xr_datacube = mask.filter_clear_cover(xr_datacube, clear_cover) return xr_datacube def search( @@ -456,7 +492,6 @@ def search( bbox=None, post_query=None, prefer_alternate=None, - prefer_http=False, add_default_scale_factor=False, **kwargs, ): @@ -475,8 +510,6 @@ def search( DESCRIPTION. The default is None. prefer_alternate : TYPE, optional DESCRIPTION. The default is None. - prefer_http : TYPE, optional - DESCRIPTION. The default is False. **kwargs : TYPE DESCRIPTION. @@ -577,11 +610,10 @@ def search( **kwargs, ) items_collection = items_collection.item_collection() - if any((prefer_alternate, prefer_http, add_default_scale_factor)): + if any((prefer_alternate, add_default_scale_factor)): items_collection = enhance_assets( items_collection.clone(), alternate=prefer_alternate, - use_http_url=prefer_http, add_default_scale_factor=add_default_scale_factor, ) if post_query: @@ -596,9 +628,7 @@ def ag_cloud_mask_from_items(items): for item in items: if not item.properties.get("eda:ag_cloud_mask_available"): continue - collection = item.properties[ - "eda:ag_cloud_mask_collection_id" - ] + collection = item.properties["eda:ag_cloud_mask_collection_id"] if products.get(collection, None) is None: products[collection] = [] products[collection].append( @@ -628,7 +658,7 @@ def item_property_to_df( if asset is not None and asset.to_dict() is not None: try: properties = asset.to_dict()[property_name] - except: + except NameError: print( f'No property "{property_name}" has been found in the asset "{asset}".' ) diff --git a/earthdaily/earthdatastore/cube_utils/__init__.py b/earthdaily/earthdatastore/cube_utils/__init__.py index 53fdf929..3b706755 100644 --- a/earthdaily/earthdatastore/cube_utils/__init__.py +++ b/earthdaily/earthdatastore/cube_utils/__init__.py @@ -3,7 +3,8 @@ import pandas as pd import geopandas as gpd from shapely.geometry import box -from earthdaily.earthdatastore.cube_utils._zonal import zonal_stats, zonal_stats_numpy +from ._zonal import zonal_stats, zonal_stats_numpy +from .asset_mapper import AssetMapper from rasterio.enums import Resampling from rasterio.mask import geometry_mask import rioxarray as rxr @@ -42,7 +43,7 @@ def _cube_odc(items_collection, assets=None, times=None, **kwargs): if "resampling" in kwargs: if isinstance(kwargs["resampling"], int): kwargs["resampling"] = Resampling(kwargs["resampling"]).name - chunks = kwargs.get("chunks", dict(x=2048, y=2048, time=1)) + chunks = kwargs.get("chunks", dict(x="auto", y="auto", time="auto")) kwargs.pop("chunks", None) ds = stac.load( @@ -90,6 +91,7 @@ def datacube( engine="odc", rescale=True, groupby_date="mean", + common_band_names=True, **kwargs, ): logging.info(f"Building datacube with {len(items_collection)} items") @@ -106,6 +108,10 @@ def datacube( raise NotImplementedError( f"Engine '{engine}' not supported. Only {' and '.join(list(engines.keys()))} are currently supported." ) + if common_band_names and not isinstance(assets, dict): + aM = AssetMapper() + assets = aM.map_collection_bands(items_collection[0].collection_id, assets) + if isinstance(assets, dict): assets_keys = list(assets.keys()) ds = engines[engine]( @@ -192,7 +198,7 @@ def rescale_assets_with_items( scales = dict() if len(items_collection) > ds.time.size: unique_dt = {} - items_collection_unique_dt = [] + # items_collection_unique_dt = [] for item in items_collection: if item.datetime in unique_dt.keys(): for asset in item.assets.keys(): @@ -245,7 +251,7 @@ def rescale_assets_with_items( scale = rasterbands.get("scale", None) if offset or scale: - if not ds_asset in scales: + if ds_asset not in scales: scales[ds_asset] = {} scale = rasterbands.get("scale", 1) diff --git a/earthdaily/earthdatastore/cube_utils/_zonal.py b/earthdaily/earthdatastore/cube_utils/_zonal.py index 9ea41196..0ed284b6 100644 --- a/earthdaily/earthdatastore/cube_utils/_zonal.py +++ b/earthdaily/earthdatastore/cube_utils/_zonal.py @@ -16,9 +16,7 @@ def _compute_M(data): cols = np.arange(data.size) - return csr_matrix( - (cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size) - ) + return csr_matrix((cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size)) def _indices_sparse(data): @@ -45,11 +43,8 @@ def datacube_time_stats(datacube, operations): def _rasterize(gdf, dataset, all_touched=False): feats = rasterize(gdf, dataset, all_touched=all_touched) - idx_start = 0 - if 0 in feats: - idx_start = 1 yx_pos = _indices_sparse(feats) - return feats, yx_pos, idx_start + return feats, yx_pos def zonal_stats_numpy( @@ -60,11 +55,11 @@ def zonal_stats_numpy( preload_datavar=False, ): tqdm_bar = tqdm.tqdm(total=len(dataset.data_vars) * dataset.time.size) + dataset = dataset.rio.clip_box(*gdf.to_crs(dataset.rio.crs).total_bounds) - feats, yx_pos, idx_start = _rasterize( - gdf, dataset, all_touched=all_touched - ) + feats, yx_pos = _rasterize(gdf, dataset, all_touched=all_touched) ds = [] + features_idx = [] for data_var in dataset.data_vars: tqdm_bar.set_description(data_var) dataset_var = dataset[data_var] @@ -76,9 +71,19 @@ def zonal_stats_numpy( vals[t] = [] mem_asset = dataset_var.isel(time=t).to_numpy() for i in range(gdf.shape[0]): - pos = yx_pos[i + idx_start] - data = mem_asset[pos] - res = [operation(data) for operation in operations.values()] + features_idx.append(i) + if len(yx_pos) <= i + 1: + break + pos = np.asarray(yx_pos[i + 1]) + # mem_asset[*pos] only for python>=3.11 + if len(pos) == 2: + data = mem_asset[pos[0], pos[1]] + elif len(pos) == 1: + data = mem_asset[pos[0]] + if data.size > 0: + res = [operation(data) for operation in operations.values()] + else: + res = [np.nan for operation in operations] vals[t].append(res) arr = np.asarray([vals[v] for v in vals]) @@ -87,14 +92,14 @@ def zonal_stats_numpy( dims=["time", "feature", "stats"], coords=dict( time=dataset_var.time.values, - feature=gdf.index, + feature=gdf.index[np.nonzero(np.unique(feats))[0] - 1], stats=list(operations.keys()), ), ) del arr, mem_asset, vals, dataset_var ds.append(da.to_dataset(name=data_var)) tqdm_bar.close() - return xr.merge(ds) + return xr.merge(ds).transpose("feature", "time", "stats") def zonal_stats( @@ -115,23 +120,27 @@ def zonal_stats( zonal_ds_list = [] + dataset = dataset.rio.clip_box(*gdf.to_crs(dataset.rio.crs).total_bounds) + if method == "optimized": - feats, yx_pos, idx_start = _rasterize( - gdf, dataset, all_touched=all_touched - ) + feats, yx_pos = _rasterize(gdf, dataset, all_touched=all_touched) for gdf_idx in tqdm.trange(gdf.shape[0], disable=not verbose): tqdm_bar.update(1) - yx_pos_idx = yx_pos[gdf_idx + idx_start] + if gdf_idx + 1 >= len(yx_pos): + continue + yx_pos_idx = yx_pos[gdf_idx + 1] + if np.asarray(yx_pos_idx).size == 0: + continue datacube_spatial_subset = dataset.isel( x=xr.DataArray(yx_pos_idx[1], dims="xy"), y=xr.DataArray(yx_pos_idx[0], dims="xy"), ) del yx_pos_idx zonal_ds_list.append( - datacube_time_stats( - datacube_spatial_subset, operations - ).expand_dims(dim={"feature": [gdf.iloc[gdf_idx].name]}) + datacube_time_stats(datacube_spatial_subset, operations).expand_dims( + dim={"feature": [gdf.iloc[gdf_idx].name]} + ) ) del yx_pos, feats @@ -145,9 +154,7 @@ def zonal_stats( shapes = feat.geometry.geoms else: shapes = [feat.geometry] - datacube_spatial_subset = dataset.rio.clip( - shapes, all_touched=all_touched - ) + datacube_spatial_subset = dataset.rio.clip(shapes, all_touched=all_touched) zonal_feat = datacube_time_stats( datacube_spatial_subset, operations @@ -155,7 +162,5 @@ def zonal_stats( zonal_ds_list.append(zonal_feat) else: - raise NotImplementedError( - 'method available are : "standard" or "optimized"' - ) + raise NotImplementedError('method available are : "standard" or "optimized"') return xr.concat(zonal_ds_list, dim="feature") diff --git a/earthdaily/earthdatastore/cube_utils/asset_mapper/__init__.py b/earthdaily/earthdatastore/cube_utils/asset_mapper/__init__.py new file mode 100644 index 00000000..2ca7216d --- /dev/null +++ b/earthdaily/earthdatastore/cube_utils/asset_mapper/__init__.py @@ -0,0 +1,49 @@ +from . import _asset_mapper_config + + +class AssetMapper: + def __init__(self): + self.available_collections = list( + _asset_mapper_config.asset_mapper_collections.keys() + ) + + def collection_mapping(self, collection): + if self._collection_exists(collection, raise_warning=True): + return _asset_mapper_config.asset_mapper_collections[collection] + + def _collection_exists(self, collection, raise_warning=False): + exists = True if collection in self.available_collections else False + if raise_warning and not exists: + raise NotImplementedError( + f"Collection {collection} has not been implemented" + ) + return exists + + def map_collection_bands(self, collection, bands): + if isinstance(bands, dict): + return bands + if not self._collection_exists(collection): + return bands + + # HANDLE LIST TO DICT CONVERSION + if isinstance(bands, list): + bands = {band: band for band in bands} + + output_bands = {} + + config = self.collection_mapping(collection) + + # Try to map each band + for band in bands: + if band in config[0]: + output_bands[config[0][band]] = band + # No band found with specified key (common band name) + else: + # Looking for band matching the specified value (asset name) + matching_assets = [ + key for key, value in config[0].items() if value == band + ] + + if matching_assets: + output_bands[band] = band + return output_bands diff --git a/earthdaily/earthdatastore/cube_utils/asset_mapper/_asset_mapper_config.py b/earthdaily/earthdatastore/cube_utils/asset_mapper/_asset_mapper_config.py new file mode 100644 index 00000000..89110b9c --- /dev/null +++ b/earthdaily/earthdatastore/cube_utils/asset_mapper/_asset_mapper_config.py @@ -0,0 +1,158 @@ +asset_mapper_collections = { + "sentinel-2-l2a": [ + { + "coastal44": "coastal", + "coastal": "coastal", + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + "nir08": "nir08", + "nir09": "nir09", + "swir16": "swir16", + "swir22": "swir22", + "rededge": "rededge2", + "rededge70": "rededge1", + "rededge74": "rededge2", + "rededge78": "rededge3", + "watervapor": "wvp", + "native-cloud-mask": "scl", + } + ], + "landsat-c2l2-sr": [ + { + "coastal44": "coastal", + "coastal": "coastal", + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir08", + "nir08": "nir08", + "swir16": "swir16", + "swir22": "swir22", + "native-cloud-mask": "qa-pixel", + } + ], + "landsat-c2l2-st": [ + {"lwir": "lwir11", "lwir11": "lwir11", "native-cloud-mask": "qa-pixel"} + ], + "venus-l2a": [ + { + "coastal42": "image_file_SRE_B1", + "coastal44": "image_file_SRE_B2", + "coastal": "image_file_SRE_B2", + "blue": "image_file_SRE_B3", + "green": "image_file_SRE_B4", + "yellow": "image_file_SRE_B5", + "red": "image_file_SRE_B7", + "nir": "image_file_SRE_B11", + "nir08": "image_file_SRE_B11", + "nir09": "image_file_SRE_B12", + "rededge": "image_file_SRE_B09", + "rededge70": "image_file_SRE_B08", + "rededge74": "image_file_SRE_B09", + "rededge78": "image_file_SRE_B10", + "native-cloud-mask": "detailed_cloud_mask", + } + ], + "earthdaily-simulated-cloudless-l2a-cog-edagro": [ + { + "coastal44": "image_file_C", + "coastal": "image_file_C", + "blue": "image_file_B", + "green": "image_file_G", + "yellow": "image_file_Y", + "red": "image_file_R", + "nir": "image_file_NIR", + "nir08": "image_file_NIR", + "swir12": "image_file_SWIR1", + "swir16": "image_file_SWIR2", + "cirrus": "image_file_CI", + "rededge": "image_file_RE2", + "rededge70": "image_file_RE1", + "rededge74": "image_file_RE2", + "rededge78": "image_file_RE3", + "watervapor1": "image_file_RWVP1", + "watervapor2": "image_file_RWVP2", + "lwir": "image_file_ST", + "lwir11": "image_file_ST", + "tir1": "image_file_TIR", + "tir2": "image_file_TIR2", + "tir3": "image_file_TIR3", + "aqua": "image_file_A", + "wildfire": "image_file_W", + "me1": "image_file_ME1", + "me2": "image_file_ME2", + } + ], + "alsat-1b-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "cbers-4-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "deimos-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "gaofen-1-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "gaofen-6-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "kazstsat-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "huanjing-2-l2a-cog-edagro": [ + { + "blue": "blue", + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "dmc-l2a-cog-edagro": [ + { + "green": "green", + "red": "red", + "nir": "nir", + } + ], + "resourcesat-2-l2a-cog-edagro": [ + { + "green": "green", + "red": "red", + "nir": "nir", + } + ], +} diff --git a/earthdaily/earthdatastore/cube_utils/custom_operations.py b/earthdaily/earthdatastore/cube_utils/custom_operations.py index e6fc4903..be8f9878 100644 --- a/earthdaily/earthdatastore/cube_utils/custom_operations.py +++ b/earthdaily/earthdatastore/cube_utils/custom_operations.py @@ -6,9 +6,12 @@ class CustomOperations: @staticmethod def _np_mode(arr, **kwargs): - if not isinstance(arr, np.ndarray): - arr = arr.compute() - # or it will output + if isinstance(arr, list): + arr = np.asarray(arr) + if isinstance(arr, xr.Dataset | xr.DataArray): + if arr.chunks is not None: + arr = arr.compute() + # or it will output # NotImplementedError: Slicing an array with unknown chunks with a dask.array of ints is not supported values, counts = np.unique(arr, return_counts=True) rm = np.isnan(values) @@ -29,6 +32,7 @@ def register_custom_operations(): # register custom methods fo DataArrayGroupBy xr.core.groupby.DataArrayGroupBy.mode = CustomOperations.mode xr.core.groupby.DatasetGroupBy.mode = CustomOperations.mode + np.mode = CustomOperations._np_mode CustomOperations.register_custom_operations() diff --git a/earthdaily/earthdatastore/mask/__init__.py b/earthdaily/earthdatastore/mask/__init__.py index 652d51ba..fb86146a 100644 --- a/earthdaily/earthdatastore/mask/__init__.py +++ b/earthdaily/earthdatastore/mask/__init__.py @@ -4,17 +4,12 @@ from earthdaily.earthdatastore.cube_utils import _bbox_to_intersects import geopandas as gpd import warnings -import json import numpy as np import tqdm from joblib import Parallel, delayed dask.config.set(**{"array.slicing.split_large_chunks": True}) -warnings.simplefilter( - "ignore", category=xr.core.extensions.AccessorRegistrationWarning -) - _available_masks = [ "native", "venus_detailed_cloud_mask", @@ -35,6 +30,17 @@ } +def _bool_or_int_to_njobs(var): + if isinstance(var, bool): + if var: + arg = 1 + else: + arg = False + else: + arg = var + return arg + + class Mask: def __init__(self, dataset: xr.Dataset, intersects=None, bbox=None): self._obj = dataset @@ -50,25 +56,19 @@ def ag_cloud_mask( add_mask_var=False, mask_statistics=False, ): - acm_datacube["time"] = acm_datacube.time.dt.round( - "s" - ) # rm nano second + acm_datacube["time"] = acm_datacube.time.dt.round("s") # rm nano second self._obj["time"] = self._obj.time.dt.round("s") # rm nano second # - self._obj = self._obj.where( - acm_datacube["agriculture-cloud-mask"] == 1 - ) + self._obj = self._obj.where(acm_datacube["agriculture-cloud-mask"] == 1) if add_mask_var: - self._obj["agriculture-cloud-mask"] = acm_datacube[ - "agriculture-cloud-mask" - ] + self._obj["agriculture-cloud-mask"] = acm_datacube["agriculture-cloud-mask"] if mask_statistics: self.compute_clear_coverage( acm_datacube["agriculture-cloud-mask"], "ag_cloud_mask", 1, labels_are_clouds=False, - n_jobs=1 if mask_statistics == True else mask_statistics, + n_jobs=_bool_or_int_to_njobs(mask_statistics), ) return self._obj @@ -82,9 +82,7 @@ def cloudmask_from_asset( fill_value=np.nan, ): if cloud_asset not in self._obj.data_vars: - raise ValueError( - f"Asset '{cloud_asset}' needed to compute cloudmask." - ) + raise ValueError(f"Asset '{cloud_asset}' needed to compute cloudmask.") else: cloud_layer = self._obj[cloud_asset].copy() _assets = [a for a in self._obj.data_vars if a != cloud_asset] @@ -105,7 +103,7 @@ def cloudmask_from_asset( cloud_asset, labels, labels_are_clouds=labels_are_clouds, - n_jobs=1 if mask_statistics == True else mask_statistics, + n_jobs=_bool_or_int_to_njobs(mask_statistics), ) return self._obj @@ -123,9 +121,7 @@ def scl( mask_statistics=mask_statistics, ) - def venus_detailed_cloud_mask( - self, add_mask_var=False, mask_statistics=False - ): + def venus_detailed_cloud_mask(self, add_mask_var=False, mask_statistics=False): return self.cloudmask_from_asset( "detailed_cloud_mask", 0, @@ -142,15 +138,13 @@ def compute_clear_coverage( labels_are_clouds=True, n_jobs=1, ): - def compute_clear_pixels( - cloudmask_array, labels, labels_are_clouds=False - ): + def compute_clear_pixels(cloudmask_array, labels, labels_are_clouds=False): cloudmask_array = cloudmask_array.data.compute() if labels_are_clouds: - labels_sum = np.sum( - ~np.in1d(cloudmask_array, labels) - ) - np.sum(np.isnan(cloudmask_array)) + labels_sum = np.sum(~np.in1d(cloudmask_array, labels)) - np.sum( + np.isnan(cloudmask_array) + ) else: labels_sum = np.sum(np.in1d(cloudmask_array, labels)) return labels_sum @@ -173,6 +167,10 @@ def compute_clear_pixels( self._obj = self._obj.assign_coords( {f"clear_pixels_{cloudmask_name}": ("time", n_pixels_as_labels)} ) + self._obj.coords["clear_pixels"] = self._obj.coords[ + f"clear_pixels_{cloudmask_name}" + ] + self._obj = self._obj.assign_coords( { f"clear_percent_{cloudmask_name}": ( @@ -184,6 +182,15 @@ def compute_clear_pixels( ) } ) + + self._obj.coords["clear_percent"] = self._obj.coords[ + f"clear_percent_{cloudmask_name}" + ] + warnings.warn( + f"Removed in 0.0.2 : The two coordinates 'clear_pixels_{cloudmask_name}' and 'clear_percent_{cloudmask_name}' will be deleted to keep only the prefix 'clear_percent' and 'clear_pixels'", + category=DeprecationWarning, + ) + return self._obj def compute_available_pixels(self): @@ -253,3 +260,7 @@ def QA_PIXEL_cloud_detection(arr): cloudfree_pixels = cloudfree[cloudfree != 0] cloudmask = np.isin(arr, cloudfree_pixels).astype(int) return cloudmask + + +def filter_clear_cover(dataset, clear_cover, coordinate="clear_percent"): + return dataset.sel(time=dataset.time[dataset.clear_percent >= clear_cover]) diff --git a/examples/common_band_names.py b/examples/common_band_names.py new file mode 100644 index 00000000..a09fc6b3 --- /dev/null +++ b/examples/common_band_names.py @@ -0,0 +1,27 @@ +""" +EarthDaily unique common band names +================================================================= + +For a better interoperability between sensors.""" + +############################################################################## +# Import librairies +# ------------------------------------------- + +from earthdaily.earthdatastore.cube_utils import asset_mapper +from rich.table import Table +from rich.console import Console +console = Console(force_interactive=True) + +############################################################################## +# Show each collection with their earthdaily common band names +# -------------------------------------------------------------- +# For band names where several bands are available (rededge) it has been chosen +# to use the central wavelength (rededge70 is rededge1 of sentinel-2 for example). +# + +for collection,assets in asset_mapper._asset_mapper_config.asset_mapper_collections.items(): + table = Table("asset","EarthDaily Common band name",title=f"Earthdaily common names for {collection}") + for common_name, asset in assets[0].items(): + table.add_row(asset, common_name) + console.print(table) diff --git a/examples/compare_scale_s2.py b/examples/compare_scale_s2.py index 272e4556..2b2250bc 100644 --- a/examples/compare_scale_s2.py +++ b/examples/compare_scale_s2.py @@ -36,7 +36,7 @@ def get_cube(rescale=True): datetime=["2022-01-01", "2022-03-10"], assets=["red", "green", "blue"], mask_with="native", # same as scl - mask_statistics=-1, # boolean or int (n_jobs) to compute in parallel if small + clear_cover=50, # at least 50% of the polygon must be clear rescale=rescale, ) return pivot_cube @@ -47,13 +47,12 @@ def get_cube(rescale=True): # ---------------------------------------------------- pivot_cube = get_cube(rescale=False) * 0.0001 -pivot_cube.clear_percent_scl.plot.scatter(x="time") plt.show() #####################################################################da######### # Plots cube with SCL with at least 50% of clear data # ---------------------------------------------------- -pivot_cube = pivot_cube.sel(time=pivot_cube.time[pivot_cube.clear_percent_scl > 50]) + pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3) @@ -64,13 +63,13 @@ def get_cube(rescale=True): # ---------------------------------------------------- pivot_cube = get_cube() -pivot_cube.clear_percent_scl.plot.scatter(x="time") +pivot_cube.clear_percent.plot.scatter(x="time") plt.show() #####################################################################da######### # Plots cube with SCL with at least 50% of clear data # ---------------------------------------------------- -pivot_cube = pivot_cube.sel(time=pivot_cube.time[pivot_cube.clear_percent_scl > 50]) + pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.33, col="time", col_wrap=3) diff --git a/examples/field_evolution.py b/examples/field_evolution.py index bc1b67bc..0b6eff19 100644 --- a/examples/field_evolution.py +++ b/examples/field_evolution.py @@ -11,7 +11,6 @@ from earthdaily import earthdatastore import geopandas as gpd from matplotlib import pyplot as plt -import numpy as np ############################################################################## # Load plot @@ -27,43 +26,38 @@ eds = earthdatastore.Auth() ############################################################################## -# Search for collection items in august 2022 -# +# Search for collection items in august 2022 (1st to 9th) +# where at least 50% of the field is clear according to the native cloudmask. + pivot_cube = eds.datacube( "sentinel-2-l2a", intersects=pivot, - datetime=["2022-08-01", "2022-08-08"], + datetime=["2022-08-01", "2022-08-09"], assets=["red", "green", "blue"], mask_with="native", # same as scl - mask_statistics=True, + # mask_statistics=True, # as you ask `clear_cover`it will force computing mask_statistics + clear_cover=50, ) -pivot_cube.clear_percent_scl.plot.scatter(x="time") +pivot_cube.clear_percent.plot.scatter(x="time") -#####################################################################da######### +############################################################################## # Plots cube with SCL with at least 50% of clear data # ---------------------------------------------------- -cube_majority_clear = pivot_cube.sel( - time=pivot_cube.time[pivot_cube.clear_percent_scl > 50] -) -cube_majority_clear = cube_majority_clear.load() - -cube_majority_clear.to_array(dim="band").plot.imshow( - vmin=0, vmax=0.33, col="time", col_wrap=3 -) +pivot_cube.to_array(dim="band").plot.imshow(vmin=0, vmax=0.4, col="time", col_wrap=3) plt.title("Clear cover percent with SCL") plt.title("Pivot evolution with SCL masks") plt.show() -#####################################################################da######### +############################################################################## # Compute zonal stats for the pivot # ---------------------------------------------------- zonal_stats = earthdatastore.cube_utils.zonal_stats( - cube_majority_clear, pivot, operations=["mean", "max", "min"] + pivot_cube, pivot, operations=["mean", "max", "min"] ) zonal_stats = zonal_stats.load() diff --git a/examples/first_steps_create_datacube.py b/examples/first_steps_create_datacube.py index 9fe01a56..fb2d9353 100644 --- a/examples/first_steps_create_datacube.py +++ b/examples/first_steps_create_datacube.py @@ -28,19 +28,19 @@ eds = earthdatastore.Auth() # using config from ENV ########################################################### -# Create datacube (all in one) +# Create datacube (in one step) # -------------------------------------------------- s2_datacube = eds.datacube( "sentinel-2-l2a", assets=["blue", "green", "red", "nir"], intersects=geometry, - datetime=["2022-07"], + datetime=["2022-08-01", "2022-08-09"], mask_with="native", # equal to "scl" for sentinel-2 - mask_statistics=True, + clear_cover=50, ) -s2_datacube.clear_percent_scl.plot.scatter(x="time") +s2_datacube.clear_percent.plot.scatter(x="time") plt.title("Percentage of clear pixels on the study site") plt.show() print(s2_datacube) @@ -57,7 +57,7 @@ # Request items items = eds.search( - "sentinel-2-l2a", intersects=geometry, datetime=["2022-07"] + "sentinel-2-l2a", intersects=geometry, datetime=["2022-08-01", "2022-08-09"] ) ########################################################### @@ -77,6 +77,9 @@ mask_statistics=True ) +s2_datacube = earthdatastore.mask.filter_clear_cover( + s2_datacube, 50 +) # at least 50% of clear pixels # s2_datacube[["red", "green", "blue"]].to_array(dim="band").plot.imshow( vmin=0, vmax=0.2, col="time", col_wrap=4 diff --git a/examples/multisensors_cube.py b/examples/multisensors_cube.py index ca0ec12b..59a97725 100644 --- a/examples/multisensors_cube.py +++ b/examples/multisensors_cube.py @@ -9,11 +9,11 @@ # Import librairies # ------------------------------------------- -from earthdaily import earthdatastore import geopandas as gpd -from rasterio.enums import Resampling from matplotlib import pyplot as plt +from rasterio.enums import Resampling +from earthdaily import earthdatastore ############################################################################## # Import librairies @@ -33,27 +33,19 @@ # Generate s2 cube # ------------------------------------------- +assets = ["blue", "green", "red", "nir"] s2 = eds.datacube( "sentinel-2-l2a", intersects=polygon, datetime=datetime, - assets=["blue", "green", "red", "nir"], - prefer_http=True, + assets=assets, ) ############################################################################## # Generate venus cube # ------------------------------------------- -venus_assets = dict( - image_file_SRE_B3="blue", - image_file_SRE_B4="green", - image_file_SRE_B5="yellow", - image_file_SRE_B7="red", - image_file_SRE_B11="nir", -) - venus = eds.datacube( "venus-l2a", intersects=polygon, @@ -61,21 +53,33 @@ datetime=datetime, epsg=s2.rio.crs.to_epsg(), resampling=Resampling.nearest, # cubic - assets=venus_assets, - prefer_http=True, + assets=assets, ) +############################################################################## +# Generate Landsat cube +# ------------------------------------------- + +landsat = eds.datacube( + "landsat-c2l2-sr", + intersects=polygon, + datetime=datetime, + resampling=Resampling.nearest, + epsg=s2.rio.crs.to_epsg(), + resolution=s2.rio.resolution()[0], + assets=assets, +) ############################################################################## # Create supercube # ------------------------------------------- print("create metacube") -supercube = earthdatastore.metacube(s2, venus) +supercube = earthdatastore.metacube(s2, venus, landsat) ############################################################################## -# Get the first common date for plotting -# ------------------------------------------- +# Get the first common date between S2 and Venus for plotting +# --------------------------------------------------------------- common_date = [ day @@ -88,7 +92,7 @@ # ------------------------------------------- s2.sel(time=common_date)[["red", "green", "blue"]].to_array(dim="band").plot.imshow( - vmin=0, vmax=0.15 + vmin=0, vmax=0.2 ) plt.title(f"Sentinel-2 on {common_date}") plt.show() @@ -98,7 +102,7 @@ # ------------------------------------------- venus.sel(time=common_date, method="nearest")[["red", "green", "blue"]].to_array( dim="band" -).plot.imshow(vmin=0, vmax=0.15) +).plot.imshow(vmin=0, vmax=0.2) plt.title(f"Venus on {common_date}") plt.show() @@ -110,7 +114,7 @@ supercube.sel(time=common_date)[["red", "green", "blue"]].to_array( dim="band" -).plot.imshow(vmin=0, vmax=0.15) +).plot.imshow(vmin=0, vmax=0.2) plt.title(f"Fusion of Venus/Sentinel-2 on {common_date}") plt.show() diff --git a/examples/summary_stack.py b/examples/summary_stack.py index 1d9cd93c..bf2bbd8e 100644 --- a/examples/summary_stack.py +++ b/examples/summary_stack.py @@ -7,11 +7,9 @@ # Import librairies # ------------------------------------------- -import json from earthdaily import earthdatastore from rich.table import Table from rich.console import Console -import os from dotenv import load_dotenv @@ -40,7 +38,7 @@ # Explore a specific collection # ------------------------------------------- -collection = eds.explore("sentinel-2-l2a") +collection = eds.explore("venus-l2a") console.log(collection.properties) ############################################################################## diff --git a/examples/venus_cube_mask.py b/examples/venus_cube_mask.py index 672413dc..63693320 100644 --- a/examples/venus_cube_mask.py +++ b/examples/venus_cube_mask.py @@ -68,19 +68,10 @@ venus_datacube = earthdatastore.datacube( - items, - assets={ - "image_file_SRE_B3": "blue", - "image_file_SRE_B4": "green", - "image_file_SRE_B7": "red", - }, - epsg=epsg, - resolution=resolution, + items, assets=["blue", "green", "red"], epsg=epsg, resolution=resolution ) print(venus_datacube) -venus_datacube.isel( - time=slice(29, 33), x=slice(4000, 4500), y=slice(4000, 4500) -)[["red", "green", "blue"]].to_array(dim="band").plot.imshow( - col="time", vmin=0, vmax=0.33 -) +venus_datacube.isel(time=slice(29, 32), x=slice(4000, 4500), y=slice(4000, 4500))[ + ["red", "green", "blue"] +].to_array(dim="band").plot.imshow(col="time", vmin=0, vmax=0.33) diff --git a/requirements.yml b/requirements.yml index e19f767a..764ad70b 100644 --- a/requirements.yml +++ b/requirements.yml @@ -4,6 +4,7 @@ channels: - defaults dependencies: - numpy + - python>=3.10 - gdal>=3.7.0 - pandas - geopandas>=0.11 diff --git a/requirements_dev.yml b/requirements_dev.yml index f320d455..b37d6fc4 100644 --- a/requirements_dev.yml +++ b/requirements_dev.yml @@ -2,7 +2,7 @@ channels: - defaults - conda-forge dependencies: - - python=3.11 + - python>=3.10 - numpy - pandas - gdal>=3.7.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_assetmapper.py b/tests/test_assetmapper.py new file mode 100644 index 00000000..5a541174 --- /dev/null +++ b/tests/test_assetmapper.py @@ -0,0 +1,45 @@ +import unittest + +from earthdaily.earthdatastore.cube_utils.asset_mapper import AssetMapper + + +class TestAssetMapper(unittest.TestCase): + def setUp(self): + self.aM = AssetMapper() + + def test_unknow_collection(self): + collection = "earthdaily-unknow-collection" + assets = ["blue", "green", "red", "lambda"] + self.assertEqual(self.aM.map_collection_bands(collection, assets), assets) + with self.assertRaises(NotImplementedError): + self.aM._collection_exists(collection, raise_warning=True) + + def test_return_same_dict(self): + collection = "sentinel-2-l2a" + assets = {"key": "value", "source": "target", "sensorasset": "myoutputband"} + self.assertEqual(self.aM.map_collection_bands(collection, assets), assets) + + def test_sentinel2(self): + collection = "sentinel-2-l2a" + assets = ["blue", "green", "red", "rededge74", "missing_band"] + assets_s2 = ["blue", "green", "red", "rededge2"] + self.assertEqual( + list(self.aM.map_collection_bands(collection, assets).keys()), assets_s2 + ) + + def test_venus_rededge(self): + collection = "venus-l2a" + rededges = { + "rededge70": "image_file_SRE_B08", + "rededge74": "image_file_SRE_B09", + "rededge78": "image_file_SRE_B10", + } + + self.assertEqual( + list(self.aM.map_collection_bands(collection, rededges.keys()).keys()), + list(rededges.values()), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zonalstats.py b/tests/test_zonalstats.py new file mode 100644 index 00000000..ced04591 --- /dev/null +++ b/tests/test_zonalstats.py @@ -0,0 +1,66 @@ +import numpy as np +import xarray as xr +import rioxarray as rxr +import geopandas as gpd +from shapely.geometry import Polygon +import earthdaily +import unittest + + +class TestZonalStats(unittest.TestCase): + def setUp(self, constant=np.random.randint(1, 12)): + # Define time, x, and y values + times = ["1987-04-22", "2022-04-22"] + x_values = np.arange(0, 8) + y_values = np.arange(0, 3) + self.constant = constant + # Create 3D arrays for the data values + data_values = np.arange(0, 24).reshape(3, 8) + data_values = np.dstack((data_values, np.full((3, 8), constant))) + + # Create the xarray dataset + ds = xr.Dataset( + {"first_var": (("y", "x", "time"), data_values)}, + coords={ + "y": y_values, + "x": x_values, + "time": times, + }, + ).rio.write_crs("EPSG:4326") + + # first pixel + geometry = [ + Polygon([(0, 0), (0, 0.8), (0.8, 0.8), (0.8, 0)]), + Polygon([(1, 1), (9, 1), (9, 2.1), (1, 1)]), + ] + # out of bound geom # Polygon([(10,10), (10,11), (11,11), (11,10)])] + gdf = gpd.GeoDataFrame({"geometry": geometry}, crs="EPSG:4326") + self.gdf = gdf + self.datacube = ds + + def test_numpy(self): + zonalstats = earthdaily.earthdatastore.cube_utils.zonal_stats_numpy( + self.datacube, + self.gdf, + all_touched=True, + operations=dict(mean=np.nanmean, max=np.nanmax, min=np.nanmin, mode=np.mode), + ) + + for operation in ['min','max', 'mode']: + self._check_results(zonalstats["first_var"].sel(stats=operation).values, operation=operation) + + def test_basic(self): + zonalstats = earthdaily.earthdatastore.cube_utils.zonal_stats( + self.datacube, self.gdf, all_touched=True, operations=["min", "max", "mode"] + ) + for operation in ['min','max','mode']: + self._check_results(zonalstats["first_var"].sel(stats=operation).values, operation=operation) + + def _check_results(self, stats_values,operation='min'): + results = {"min":np.asarray([[0, self.constant], [9, self.constant]]), + "max":np.asarray([[8, self.constant], [23, self.constant]]), + "mode":np.asarray([[0, self.constant], [9, self.constant]])} + self.assertTrue(np.all(stats_values == results[operation])) + +if __name__ == "__main__": + unittest.main()