diff --git a/.gitignore b/.gitignore index 2aac475..045c71a 100644 --- a/.gitignore +++ b/.gitignore @@ -244,3 +244,4 @@ tags # Data *.zarr.zip *tif* +integration_test/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 49aa734..99a936f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [PEP 440](https://www.python.org/dev/peps/pep-0440/) and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] + +### Changed +* Format and layout of chips to more closely match the TerraMesh dataset. + ## [0.4.0] ### Changed diff --git a/README.md b/README.md index bfca03d..b5330f6 100644 --- a/README.md +++ b/README.md @@ -6,38 +6,38 @@ A package for satellite image AI data prep. This package "chips" data labels and `SatChip` relies on a two-step process; chip your label train data inputs, then create corresponding chips for different remote sensing data sources. ### Step 1: Chip labels -The `chiplabel` CLI tool takes a GDAL-compatible image, a collection date, and an optional output directory as input using the following format: +The `chiplabel` CLI tool takes a GDAL-compatible image, a collection date, and an optional chip directory as input using the following format: ```bash -chiplabel PATH/TO/LABELS.tif DATE(UTC FORMAT) --outdir OUTPUT_DIR +chiplabel PATH/TO/LABELS.tif DATE(UTC FORMAT) --chipdir CHIP_DIR ``` For example: ```bash -chiplabel LA_damage_20250113_v0.tif 2024-01-01T01:01:01 --outdir chips +chiplabel LA_damage_20250113_v0.tif 2024-01-01T01:01:01 --chipdir chips ``` -This will produce an output zipped Zarr store label dataset with the name `{LABELS}.zarr.zip` in the specified output directory (`--outdir`). This file will be the input to the remote sensing data chipping step. +This will produce an output zipped Zarr store label dataset with the name `{LABEL}_{SAMPLE}.zarr.zip` (see the (Tiling Schema)[#tiling_schema] section for details on the `SAMPLE` name) to the `LABEL` directory in the specified chip directory (`--chipdir`). This file will be the input to the remote sensing data chipping step. For more information on usage see `chiplabel --help` ### Step 2: Chip remote sensing data -The `chipdata` CLI tool takes a label zipped Zarr store, a dataset name, a date range and a set of optional parameters using the following format: +The `chipdata` CLI tool takes a path to a directory containing chip labels, a dataset name, a date range and a set of optional parameters using the following format: ```bash -chipdata PATH/TO/LABELS.zarr.zip DATASET Ymd-Ymd \ +chipdata PATH/TO/LABEL DATASET Ymd-Ymd \ --maxcloudpct MAX_CLOUD_PCT --strategy STRATEGY \ - --outdir OUTPUT_DIR --scratchdir SCRATCH_DIR + --chipdir CHIPPUT_DIR --imagedir IMAGE_DIR ``` For example: ```bash -chipdata LA_damage_20250113_v0.zarr.zip S2L2A 20250112-20250212 --maxcloudpct 20 --outdir chips --scratchdir images +chipdata LABEL S2L2A 20250112-20250212 --maxcloudpct 20 --chipdir CHIP_DIR --imagedir IMAGES ``` -Similarly to step 1, this will produce an output zipped Zarr store that contains chipped data for your chosen dataset with the name `{LABELS}_{DATASET}.zarr.zip`. The arguments are as follows: -- `PATH/TO/LABELS.zarr.zip`: the path to your training lables. +Similarly to step 1, this will produce an output zipped Zarr store that contains chipped data for your chosen dataset with the name `{LABELS_{SAMPLE}_{DATASET}.zarr.zip`. The arguments are as follows: +- `PATH/TO/LABEL`: the path to your training labels - `DATASET`: The satellite imagery dataset you would like to create labels for. See the list below for all current options. - `Ymd-Ymd`: The date range to select imagery from. For example, `20250112-20250212` selects imagery between January 12 and February 12, 2025. - `MAX_CLOUD_PCT`: For optical data, this optional parameter lets you set the maximum amount of cloud coverage allowed in a chip. Values between 0 and 100 are allowed. Cloud coverage is calculated on a per-chip basis. The default is 100 i.e., no limit. - `STRATEGY`: Lets you selected what data inside your date range will be used to create chips. Specifying `BEST` (the default) will create a chip for the image closest to the beginning of your date range that has at least 95% spatial coverage. Specifying `ALL` will create chips for all images within your date range that have at least 95% spatial coverage. -- `OUTPUT_DIR`: Specifies the directory where the image chips will be saved. If not specified, this defaults to your current directory. -- `SCRATCH_DIR`: Specifies the directory where the full-size satellite images will be downloaded to. If this argument is not provided, the images will be stored in a scratch directory that will be deleted when the `chipdata` call finishes. +- `CHIP_DIR`: Specifies the directory where the image chips will be saved. If not specified, this defaults to your current directory. +- `IMAGE_DIR`: Specifies the directory where the full-size satellite images will be downloaded to. If this argument is not provided, the images will be stored in the `IMAGES` directory within `CHIP_DIR`. Currently supported datasets include: - `S2L2A`: Sentinel-2 L2A data sourced from the [Sentinel-2 AWS Open Data Archive](https://registry.opendata.aws/sentinel-2/) @@ -65,9 +65,9 @@ For instance, the bottom-left subgrid of MajorTOM tile `434U_876L` is named `434 ## Viewing Chips Assessing chips after their creation can be challenging due to the large number of small images created. To address this issue, SatChip includes a `chipview` CLI tool that uses Matplotlib to quickly visualize the data included within the created zipped Zarr stores: ```bash -chipview PATH/TO/CHIPS.zarr.zip BAND --idx IDX +chipview PATH/TO/CHIP.zarr.zip --band BAND ``` -Where `PATH/TO/CHIPS.zarr.zip` is the path to the chip file (labels or image data), `BAND` is the name of the band you would like to view, and `IDX` is an optional integer index of which dataset you would like to initially view. +Where `PATH/TO/CHIPS.zarr.zip` is the path to the chip file (labels or image data), and `BAND` is an OPTIONAL name of the band you would like to view. If no band is specified, an OPERA-style RGB decomposition will be shown for RTC data, and an RGB composite will be shown for optical data. ## License `SatChip` is licensed under the BSD-3-Clause open source license. See the LICENSE file for more details. diff --git a/pyproject.toml b/pyproject.toml index fd4a6f2..c2d5f72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,8 @@ chipview = "satchip.chip_view:main" [tool.pytest.ini_options] testpaths = ["tests"] script_launch_mode = "subprocess" +addopts = '-ra -q -m "not integration"' +markers = ["integration"] [tool.setuptools] include-package-data = true diff --git a/scripts/open_chips.py b/scripts/open_chips.py deleted file mode 100644 index ed8e2bf..0000000 --- a/scripts/open_chips.py +++ /dev/null @@ -1,22 +0,0 @@ -import argparse -from pathlib import Path - -import xarray as xr - - -def open_chips(input_path: Path) -> xr.Dataset: - ds = xr.open_zarr(input_path) - return ds - - -def main() -> None: - parser = argparse.ArgumentParser(description='Open a chip Zarr Zip Store') - parser.add_argument('input', type=str, help='Path to the input Zarr Zip Store') - args = parser.parse_args() - args.input = Path(args.input) - - open_chips(args.input) - - -if __name__ == '__main__': - main() diff --git a/src/satchip/chip_data.py b/src/satchip/chip_data.py index 55e7c70..84f9fb7 100644 --- a/src/satchip/chip_data.py +++ b/src/satchip/chip_data.py @@ -1,18 +1,18 @@ import argparse +from collections import Counter from datetime import datetime from pathlib import Path -from tempfile import TemporaryDirectory import numpy as np import xarray as xr +from shapely.geometry import box from tqdm import tqdm -import satchip from satchip import utils from satchip.chip_hls import get_hls_data from satchip.chip_sentinel1rtc import get_rtc_paths_for_chips, get_s1rtc_chip_data from satchip.chip_sentinel2 import get_s2l2a_data -from satchip.terra_mind_grid import TerraMindGrid +from satchip.terra_mind_grid import TerraMindChip, TerraMindGrid def fill_missing_times(data_chip: xr.DataArray, times: np.ndarray) -> xr.DataArray: @@ -31,65 +31,80 @@ def fill_missing_times(data_chip: xr.DataArray, times: np.ndarray) -> xr.DataArr return xr.concat([data_chip, missing_data], dim='time').sortby('time') +def get_chip(label_path: Path) -> TerraMindChip: + label_dataset = utils.load_chip(label_path) + buffered = box(*label_dataset.bounds).buffer(0.1).bounds + grid = TerraMindGrid([buffered[1], buffered[3]], [buffered[0], buffered[2]]) # type: ignore + label_chip_name = label_dataset.sample.item() + chip = [c for c in grid.terra_mind_chips if c.name == label_chip_name] + assert len(chip) == 1, f'No TerraMind chip found for label {label_chip_name}' + return chip[0] + + def chip_data( - label_path: Path, + chip: TerraMindChip, + platform: str, + opts: utils.ChipDataOpts, + image_dir: Path, +) -> xr.Dataset: + if platform == 'S1RTC': + rtc_paths = opts['local_hyp3_paths'][chip.name] + chip_dataset = get_s1rtc_chip_data(chip, rtc_paths) + elif platform == 'S2L2A': + chip_dataset = get_s2l2a_data(chip, image_dir, opts=opts) + elif platform == 'HLS': + chip_dataset = get_hls_data(chip, image_dir, opts=opts) + else: + raise Exception(f'Unknown platform {platform}') + + return chip_dataset + + +def create_chips( + label_paths: list[Path], platform: str, date_start: datetime, date_end: datetime, strategy: str, max_cloud_pct: int, - output_dir: Path, - scratch_dir: Path, -) -> xr.Dataset: - labels = utils.load_chip(label_path) - date = labels.time.data[0].astype('M8[ms]').astype(datetime) - bounds = labels.attrs['bounds'] - - grid = TerraMindGrid([bounds[1] - 1, bounds[3] + 1], [bounds[0] - 1, bounds[2] + 1]) # type: ignore - terra_mind_chips = [c for c in grid.terra_mind_chips if c.name in list(labels.sample.data)] + chip_dir: Path, + image_dir: Path, +) -> list[Path]: + platform_dir = chip_dir / platform + platform_dir.mkdir(parents=True, exist_ok=True) opts: utils.ChipDataOpts = {'strategy': strategy, 'date_start': date_start, 'date_end': date_end} if platform in ['S2L2A', 'HLS']: opts['max_cloud_pct'] = max_cloud_pct + chips = [get_chip(p) for p in label_paths] + chip_names = [c.name for c in chips] + if len(chip_names) != len(set(chip_names)): + duplicates = [name for name, count in Counter(chip_names).items() if count > 1] + msg = f'Duplicate sample locations not supported. Duplicate chips: {", ".join(duplicates)}' + raise NotImplementedError(msg) + chip_paths = [ + platform_dir / (x.with_suffix('').with_suffix('').name + f'_{platform}.zarr.zip') for x in label_paths + ] if platform == 'S1RTC': - rtc_paths_for_chips = get_rtc_paths_for_chips(terra_mind_chips, bounds, scratch_dir, opts) - - data_chips = [] - for chip in tqdm(terra_mind_chips): - if platform == 'S1RTC': - rtc_paths = rtc_paths_for_chips[chip.name] - chip_data = get_s1rtc_chip_data(chip, rtc_paths, scratch_dir, opts=opts) - elif platform == 'S2L2A': - chip_data = get_s2l2a_data(chip, scratch_dir, opts=opts) - elif platform == 'HLS': - chip_data = get_hls_data(chip, scratch_dir, opts=opts) - else: - raise Exception(f'Unknown platform {platform}') - - data_chips.append(chip_data) - - times = np.unique(np.concatenate([dc.time.data for dc in data_chips])) - for i, data_chip in enumerate(data_chips): - if len(data_chip.time) < len(times): - data_chips[i] = fill_missing_times(data_chip, times) - attrs = {'date_created': date.isoformat(), 'satchip_version': satchip.__version__, 'bounds': labels.attrs['bounds']} - dataset = xr.Dataset(attrs=attrs) - dataset['data'] = xr.combine_by_coords(data_chips, join='override') - output_path = output_dir / (label_path.with_suffix('').with_suffix('').name + f'_{platform}.zarr.zip') - utils.save_chip(dataset, output_path) - return labels + rtc_paths_for_chips = get_rtc_paths_for_chips(chips, image_dir, opts) + opts['local_hyp3_paths'] = rtc_paths_for_chips + + for chip, chip_path in tqdm(zip(chips, chip_paths), desc='Chipping labels'): + dataset = chip_data(chip, platform, opts, image_dir) + utils.save_chip(dataset, chip_path) + return chip_paths def main() -> None: parser = argparse.ArgumentParser(description='Chip a label image') - parser.add_argument('labelpath', type=Path, help='Path to the label image') + parser.add_argument('labelpath', type=Path, help='Path to the label directory') parser.add_argument('platform', choices=['S2L2A', 'S1RTC', 'HLS'], type=str, help='Dataset to create chips for') parser.add_argument('daterange', type=str, help='Inclusive date range to search for data in the format Ymd-Ymd') parser.add_argument('--maxcloudpct', default=100, type=int, help='Maximum percent cloud cover for a data chip') - parser.add_argument('--outdir', default='.', type=Path, help='Output directory for the chips') + parser.add_argument('--chipdir', default='.', type=Path, help='Output directory for the chips') parser.add_argument( - '--scratchdir', default=None, type=Path, help='Output directory for scratch files if you want to keep them' + '--imagedir', default=None, type=Path, help='Output directory for image files. Defaults to chipdir/IMAGES' ) parser.add_argument( '--strategy', @@ -103,23 +118,15 @@ def main() -> None: assert 0 <= args.maxcloudpct <= 100, 'maxcloudpct must be between 0 and 100' date_start, date_end = [datetime.strptime(d, '%Y%m%d') for d in args.daterange.split('-')] assert date_start < date_end, 'start date must be before end date' + label_paths = list(args.labelpath.glob('*.zarr.zip')) + assert len(label_paths) > 0, f'No label files found in {args.labelpath}' - params = ( - args.labelpath, - args.platform, - date_start, - date_end, - args.strategy, - args.maxcloudpct, - args.outdir, - ) + if args.imagedir is None: + args.imagedir = args.chipdir / 'IMAGES' - if args.scratchdir is not None: - chip_data(*params, args.scratchdir) - else: - with TemporaryDirectory() as tmp_dir: - scratch_dir = Path(tmp_dir) - chip_data(*params, scratch_dir) + create_chips( + label_paths, args.platform, date_start, date_end, args.strategy, args.maxcloudpct, args.chipdir, args.imagedir + ) if __name__ == '__main__': diff --git a/src/satchip/chip_hls.py b/src/satchip/chip_hls.py index b7d16cc..21168bf 100644 --- a/src/satchip/chip_hls.py +++ b/src/satchip/chip_hls.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from datetime import datetime, timedelta from pathlib import Path @@ -9,28 +10,32 @@ from earthaccess.results import DataGranule from satchip import utils -from satchip.chip_xr_base import create_template_da +from satchip.chip_xr_base import create_dataset_chip, create_template_da from satchip.terra_mind_grid import TerraMindChip -HLS_L_BANDS = { - 'B01': 'COASTAL', - 'B02': 'BLUE', - 'B03': 'GREEN', - 'B04': 'RED', - 'B05': 'NIR08', - 'B06': 'SWIR16', - 'B07': 'SWIR22', -} -HLS_S_BANDS = { - 'B01': 'COASTAL', - 'B02': 'BLUE', - 'B03': 'GREEN', - 'B04': 'RED', - 'B8A': 'NIR08', - 'B11': 'SWIR16', - 'B12': 'SWIR22', -} +HLS_L_BANDS = OrderedDict( + { + 'B01': 'COASTAL', + 'B02': 'BLUE', + 'B03': 'GREEN', + 'B04': 'RED', + 'B05': 'NIR08', + 'B06': 'SWIR16', + 'B07': 'SWIR22', + } +) +HLS_S_BANDS = OrderedDict( + { + 'B01': 'COASTAL', + 'B02': 'BLUE', + 'B03': 'GREEN', + 'B04': 'RED', + 'B8A': 'NIR08', + 'B11': 'SWIR16', + 'B12': 'SWIR22', + } +) BAND_SETS = {'L30': HLS_L_BANDS, 'S30': HLS_S_BANDS} @@ -52,7 +57,7 @@ def get_product_id(umm: dict) -> str: def get_scenes( - items: list[DataGranule], roi: shapely.geometry.Polygon, max_cloud_pct: int, strategy: str, scratch_dir: Path + items: list[DataGranule], roi: shapely.geometry.Polygon, max_cloud_pct: int, strategy: str, image_dir: Path ) -> list[DataGranule]: """Returns the best HLS scene from the given list of items. The best scene is defined as the earliest scene with the largest intersection with the roi and @@ -63,7 +68,7 @@ def get_scenes( roi: Region of interest polygon. max_cloud_pct: Maximum percent of bad pixels allowed in the scene. strategy: Strategy to use when selecting data. - scratch_dir: Directory to store downloaded files. + image_dir: Directory to store downloaded files. Returns: The best HLS items. @@ -74,10 +79,10 @@ def get_scenes( valid_scenes = [] for item in best_first: product_id = get_product_id(item['umm']) - n_products = len(list(scratch_dir.glob(f'{product_id}*'))) + n_products = len(list(image_dir.glob(f'{product_id}*'))) if n_products < 15: - earthaccess.download([item], scratch_dir, pqdm_kwargs={'disable': True}) - fmask_path = scratch_dir / f'{product_id}.v2.0.Fmask.tif' + earthaccess.download([item], image_dir, pqdm_kwargs={'disable': True}) + fmask_path = image_dir / f'{product_id}.v2.0.Fmask.tif' assert fmask_path.exists(), f'File not found: {fmask_path}' qual_da = rioxarray.open_rasterio(fmask_path).rio.clip_box(*roi.bounds, crs='EPSG:4326') # type: ignore bit_masks = np.unpackbits(qual_da.data[0][..., np.newaxis], axis=-1) @@ -94,7 +99,7 @@ def get_scenes( return valid_scenes -def get_hls_data(chip: TerraMindChip, scratch_dir: Path, opts: utils.ChipDataOpts) -> xr.DataArray: +def get_hls_data(chip: TerraMindChip, image_dir: Path, opts: utils.ChipDataOpts) -> xr.Dataset: """Returns XArray DataArray of a Harmonized Landsat Sentinel-2 image for the given bounds and closest collection after date. """ @@ -109,24 +114,21 @@ def get_hls_data(chip: TerraMindChip, scratch_dir: Path, opts: utils.ChipDataOpt roi_buffered = roi.buffer(0.01) max_cloud_pct = opts.get('max_cloud_pct', 100) strategy = opts.get('strategy', 'BEST').upper() - scenes = get_scenes(results, roi, max_cloud_pct, strategy, scratch_dir) - das = [] + timesteps = get_scenes(results, roi, max_cloud_pct, strategy, image_dir) template = create_template_da(chip) - for scene in scenes: + timestep_arrays = [] + for scene in timesteps: product_id = get_product_id(scene['umm']) bands = BAND_SETS[product_id.split('.')[1]] + band_arrays = [] for band in bands: - image_path = scratch_dir / f'{product_id}.v2.0.{band}.tif' + image_path = image_dir / f'{product_id}.v2.0.{band}.tif' da = rioxarray.open_rasterio(image_path).rio.clip_box(*roi_buffered.bounds, crs='EPSG:4326') # type: ignore da_reproj = da.rio.reproject_match(template) - da_reproj['band'] = [bands[band]] - da_reproj = da_reproj.expand_dims({'time': [get_date(scene['umm']).replace(tzinfo=None)]}) - da_reproj['x'] = np.arange(0, chip.ncol) - da_reproj['y'] = np.arange(0, chip.nrow) - da_reproj.attrs = {} - das.append(da_reproj) - dataarray = xr.combine_by_coords(das, join='override').drop_vars('spatial_ref') - assert isinstance(dataarray, xr.DataArray) - dataarray = dataarray.expand_dims({'sample': [chip.name], 'platform': ['HLS']}) - dataarray.attrs = {} - return dataarray + band_arrays.append(da_reproj.data.squeeze()) + band_array = np.stack(band_arrays, axis=0) + timestep_arrays.append(band_array) + data_array = np.stack(timestep_arrays, axis=0) + dates = [get_date(scene['umm']).replace(tzinfo=None) for scene in timesteps] # type: ignore + dataset = create_dataset_chip(data_array, chip, dates, list(HLS_L_BANDS.values())) + return dataset diff --git a/src/satchip/chip_label.py b/src/satchip/chip_label.py index 05208ff..2b5f048 100644 --- a/src/satchip/chip_label.py +++ b/src/satchip/chip_label.py @@ -7,33 +7,27 @@ import xarray as xr from tqdm import tqdm -import satchip from satchip import utils +from satchip.chip_xr_base import create_dataset_chip from satchip.terra_mind_grid import TerraMindGrid -def get_overall_bounds(bounds: list) -> list: - minx = min([b[0] for b in bounds]) - miny = min([b[1] for b in bounds]) - maxx = max([b[2] for b in bounds]) - maxy = max([b[3] for b in bounds]) - return [minx, miny, maxx, maxy] - - def is_valuable(chip: np.ndarray) -> bool: vals = list(np.unique(chip)) return not vals == [0] -def chip_labels(label_path: Path, date: datetime, output_dir: Path) -> Path: +def chip_labels(label_path: Path, date: datetime, chip_dir: Path) -> list[Path]: + label_dir = chip_dir / 'LABEL' + label_dir.mkdir(parents=True, exist_ok=True) label = xr.open_dataarray(label_path) bbox = utils.get_epsg4326_bbox(label.rio.bounds(), label.rio.crs.to_epsg()) tm_grid = TerraMindGrid(latitude_range=(bbox[1], bbox[3]), longitude_range=(bbox[0], bbox[2])) - chips = {} + chip_paths = [] for tm_chip in tqdm(tm_grid.terra_mind_chips): chip = label.rio.reproject( dst_crs=f'EPSG:{tm_chip.epsg}', - resampling=rio.enums.Resampling(1), + resampling=rio.enums.Resampling(1), # type: ignore transform=tm_chip.rio_transform, shape=(tm_chip.nrow, tm_chip.ncol), ) @@ -41,42 +35,23 @@ def chip_labels(label_path: Path, date: datetime, output_dir: Path) -> Path: chip_array[np.isnan(chip_array)] = 0 chip_array = np.round(chip_array).astype(np.int16) if is_valuable(chip_array): - chips[tm_chip.name] = [chip_array, tm_chip] - - if len(chips) == 0: - raise ValueError(f'No valid chips found for {label_path.name}') - - coords = { - 'time': np.array([date]), - 'band': np.array(['labels']), - 'sample': np.array([str(x) for x in chips.keys()]), - 'y': np.arange(0, chip_array.shape[0]), - 'x': np.arange(0, chip_array.shape[1]), - } - print(f'Found {len(chips)} valid chips for {label_path.name}') - label_np = np.expand_dims(np.stack([val[0] for val in chips.values()], axis=0), axis=[0, 1]) - lats, lons = zip(*[val[1].center for val in chips.values()]) - - dataset = xr.Dataset(attrs={'date_created': date.isoformat(), 'satchip_version': satchip.__version__}) - dataset.attrs['bounds'] = get_overall_bounds([val[1].bounds for val in chips.values()]) - dataset['bands'] = xr.DataArray(label_np, coords=coords, dims=list(coords.keys())) - dataset['lats'] = xr.DataArray(np.array(lats), coords={'sample': coords['sample']}, dims=['sample']) - dataset['lons'] = xr.DataArray(np.array(lons), coords={'sample': coords['sample']}, dims=['sample']) - output_path = output_dir / label_path.with_suffix('.zarr.zip').name - utils.save_chip(dataset, output_path) - return output_path + dataset = create_dataset_chip(chip_array.reshape(1, *chip_array.shape), tm_chip, date, ['LABEL']) + chip_path = label_dir / f'{label_path.stem}_{tm_chip.name}.zarr.zip' + utils.save_chip(dataset, chip_path) + chip_paths.append(chip_path) + return chip_paths def main() -> None: parser = argparse.ArgumentParser(description='Chip a label image') parser.add_argument('labelpath', type=str, help='Path to the label image') parser.add_argument('date', type=str, help='Date and time of the image in ISO format (YYYY-MM-DDTHH:MM:SS)') - parser.add_argument('--outdir', default='.', type=str, help='Output directory for the chips') + parser.add_argument('--chipdir', default='.', type=str, help='Output directory for the chips') args = parser.parse_args() args.labelpath = Path(args.labelpath) args.date = datetime.fromisoformat(args.date) - args.outdir = Path(args.outdir) - chip_labels(args.labelpath, args.date, args.outdir) + args.chipdir = Path(args.chipdir) + chip_labels(args.labelpath, args.date, args.chipdir) if __name__ == '__main__': diff --git a/src/satchip/chip_sentinel1rtc.py b/src/satchip/chip_sentinel1rtc.py index bcb4993..aa1ff38 100644 --- a/src/satchip/chip_sentinel1rtc.py +++ b/src/satchip/chip_sentinel1rtc.py @@ -9,20 +9,24 @@ import xarray as xr from satchip import utils -from satchip.chip_xr_base import create_template_da +from satchip.chip_xr_base import create_dataset_chip, create_template_da from satchip.terra_mind_grid import TerraMindChip +S1RTC_BANDS = ['VV', 'VH'] + + def get_rtc_paths_for_chips( - terra_mind_chips: list[TerraMindChip], bounds: list[float], scratch_dir: Path, opts: utils.ChipDataOpts -) -> dict[str, list[Path]]: + terra_mind_chips: list[TerraMindChip], image_dir: Path, opts: utils.ChipDataOpts +) -> dict[str, list[utils.RtcImageSet]]: + bounds = utils.get_overall_bounds([chip.bounds for chip in terra_mind_chips]) _check_bounds_size(bounds) granules = _get_granules(bounds, opts['date_start'], opts['date_end']) slcs_for_chips = _get_slcs_for_each_chip(terra_mind_chips, granules, opts['strategy']) assert len(slcs_for_chips) == len(terra_mind_chips) - rtc_paths_for_chips = _get_rtcs_for(slcs_for_chips, scratch_dir) - return rtc_paths_for_chips + rtc_image_sets_for_chips = _get_rtcs_for(slcs_for_chips, image_dir) + return rtc_image_sets_for_chips def _check_bounds_size(bounds: list[float]) -> None: @@ -37,7 +41,7 @@ def _check_bounds_size(bounds: list[float]) -> None: def _get_granules(bounds: list[float], date_start: datetime, date_end: datetime) -> list[asf.S1Product]: date_start = date_start date_end = date_end + timedelta(days=1) # inclusive end - roi = shapely.box(*bounds) + roi = shapely.box(*bounds) # type: ignore search_results = asf.geo_search( intersectsWith=roi.wkt, start=date_start, @@ -78,25 +82,26 @@ def _get_pct_intersect(product: asf.S1Product, roi: shapely.geometry.Polygon) -> return intersection -def _get_rtcs_for(slcs_for_chips: dict[str, list[asf.S1Product]], scratch_dir: Path) -> dict[str, list[Path]]: +def _get_rtcs_for( + slcs_for_chips: dict[str, list[asf.S1Product]], image_dir: Path +) -> dict[str, list[utils.RtcImageSet]]: flat_slcs = sum(slcs_for_chips.values(), []) slc_names = set(granule.properties['sceneName'] for granule in flat_slcs) finished_rtc_jobs = _process_rtcs(slc_names) - paths_for_slc_name: dict[str, Path] = {} + image_set_for_slc_name: dict[str, utils.RtcImageSet] = {} for job in finished_rtc_jobs: - rtc_path = _download_hyp3_rtc(job, scratch_dir) - slc_name = job.job_parameters['granules'][0] - - paths_for_slc_name[slc_name] = rtc_path + rtc_image_set = _download_hyp3_rtc(job, image_dir) + slc_name = job.job_parameters['granules'][0] # type: ignore + image_set_for_slc_name[slc_name] = rtc_image_set - rtc_paths_for_chips: dict[str, list[Path]] = {} + image_sets_for_chips: dict[str, list[utils.RtcImageSet]] = {} for chip_name, chip_slcs in slcs_for_chips.items(): - rtc_paths = [paths_for_slc_name[name.properties['sceneName']] for name in chip_slcs] - rtc_paths_for_chips[chip_name] = rtc_paths + image_sets = [image_set_for_slc_name[name.properties['sceneName']] for name in chip_slcs] + image_sets_for_chips[chip_name] = image_sets - return rtc_paths_for_chips + return image_sets_for_chips def _process_rtcs(slc_names: set[str]) -> hyp3_sdk.Batch: @@ -114,9 +119,9 @@ def _process_rtcs(slc_names: set[str]) -> hyp3_sdk.Batch: batch = hyp3_sdk.Batch(hyp3_jobs) batch = hyp3.watch(batch) - assert all([j.succeeded() for j in batch]), 'One or more HyP3 jobs failed' + assert all([j.succeeded() for j in batch]), 'One or more HyP3 jobs failed' # type: ignore - return batch + return batch # type: ignore def _get_rtc_jobs_by_scene_name(hyp3: hyp3_sdk.HyP3) -> dict[str, hyp3_sdk.Job]: @@ -126,7 +131,7 @@ def _get_rtc_jobs_by_scene_name(hyp3: hyp3_sdk.HyP3) -> dict[str, hyp3_sdk.Job]: if not _is_valid_rtc_job(job): continue - name = job.job_parameters['granules'][0] + name = job.job_parameters['granules'][0] # type: ignore jobs_by_scene_name[name] = job return jobs_by_scene_name @@ -136,41 +141,38 @@ def _is_valid_rtc_job(job: hyp3_sdk.Job) -> bool: return ( not job.failed() and not job.expired() - and job.job_parameters['radiometry'] == 'gamma0' - and job.job_parameters['resolution'] == 20 + and job.job_parameters['radiometry'] == 'gamma0' # type: ignore + and job.job_parameters['resolution'] == 20 # type: ignore ) -def _download_hyp3_rtc(job: hyp3_sdk.Job, scratch_dir: Path) -> tuple[Path, Path]: - output_path = scratch_dir / job.to_dict()['files'][0]['filename'] +def _download_hyp3_rtc(job: hyp3_sdk.Job, image_dir: Path) -> utils.RtcImageSet: + output_path = image_dir / job.to_dict()['files'][0]['filename'] output_dir = output_path.with_suffix('') output_zip = output_path.with_suffix('.zip') if not output_dir.exists(): - job.download_files(location=scratch_dir) + job.download_files(location=image_dir) hyp3_sdk.util.extract_zipped_product(output_zip) vv_path = list(output_dir.glob('*_VV.tif'))[0] vh_path = list(output_dir.glob('*_VH.tif'))[0] - return vv_path, vh_path + image_set: utils.RtcImageSet = {'VV': vv_path, 'VH': vh_path} + return image_set -def get_s1rtc_chip_data( - chip: TerraMindChip, image_sets: list[Path], scratch_dir: Path, opts: utils.ChipDataOpts -) -> xr.DataArray: +def get_s1rtc_chip_data(chip: TerraMindChip, image_sets: list[utils.RtcImageSet]) -> xr.Dataset: roi = shapely.box(*chip.bounds) - das = [] template = create_template_da(chip) + timestep_arrays = [] for image_set in image_sets: - for band_name, image_path in zip(['VV', 'VH'], image_set): + band_arrays = [] + for band in S1RTC_BANDS: + image_path = image_set[band] da = rioxarray.open_rasterio(image_path).rio.clip_box(*roi.buffer(0.1).bounds, crs='EPSG:4326') # type: ignore da_reproj = da.rio.reproject_match(template) - da_reproj['band'] = [band_name] - image_time = datetime.strptime(image_path.name.split('_')[2], '%Y%m%dT%H%M%S') - da_reproj = da_reproj.expand_dims({'time': [image_time]}) - da_reproj['x'] = np.arange(0, chip.ncol) - da_reproj['y'] = np.arange(0, chip.nrow) - da_reproj.attrs = {} - das.append(da_reproj) - dataarray = xr.combine_by_coords(das, join='override').drop_vars('spatial_ref') - assert isinstance(dataarray, xr.DataArray) - dataarray = dataarray.expand_dims({'sample': [chip.name], 'platform': ['S1RTC']}) - return dataarray + band_arrays.append(da_reproj.data.squeeze()) + band_array = np.stack(band_arrays, axis=0) + timestep_arrays.append(band_array) + data_array = np.stack(timestep_arrays, axis=0) + dates = [datetime.strptime(image_set['VV'].name.split('_')[2], '%Y%m%dT%H%M%S') for image_set in image_sets] + dataset = create_dataset_chip(data_array, chip, dates, S1RTC_BANDS) + return dataset diff --git a/src/satchip/chip_sentinel2.py b/src/satchip/chip_sentinel2.py index b80c3ca..f7049d0 100644 --- a/src/satchip/chip_sentinel2.py +++ b/src/satchip/chip_sentinel2.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from pathlib import Path @@ -12,24 +13,26 @@ from pystac_client import Client from satchip import utils -from satchip.chip_xr_base import create_template_da +from satchip.chip_xr_base import create_dataset_chip, create_template_da from satchip.terra_mind_grid import TerraMindChip -S2_BANDS = { - 'B01': 'COASTAL', - 'B02': 'BLUE', - 'B03': 'GREEN', - 'B04': 'RED', - 'B05': 'REDEDGE1', - 'B06': 'REDEDGE2', - 'B07': 'REDEDGE3', - 'B08': 'NIR', - 'B8A': 'NIR08', - 'B09': 'NIR09', - 'B11': 'SWIR16', - 'B12': 'SWIR22', -} +S2_BANDS = OrderedDict( + { + 'B01': 'COASTAL', + 'B02': 'BLUE', + 'B03': 'GREEN', + 'B04': 'RED', + 'B05': 'REDEDGE1', + 'B06': 'REDEDGE2', + 'B07': 'REDEDGE3', + 'B08': 'NIR', + 'B8A': 'NIR08', + 'B09': 'NIR09', + 'B11': 'SWIR16', + 'B12': 'SWIR22', + } +) S3_FS = s3fs.S3FileSystem(anon=True) @@ -46,28 +49,28 @@ def url_to_s3path(url: str) -> str: return f'{bucket}/{key}' -def url_to_localpath(url: str, scratch_dir: Path) -> Path: - """Converts an S3 URL to a local file path in the given scratch directory.""" +def url_to_localpath(url: str, image_dir: Path) -> Path: + """Converts an S3 URL to a local file path in the given image directory.""" parsed = urlparse(url) name = '_'.join(parsed.path.lstrip('/').split('/')[-2:]) - local_file_path = scratch_dir / name + local_file_path = image_dir / name return local_file_path -def fetch_s3_file(url: str, scratch_dir: Path) -> Path: - """Fetches an S3 file to the given scratch directory if it doesn't already exist.""" - local_path = url_to_localpath(url, scratch_dir) +def fetch_s3_file(url: str, image_dir: Path) -> Path: + """Fetches an S3 file to the given image directory if it doesn't already exist.""" + local_path = url_to_localpath(url, image_dir) if not local_path.exists(): s3_path = url_to_s3path(url) S3_FS.get(s3_path, str(local_path)) return local_path -def multithread_fetch_s3_file(urls: list[str], scratch_dir: Path, max_workers: int = 8) -> None: - """Fetches multiple S3 files to the given scratch directory using multithreading.""" +def multithread_fetch_s3_file(urls: list[str], image_dir: Path, max_workers: int = 8) -> None: + """Fetches multiple S3 files to the given image directory using multithreading.""" s3_paths, download_paths = [], [] for url in urls: - local_path = url_to_localpath(url, scratch_dir) + local_path = url_to_localpath(url, image_dir) if not local_path.exists(): download_paths.append(local_path) s3_paths.append(url_to_s3path(url)) @@ -86,7 +89,7 @@ def get_pct_intersect(scene_geom: dict | None, roi: shapely.geometry.Polygon) -> def get_scenes( - items: list[Item], roi: shapely.geometry.Polygon, strategy: str, max_cloud_pct: int, scratch_dir: Path + items: list[Item], roi: shapely.geometry.Polygon, strategy: str, max_cloud_pct: int, image_dir: Path ) -> list[Item]: """Returns the best Sentinel-2 L2A scene from the given list of items. The best scene is defined as the earliest scene with the largest intersection with the roi and @@ -96,7 +99,7 @@ def get_scenes( items: List of Sentinel-2 L2A items. roi: Region of interest polygon. max_cloud_pct: Maximum percent of bad pixels allowed in the scene. - scratch_dir: Directory to store downloaded files. + image_dir: Directory to store downloaded files. Returns: The best Sentinel-2 L2A item. @@ -109,7 +112,7 @@ def get_scenes( valid_scenes = [] for item in best_first: scl_href = item.assets['scl'].href - local_path = fetch_s3_file(scl_href, scratch_dir) + local_path = fetch_s3_file(scl_href, image_dir) assert local_path.exists(), f'File not found: {local_path}' scl_da = rioxarray.open_rasterio(local_path).rio.clip_box(*roi.bounds, crs='EPSG:4326') # type: ignore scl_array = scl_da.data[0] @@ -128,12 +131,12 @@ def get_scenes( return valid_scenes -def get_s2l2a_data(chip: TerraMindChip, scratch_dir: Path, opts: utils.ChipDataOpts) -> xr.DataArray: +def get_s2l2a_data(chip: TerraMindChip, image_dir: Path, opts: utils.ChipDataOpts) -> xr.Dataset: """Get XArray DataArray of Sentinel-2 L2A image for the given bounds and best collection parameters. Args: chip: TerraMindChip object defining the area of interest. - scratch_dir: Directory to store downloaded files. + image_dir: Directory to store downloaded files. opts: Options dictionary with the following keys - date_start: Start date for the search. - date_end: End date for the search. @@ -141,7 +144,7 @@ def get_s2l2a_data(chip: TerraMindChip, scratch_dir: Path, opts: utils.ChipDataO - max_cloud_pct (optional): Maximum percent of bad pixels allowed in the scene. Returns: - XArray DataArray containing the Sentinel-2 L2A image data. + XArray Dataset containing the Sentinel-2 L2A image data. """ date_start = opts['date_start'] date_end = opts['date_end'] + timedelta(days=1) # inclusive end @@ -158,28 +161,28 @@ def get_s2l2a_data(chip: TerraMindChip, scratch_dir: Path, opts: utils.ChipDataO assert len(search.item_collection()) > 0, ( f'No Sentinel-2 L2A scenes found for chip {chip.name} between {date_start} and {date_end}.' ) + assert len(search.item_collection()) < 1000, ( + 'Too many Sentinel-2 L2A scenes found for chip. Please narrow the date range.' + ) items = list(search.item_collection()) max_cloud_pct = opts.get('max_cloud_pct', 100) strategy = opts.get('strategy', 'BEST') - items = get_scenes(items, roi, strategy, max_cloud_pct, scratch_dir) + timesteps = get_scenes(items, roi, strategy, max_cloud_pct, image_dir) urls = [item.assets[S2_BANDS[band].lower()].href for item in items for band in S2_BANDS] - multithread_fetch_s3_file(urls, scratch_dir) + multithread_fetch_s3_file(urls, image_dir) template = create_template_da(chip) - das = [] - for item in items: + timestep_arrays = [] + for item in timesteps: + band_arrays = [] for band in S2_BANDS: - local_path = url_to_localpath(item.assets[S2_BANDS[band].lower()].href, scratch_dir) + local_path = url_to_localpath(item.assets[S2_BANDS[band].lower()].href, image_dir) assert local_path.exists(), f'File not found: {local_path}' da = rioxarray.open_rasterio(local_path).rio.clip_box(*roi_buffered.bounds, crs='EPSG:4326') # type: ignore da_reproj = da.rio.reproject_match(template) - da_reproj['band'] = [S2_BANDS[band]] - da_reproj = da_reproj.expand_dims({'time': [item.datetime.replace(tzinfo=None)]}) # type: ignore - da_reproj['x'] = np.arange(0, chip.ncol) - da_reproj['y'] = np.arange(0, chip.nrow) - da_reproj.attrs = {} - das.append(da_reproj) - dataarray = xr.combine_by_coords(das, join='override').drop_vars('spatial_ref') - assert isinstance(dataarray, xr.DataArray) - dataarray = dataarray.expand_dims({'sample': [chip.name], 'platform': ['S2L2A']}) - dataarray.attrs = {} - return dataarray + band_arrays.append(da_reproj.data.squeeze()) + band_array = np.stack(band_arrays, axis=0) + timestep_arrays.append(band_array) + data_array = np.stack(timestep_arrays, axis=0) + dates = [item.datetime.replace(tzinfo=None) for item in timesteps] # type: ignore + dataset = create_dataset_chip(data_array, chip, dates, list(S2_BANDS.values())) + return dataset diff --git a/src/satchip/chip_view.py b/src/satchip/chip_view.py index b89c00d..c7a560d 100644 --- a/src/satchip/chip_view.py +++ b/src/satchip/chip_view.py @@ -1,64 +1,102 @@ import argparse -from collections import namedtuple -from itertools import product from pathlib import Path +import numpy as np +import xarray as xr from matplotlib import pyplot as plt from matplotlib.widgets import Slider from satchip.utils import load_chip -Index = namedtuple('Index', ['sample', 'time']) +def normalize_image_array(input_array: np.ndarray, vmin: float, vmax: float) -> np.ndarray: + """Function to normalize array values to a byte value between 0 and 255 + Args: + input_array: The array to normalize. + vmin: The minimum value to normalize to (mapped to 0). + vmax: The maximum value to normalize to (mapped to 255). -def view_chip(label_path: Path, band: str, idx: int = 0) -> None: + Returns: + The normalized array. + """ + input_array = input_array.astype(float) + scaled_array = (input_array - vmin) / (vmax - vmin) + scaled_array[np.isnan(input_array)] = 0 + normalized_array = np.round(np.clip(scaled_array, 0, 1) * 255).astype(np.uint8) + return normalized_array + + +def get_image_array(da: xr.Dataset, time_index: int, image_type: str, band: str | None) -> np.ndarray: + timestep = da['bands'].isel(time=time_index) + if image_type == 'rtc': + vv = normalize_image_array(np.sqrt(timestep.sel(band='VV').data), 0.14, 0.52) + vh = normalize_image_array(np.sqrt(timestep.sel(band='VH').data), 0.05, 0.259) + img = np.stack([vv, vh, vv], axis=-1) + elif image_type == 'optical': + red = normalize_image_array(timestep.sel(band='RED').data, 0, 3000) + green = normalize_image_array(timestep.sel(band='GREEN').data, 0, 3000) + blue = normalize_image_array(timestep.sel(band='BLUE').data, 0, 3000) + img = np.stack([red, green, blue], axis=-1) + elif image_type == 'label': + img = timestep.sel(band='LABEL').data + elif image_type == 'user': + assert band is not None, 'Band must be specified for user-defined image type' + img = timestep.sel(band=band).data + else: + raise ValueError(f'Unknown image type: {image_type}') + return img + + +def view_chip(label_path: Path, band: str | None) -> None: chip = load_chip(label_path) band_names = list(chip['band'].values) - if band not in band_names: - raise ValueError(f'Band {band} not found in chip. Available bands: {", ".join(band_names)}') - da = chip['data'].sel(band=band, platform=chip['platform'].values[0]).drop_vars(['platform', 'band']) - - indexes = [Index(s, t) for s, t in product(da.sample.values, da.time.values)] - - # Initial plot - start_index = indexes[idx] - img = da.sel(sample=start_index.sample, time=start_index.time).plot.imshow( - add_colorbar=True, cmap='gray', figsize=(10, 10) - ) - assert img.colorbar is not None - img.colorbar.set_label('') - ax = img.axes - date = start_index.time.astype('datetime64[ms]').astype(object).strftime('%Y-%m-%d') - title = f'Date: {date} | Band: {band} | Sample: {da.sample.values[idx]}' - ax.set_title(title) - ax.set_aspect('equal') - fig = ax.figure - - # Slider axis - slider_ax = fig.add_axes([0.2, 0.02, 0.6, 0.03]) # type: ignore - slider = Slider(slider_ax, 'Index', 0, len(indexes) - 1, valinit=idx, valstep=1) - - def update(val: int) -> None: - assert val is not None - index = indexes[int(slider.val)] - img.set_data(da.sel(sample=index.sample, time=index.time).values) - date = index.time.astype('datetime64[ms]').astype(object).strftime('%Y-%m-%d') - title = f'Date: {date} | Band: {band} | Sample: {index.sample}' - ax.set_title(title) - fig.canvas.draw_idle() - - slider.on_changed(update) # type: ignore + if band is not None: + if band not in band_names: + raise ValueError(f'Band {band} not found in chip. Available bands: {", ".join(band_names)}') + image_type = 'user' + elif any(b in band_names for b in ['VV', 'VH']): + image_type = 'rtc' + elif all(b in band_names for b in ['RED', 'GREEN', 'BLUE']): + image_type = 'optical' + elif 'LABEL' in band_names: + image_type = 'label' + else: + raise ValueError('Cannot determine image type. Please specify a band using --band.') + + times = chip.time.values + f, ax = plt.subplots(1, 1, figsize=(10, 10)) + time_index = 0 + im = ax.imshow(get_image_array(chip, time_index, image_type, band)) + title = ax.set_title(f'Date: {str(times[time_index]).split("T")[0]}') + if len(times) > 1: + ax_slider = plt.axes([0.25, 0.05, 0.5, 0.03]) # type: ignore + slider = Slider( + ax=ax_slider, + label='Date Index', + valmin=0, + valmax=len(times) - 1, + valinit=time_index, + valstep=1, + ) + + def update(val: int) -> None: + idx = int(slider.val) + im.set_data(get_image_array(chip, idx, image_type, band)) + title.set_text(f'Time: {str(times[idx]).split("T")[0]}') + f.canvas.draw_idle() + + slider.on_changed(update) # type: ignore + plt.show() def main() -> None: parser = argparse.ArgumentParser(description='Chip a label image') parser.add_argument('chippath', type=Path, help='Path to the label image') - parser.add_argument('band', type=str, help='Band to view') - parser.add_argument('--idx', type=int, default=0, help='Index of default sample to view') + parser.add_argument('--band', default=None, type=str, help='Band to view') args = parser.parse_args() - view_chip(args.chippath, args.band, args.idx) + view_chip(args.chippath, args.band) if __name__ == '__main__': diff --git a/src/satchip/chip_xr_base.py b/src/satchip/chip_xr_base.py index dd8ba53..360de67 100644 --- a/src/satchip/chip_xr_base.py +++ b/src/satchip/chip_xr_base.py @@ -1,13 +1,58 @@ +import datetime + import numpy as np import xarray as xr +import satchip from satchip.terra_mind_grid import TerraMindChip +def _check_spec(dataset: xr.Dataset) -> None: + assert isinstance(dataset, xr.Dataset) + dims = ['band', 'time', 'x', 'y'] + assert sorted(list(dataset.dims)) == dims # type: ignore + coords = ['band', 'sample', 'spatial_ref', 'time', 'x', 'y'] + assert sorted(list(dataset.coords)) == coords # type: ignore + assert dataset.sample.ndim == 0 + data_vars = ['bands', 'center_lat', 'center_lon', 'crs'] + assert sorted(list(dataset.data_vars)) == data_vars + assert 'date_created' in list(dataset.attrs.keys()) + assert 'satchip_version' in list(dataset.attrs.keys()) + assert 'bounds' in list(dataset.attrs.keys()) + + +def create_dataset_chip( + chip_array: np.ndarray, + tm_chip: TerraMindChip, + dates: list[datetime.datetime] | datetime.datetime, + bands: list[str], +) -> xr.Dataset: + x = tm_chip.minx + (np.arange(tm_chip.nrow) + 0.5) * tm_chip.xres + y = tm_chip.maxy + (np.arange(tm_chip.ncol) + 0.5) * tm_chip.yres + if isinstance(dates, datetime.datetime): + assert chip_array.ndim == 3, 'For single timestep, chip_array must have 3 dimensions (band, y, x)' + dates = [dates] + chip_array = chip_array.reshape(*(1, len(bands), tm_chip.ncol, tm_chip.nrow)) + else: + assert chip_array.ndim == 4, 'For multiple timesteps, chip_array must have 4 dimensions (time, band, y, x)' + coords = {'time': np.array(dates), 'band': np.array(bands), 'y': y, 'x': x} + now = datetime.datetime.now().isoformat() + dataset = xr.Dataset(attrs={'date_created': now, 'satchip_version': satchip.__version__}) + dataset.attrs['bounds'] = tm_chip.bounds + dataset = dataset.assign_coords(sample=tm_chip.name) + dataset = dataset.rio.write_crs(f'EPSG:{tm_chip.epsg}') + dataset['bands'] = xr.DataArray(chip_array, coords=coords, dims=['time', 'band', 'y', 'x']) + dataset['center_lat'] = xr.DataArray(tm_chip.center[1]) + dataset['center_lon'] = xr.DataArray(tm_chip.center[0]) + dataset['crs'] = xr.DataArray(tm_chip.epsg) + _check_spec(dataset) + return dataset + + def create_template_da(chip: TerraMindChip) -> xr.DataArray: """Create a template DataArray with the same dimensions and transform as a label chip.""" - x = np.arange(chip.nrow) * chip.gdal_transform[1] + chip.gdal_transform[0] + chip.gdal_transform[1] / 2 - y = np.arange(chip.ncol) * chip.gdal_transform[5] + chip.gdal_transform[3] + chip.gdal_transform[5] / 2 + x = chip.minx + (np.arange(chip.nrow) + 0.5) * chip.xres + y = chip.maxy + (np.arange(chip.ncol) + 0.5) * chip.yres template = xr.DataArray(np.zeros((chip.ncol, chip.nrow)), dims=('y', 'x'), coords={'y': y, 'x': x}) template.rio.write_crs(f'EPSG:{chip.epsg}', inplace=True) template.rio.write_transform(chip.rio_transform, inplace=True) diff --git a/src/satchip/utils.py b/src/satchip/utils.py index f0aad98..04d8286 100644 --- a/src/satchip/utils.py +++ b/src/satchip/utils.py @@ -1,4 +1,5 @@ import datetime +import warnings from pathlib import Path from typing import TypedDict @@ -7,6 +8,11 @@ from pyproj import CRS, Transformer +class RtcImageSet(TypedDict): + VV: Path + VH: Path + + class ChipDataRequiredOpts(TypedDict): strategy: str date_start: datetime.datetime @@ -15,6 +21,15 @@ class ChipDataRequiredOpts(TypedDict): class ChipDataOpts(ChipDataRequiredOpts, total=False): max_cloud_pct: int + local_hyp3_paths: dict[str, list[RtcImageSet]] + + +def get_overall_bounds(bounds: list) -> list: + minx = min([b[0] for b in bounds]) + miny = min([b[1] for b in bounds]) + maxx = max([b[2] for b in bounds]) + maxy = max([b[3] for b in bounds]) + return [minx, miny, maxx, maxy] def get_epsg4326_point(x: float, y: float, in_epsg: int) -> tuple[float, float]: @@ -38,12 +53,15 @@ def get_epsg4326_bbox( def save_chip(dataset: xr.Dataset, save_path: str | Path) -> None: """Save a zipped zarr archive""" - store = zarr.storage.ZipStore(save_path, mode='w') - dataset.to_zarr(store) + store = zarr.storage.ZipStore(save_path, mode='w') # type: ignore + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='Duplicate name:', module='zipfile') + dataset.to_zarr(store) + store.close() def load_chip(label_path: str | Path) -> xr.Dataset: """Load a zipped zarr archive""" - store = zarr.storage.ZipStore(label_path, read_only=True) + store = zarr.storage.ZipStore(label_path, read_only=True) # type: ignore dataset = xr.open_zarr(store) return dataset diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..dba034c --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,59 @@ +from datetime import datetime +from pathlib import Path + +import pytest +from osgeo import gdal + +from satchip.chip_data import create_chips +from satchip.chip_label import chip_labels + + +gdal.UseExceptions() + + +def create_dataset(outpath: Path, start: tuple[int, int]) -> Path: + x, y = start + pixel_size = 10 + cols, rows = 512, 512 + driver = gdal.GetDriverByName('GTiff') + dataset = driver.Create(str(outpath), cols, rows, 1, gdal.GDT_UInt16) + dataset.SetGeoTransform((x, pixel_size, 0, y, 0, -pixel_size)) + dataset.SetProjection('EPSG:32611') + array = dataset.GetRasterBand(1).ReadAsArray() + array[:, :] = 0 + array[128:384, 128:384] = 1 + dataset.GetRasterBand(1).WriteArray(array) + dataset.FlushCache() + dataset = None + return outpath + + +def create_label_and_data(label_tif, out_dir, image_dir): + chip_labels(label_tif, datetime.fromisoformat('20240115'), out_dir) + for platform in ['S2L2A', 'HLS', 'S1RTC']: + create_chips( + list((out_dir / 'LABEL').glob('*.zarr.zip')), + platform, + datetime.fromisoformat('20240101'), + datetime.fromisoformat('20240215'), + 'BEST', + 20, + out_dir, + image_dir, + ) + + +@pytest.mark.integration +def test_integration(): + data_dir = Path('integration_test') + train_dir = data_dir / 'train' + train_dir.mkdir(parents=True, exist_ok=True) + val_dir = data_dir / 'val' + val_dir.mkdir(parents=True, exist_ok=True) + image_dir = data_dir / 'images' + image_dir.mkdir(parents=True, exist_ok=True) + + train_tif = create_dataset(data_dir / 'train.tif', (431795, 3943142)) + create_label_and_data(train_tif, train_dir, image_dir) + val_tif = create_dataset(data_dir / 'val.tif', (431795, 3943142 - 10 * 512)) + create_label_and_data(val_tif, val_dir, image_dir)