diff --git a/presto/inference.py b/presto/inference.py index be67137..29e42fd 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -46,25 +46,6 @@ def __init__(self, model: Presto, batch_size: int = 8192): _NODATAVALUE = 65535 - BAND_MAPPING = { - "B02": "B2", - "B03": "B3", - "B04": "B4", - "B05": "B5", - "B06": "B6", - "B07": "B7", - "B08": "B8", - "B8A": "B8A", - "B11": "B11", - "B12": "B12", - "VH": "VH", - "VV": "VV", - "precipitation-flux": "total_precipitation", - "temperature-mean": "temperature_2m", - } - - STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"} - @classmethod def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray: """ @@ -105,30 +86,22 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS))) mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX))) - for org_band, presto_band in cls.BAND_MAPPING.items(): - if org_band in inarr.coords["bands"]: + for presto_band in NORMED_BANDS: + if presto_band in inarr.coords["bands"]: values = np.swapaxes( - inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1 + inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)), + 0, + 1, ) idx_valid = values != cls._NODATAVALUE values = cls._preprocess_band_values(values, presto_band) eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + elif presto_band == "NDVI": + # Band NDVI will be computed by Presto + continue else: - logger.warning(f"Band {org_band} not found in input data.") - eo_data[:, :, BANDS.index(presto_band)] = 0 - mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 - - for org_band, presto_band in cls.STATIC_BAND_MAPPING.items(): - if org_band in inarr.coords["bands"]: - values = np.swapaxes( - inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1 - ) - idx_valid = values != cls._NODATAVALUE - eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid - mask[:, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid - else: - logger.warning(f"Band {org_band} not found in input data.") + logger.warning(f"Band {presto_band} not found in input data.") eo_data[:, :, BANDS.index(presto_band)] = 0 mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 diff --git a/setup.py b/setup.py index 72f7c46..3fb37a3 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def load_dependencies(tag: str) -> List[str]: long_description_content_type="text/markdown", author="Gabriel Tseng", author_email="gabrieltseng95@gmail.com", - version="0.1.2", + version="0.1.4", classifiers=[ "Programming Language :: Python :: 3", "License :: Other/Proprietary License",