Skip to content

Commit

Permalink
Merge branch 'main' into croptype
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Aug 23, 2024
2 parents df8dcb8 + 84a2dd1 commit d72d819
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 37 deletions.
45 changes: 9 additions & 36 deletions presto/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@ def __init__(self, model: Presto, batch_size: int = 8192):

_NODATAVALUE = 65535

BAND_MAPPING = {
"B02": "B2",
"B03": "B3",
"B04": "B4",
"B05": "B5",
"B06": "B6",
"B07": "B7",
"B08": "B8",
"B8A": "B8A",
"B11": "B11",
"B12": "B12",
"VH": "VH",
"VV": "VV",
"precipitation-flux": "total_precipitation",
"temperature-mean": "temperature_2m",
}

STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"}

@classmethod
def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -105,30 +86,22 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS)))
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX)))

for org_band, presto_band in cls.BAND_MAPPING.items():
if org_band in inarr.coords["bands"]:
for presto_band in NORMED_BANDS:
if presto_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1
inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)),
0,
1,
)
idx_valid = values != cls._NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
elif presto_band == "NDVI":
# Band NDVI will be computed by Presto
continue
else:
logger.warning(f"Band {org_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

for org_band, presto_band in cls.STATIC_BAND_MAPPING.items():
if org_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1
)
idx_valid = values != cls._NODATAVALUE
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
else:
logger.warning(f"Band {org_band} not found in input data.")
logger.warning(f"Band {presto_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_dependencies(tag: str) -> List[str]:
long_description_content_type="text/markdown",
author="Gabriel Tseng",
author_email="gabrieltseng95@gmail.com",
version="0.1.2",
version="0.1.4",
classifiers=[
"Programming Language :: Python :: 3",
"License :: Other/Proprietary License",
Expand Down

0 comments on commit d72d819

Please sign in to comment.