diff --git a/examples/basic_workflow_gdf.py b/examples/basic_workflow_gdf.py new file mode 100644 index 0000000..e6d1761 --- /dev/null +++ b/examples/basic_workflow_gdf.py @@ -0,0 +1,70 @@ +# examples/basic_workflow.py +from pathlib import Path +from shapely.geometry import Polygon + +from rasteret import Rasteret + + +def main(): + """Demonstrate core workflows with Rasteret.""" + # 1. Define parameters + + custom_name = "bangalore3" + date_range = ("2024-01-01", "2024-01-31") + data_source = "landsat-c2l2-sr" + + workspace_dir = Path.home() / "rasteret_workspace" + workspace_dir.mkdir(exist_ok=True) + + print("1. Defining Area of Interest") + print("--------------------------") + + # Define area and time of interest + aoi_polygon = Polygon( + [(77.55, 13.01), (77.58, 13.01), (77.58, 13.08), (77.55, 13.08), (77.55, 13.01)] + ) + + aoi_polygon2 = Polygon( + [(77.56, 13.02), (77.59, 13.02), (77.59, 13.09), (77.56, 13.09), (77.56, 13.02)] + ) + + # get total bounds of all polygons above + bbox = aoi_polygon.union(aoi_polygon2).bounds + + print("\n2. Creating and Loading Collection") + print("--------------------------") + + # 2. Initialize processor - name generated automatically + processor = Rasteret( + custom_name=custom_name, + data_source=data_source, + output_dir=workspace_dir, + date_range=date_range, + ) + + # Create index if needed + if processor._collection is None: + processor.create_index( + bbox=bbox, date_range=date_range, query={"cloud_cover_lt": 20} + ) + + # List existing collections + collections = Rasteret.list_collections(dir=workspace_dir) + print("Available collections:") + for c in collections: + print(f"- {c['name']}: {c['size']} scenes") + + print("\n3. Processing Data") + print("----------------") + + df = processor.get_gdf( + geometries=[aoi_polygon, aoi_polygon2], bands=["B4", "B5"], cloud_cover_lt=20 + ) + + print(f"Columns: {df.columns}") + print(f"Unique dates: {df.datetime.dt.date.unique()}") + print(f"Unique geometries: {df.geometry.unique()}") + + +if __name__ == "__main__": + main() diff --git a/examples/basic_workflow_xarray.py b/examples/basic_workflow_xarray.py new file mode 100644 index 0000000..a50112a --- /dev/null +++ b/examples/basic_workflow_xarray.py @@ -0,0 +1,104 @@ +# examples/basic_workflow.py +from pathlib import Path +from shapely.geometry import Polygon +import xarray as xr + +from rasteret import Rasteret +from rasteret.constants import DataSources +from rasteret.core.utils import save_per_geometry + + +def main(): + + # 1. Define parameters + custom_name = "bangalore" + date_range = ("2024-01-01", "2024-01-31") + data_source = DataSources.LANDSAT # or SENTINEL2 + + workspace_dir = Path.home() / "rasteret_workspace" + workspace_dir.mkdir(exist_ok=True) + + print("1. Defining Area of Interest") + print("--------------------------") + + # Define area and time of interest + aoi1_polygon = Polygon( + [(77.55, 13.01), (77.58, 13.01), (77.58, 13.08), (77.55, 13.08), (77.55, 13.01)] + ) + + aoi2_polygon = Polygon( + [(77.56, 13.02), (77.59, 13.02), (77.59, 13.09), (77.56, 13.09), (77.56, 13.02)] + ) + + # get total bounds of all polygons above for stac search and stac index creation + bbox = aoi1_polygon.union(aoi2_polygon).bounds + + print("\n2. Creating and Loading Collection") + print("--------------------------") + + # 2. Initialize processor - name generated automatically + processor = Rasteret( + custom_name=custom_name, + data_source=data_source, + output_dir=workspace_dir, + date_range=date_range, + ) + + # Create index if needed + if processor._collection is None: + processor.create_index( + bbox=bbox, + date_range=date_range, + cloud_cover_lt=20, + # add platform filter for Landsat 9, 8, 7, 5, 4 if needed, + # else remove it for all platforms + # This is unique to Landsat STAC endpoint + platform={"in": ["LANDSAT_8"]}, + ) + + # List existing collections + collections = Rasteret.list_collections(dir=workspace_dir) + print("Available collections:") + for c in collections: + print(f"- {c['name']}: {c['size']} scenes") + + print("\n3. Processing Data") + print("----------------") + + # Calculate NDVI using xarray operations + ds = processor.get_xarray( + # pass multiple geometries not its union bounds + # for separate processing of each geometry + geometries=[aoi1_polygon, aoi2_polygon], + bands=["B4", "B5"], + cloud_cover_lt=20, + ) + + print("\nInput dataset:") + print(ds) + + # Calculate NDVI and preserve metadata + ndvi = (ds.B5 - ds.B4) / (ds.B5 + ds.B4) + ndvi_ds = xr.Dataset( + {"NDVI": ndvi}, + coords=ds.coords, # Preserve coordinates including CRS + attrs=ds.attrs, # Preserve metadata + ) + + print("\nNDVI dataset:") + print(ndvi_ds) + + # Create output directory + output_dir = Path("ndvi_results") + output_dir.mkdir(exist_ok=True) + + # Save per geometry, give prefix for output files in this case "ndvi" + output_files = save_per_geometry(ndvi_ds, output_dir, prefix="ndvi") + + print("\nProcessed NDVI files:") + for geom_id, filepath in output_files.items(): + print(f"Geometry {geom_id}: {filepath}") + + +if __name__ == "__main__": + main() diff --git a/src/rasteret/__init__.py b/src/rasteret/__init__.py new file mode 100644 index 0000000..9848078 --- /dev/null +++ b/src/rasteret/__init__.py @@ -0,0 +1,21 @@ +"""Rasteret package.""" + +from importlib.metadata import version as get_version + +from rasteret.core.processor import Rasteret +from rasteret.core.collection import Collection +from rasteret.cloud import CloudConfig, AWSProvider +from rasteret.constants import DataSources +from rasteret.logging import setup_logger + +# Set up logging +setup_logger("INFO") + + +def version(): + """Return the version of the rasteret package.""" + return get_version("rasteret") + +__version__ = version() + +__all__ = ["Collection", "Rasteret", "CloudConfig", "AWSProvider", "DataSources"] diff --git a/src/rasteret/cloud.py b/src/rasteret/cloud.py new file mode 100644 index 0000000..63f1c51 --- /dev/null +++ b/src/rasteret/cloud.py @@ -0,0 +1,118 @@ +""" Utilities for cloud storage """ + +from dataclasses import dataclass +from typing import Optional, Dict +import boto3 +from rasteret.logging import setup_logger + +logger = setup_logger() + + +@dataclass +class CloudConfig: + """Storage configuration for data source""" + + provider: str + requester_pays: bool = False + region: str = "us-west-2" + url_patterns: Dict[str, str] = None # Map HTTPS patterns to cloud URLs + + +# Configuration for supported data sources +CLOUD_CONFIG = { + "landsat-c2l2-sr": CloudConfig( + provider="aws", + requester_pays=True, + region="us-west-2", + url_patterns={"https://landsatlook.usgs.gov/data/": "s3://usgs-landsat/"}, + ), + "sentinel-2-l2a": CloudConfig( + provider="aws", requester_pays=False, region="us-west-2" + ), +} + + +class CloudProvider: + """Base class for cloud providers""" + + @staticmethod + def check_aws_credentials() -> bool: + """Check AWS credentials before any operations""" + try: + session = boto3.Session() + credentials = session.get_credentials() + if credentials is None: + logger.error( + "\nAWS credentials not found. To configure:\n" + "1. Create ~/.aws/credentials with:\n" + "[default]\n" + "aws_access_key_id = YOUR_ACCESS_KEY\n" + "aws_secret_access_key = YOUR_SECRET_KEY\n" + "OR\n" + "2. Set environment variables:\n" + "export AWS_ACCESS_KEY_ID='your_key'\n" + "export AWS_SECRET_ACCESS_KEY='your_secret'" + ) + return False + return True + except Exception: + return False + + def get_url(self, url: str, config: CloudConfig) -> str: + """Central URL resolution and signing method""" + raise NotImplementedError + + +class AWSProvider(CloudProvider): + def __init__(self, profile: Optional[str] = None, region: str = "us-west-2"): + if not self.check_aws_credentials(): + raise ValueError("AWS credentials not configured") + + try: + session = ( + boto3.Session(profile_name=profile) if profile else boto3.Session() + ) + self.client = session.client("s3", region_name=region) + except Exception as e: + logger.error(f"Failed to initialize AWS client: {str(e)}") + raise ValueError("AWS provider initialization failed") + + def get_url(self, url: str, config: CloudConfig) -> Optional[str]: + """Resolve and sign URL based on configuration""" + # First check for alternate S3 URL in STAC metadata + if isinstance(url, dict) and "alternate" in url and "s3" in url["alternate"]: + s3_url = url["alternate"]["s3"]["href"] + logger.debug(f"Using alternate S3 URL: {s3_url}") + url = s3_url + # Then check URL patterns if defined + elif config.url_patterns: + for http_pattern, s3_pattern in config.url_patterns.items(): + if url.startswith(http_pattern): + url = url.replace(http_pattern, s3_pattern) + logger.debug(f"Converted to S3 URL: {url}") + break + + # Sign URL if it's an S3 URL + if url.startswith("s3://"): + try: + bucket = url.split("/")[2] + key = "/".join(url.split("/")[3:]) + + params = { + "Bucket": bucket, + "Key": key, + } + if config.requester_pays: + params["RequestPayer"] = "requester" + + return self.client.generate_presigned_url( + "get_object", Params=params, ExpiresIn=3600 + ) + except Exception as e: + logger.error(f"Failed to sign URL {url}: {str(e)}") + return None + + return url + + +__all__ = ["CloudConfig", "AWSProvider"] diff --git a/src/rasteret/constants.py b/src/rasteret/constants.py new file mode 100644 index 0000000..da14659 --- /dev/null +++ b/src/rasteret/constants.py @@ -0,0 +1,89 @@ +"""Constants and configurations for rasteret.""" + +from typing import Dict +import pyarrow as pa + + +class DataSources: + """Registry of supported data sources based on STAC endpoint collection names.""" + + LANDSAT = "landsat-c2l2-sr" + SENTINEL2 = "sentinel-2-l2a" + + @classmethod + def list_sources(cls) -> Dict[str, str]: + """List available data sources with descriptions.""" + return { + cls.LANDSAT: "Landsat Collection 2 Level 2 Surface Reflectance", + cls.SENTINEL2: "Sentinel-2 Level 2A", + } + + +SENTINEL2_BANDS: Dict[str, str] = { + "B01": "coastal", + "B02": "blue", + "B03": "green", + "B04": "red", + "B05": "rededge1", + "B06": "rededge2", + "B07": "rededge3", + "B08": "nir", + "B8A": "nir08", + "B09": "nir09", + "B11": "swir16", + "B12": "swir22", + "SCL": "scl", +} + +LANDSAT9_BANDS: Dict[str, str] = { + "B1": "coastal", + "B2": "blue", + "B3": "green", + "B4": "red", + "B5": "nir08", + "B6": "swir16", + "B7": "swir22", + "qa_aerosol": "qa_aerosol", + "qa_pixel": "qa_pixel", + "qa_radsat": "qa_radsat", +} + + +STAC_ENDPOINTS = { + "sentinel-2-l2a": "https://earth-search.aws.element84.com/v1", + "landsat-c2l2-sr": "https://landsatlook.usgs.gov/stac-server", +} + +STAC_COLLECTION_BAND_MAPS = { + "sentinel-2-l2a": SENTINEL2_BANDS, + "landsat-c2l2-sr": LANDSAT9_BANDS, +} + +# Metadata struct for COG headers +COG_BAND_METADATA_STRUCT = pa.struct( + [ + ("image_width", pa.int32()), + ("image_height", pa.int32()), + ("tile_width", pa.int32()), + ("tile_height", pa.int32()), + ("dtype", pa.string()), + ("transform", pa.list_(pa.float64())), + ("predictor", pa.int32()), + ("compression", pa.int32()), + ("tile_offsets", pa.list_(pa.int64())), + ("tile_byte_counts", pa.list_(pa.int64())), + ("pixel_scale", pa.list_(pa.float64())), + ("tiepoint", pa.list_(pa.float64())), + ] +) + +# Default partition settings +DEFAULT_GEO_PARQUET_SETTINGS = { + "compression": "zstd", + "compression_level": 3, + "row_group_size": 20 * 1024 * 1024, + "write_statistics": True, + "use_dictionary": True, + "write_batch_size": 10000, + "basename_template": "part-{i}.parquet", +} diff --git a/src/rasteret/core/__init__.py b/src/rasteret/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rasteret/core/collection.py b/src/rasteret/core/collection.py new file mode 100644 index 0000000..c193241 --- /dev/null +++ b/src/rasteret/core/collection.py @@ -0,0 +1,381 @@ +""" Collection class for managing raster data collections. """ + +from __future__ import annotations +from datetime import datetime + +import pyarrow as pa +import pyarrow.parquet as pq +import pyarrow.dataset as ds +import pandas as pd +import geopandas as gpd + +from shapely.geometry import Polygon +from pathlib import Path +from typing import AsyncIterator, Optional, Union, Dict, List, Any, Tuple + +from rasteret.types import SceneInfo +from rasteret.core.scene import Scene +from rasteret.logging import setup_logger + +logger = setup_logger("INFO", customname="rasteret.core.collection") + + +class Collection: + """ + A collection of raster data with flexible initialization. + + Collections can be created from: + - Local partitioned datasets + - Single Arrow tables + - Empty (for building gradually) + + Collections maintain efficient partitioned storage when using files. + + Examples + -------- + # From partitioned dataset + >>> collection = Collection.from_local("path/to/dataset") + + # Filter and process + >>> filtered = collection.filter_scenes(cloud_cover_lt=20) + >>> ds = filtered.get_xarray(...) + """ + + def __init__( + self, dataset: ds.Dataset, name: str, description: Optional[str] = None + ): + """Initialize collection with dataset and name.""" + self.dataset = dataset + self.name = name + self.description = description + self._storage_path = None + self._validate_parquet_dataset() + + @classmethod + def from_local(cls, path: Union[str, Path]) -> Collection: + """ + Create collection from local partitioned dataset. + + Parameters + ---------- + path : str or Path + Path to dataset directory with Hive-style partitioning + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Dataset not found at {path}") + + try: + dataset = ds.dataset( + str(path), + format="parquet", + partitioning=ds.HivePartitioning( + pa.schema([("year", pa.int32()), ("month", pa.int32())]) + ), + exclude_invalid_files=True, + filesystem=pa.fs.LocalFileSystem(), + ) + except Exception as e: + raise ValueError(f"Invalid dataset at {path}: {str(e)}") + + return cls(dataset=dataset, name=path.name) + + def filter_scenes(self, **kwargs) -> Collection: + """ + Filter collection creating new view. + + Parameters + ---------- + **kwargs : + Supported filters: + - cloud_cover_lt: float + - date_range: Tuple[str, str] + - bbox: Tuple[float, float, float, float] + """ + filter_expr = None + + # Build filter expression + if len(self.dataset.to_table()) == 0: + return self + + if "cloud_cover_lt" in kwargs: + if "eo:cloud_cover" not in self.dataset.schema.names: + raise ValueError("Collection has no cloud cover data") + + if not isinstance(kwargs["cloud_cover_lt"], (int, float)): + raise ValueError("Invalid cloud cover value") + elif kwargs["cloud_cover_lt"] < 0 or kwargs["cloud_cover_lt"] > 100: + raise ValueError("Invalid cloud cover value") + + filter_expr = ds.field("eo:cloud_cover") < kwargs["cloud_cover_lt"] + + if "date_range" in kwargs: + if "datetime" not in self.dataset.schema.names: + raise ValueError("Collection has no datetime data") + + start, end = kwargs["date_range"] + + if not (start and end): + raise ValueError("Invalid date range") + elif start > end: + raise ValueError("Invalid date range") + elif start == end: + raise ValueError("Date range must be > 1 day") + elif len(start) != 10 or len(end) != 10: + raise ValueError("Date format must be 'YYYY-MM-DD'") + + start_ts = pd.Timestamp(start).tz_localize("UTC") + end_ts = pd.Timestamp(end).tz_localize("UTC") + + # Convert to Arrow timestamps + start_timestamp = pa.scalar(start_ts, type=pa.timestamp("us", tz="UTC")) + end_timestamp = pa.scalar(end_ts, type=pa.timestamp("us", tz="UTC")) + + date_filter = (ds.field("datetime") >= start_timestamp) & ( + ds.field("datetime") <= end_timestamp + ) + filter_expr = ( + date_filter if filter_expr is None else filter_expr & date_filter + ) + + if "bbox" in kwargs: + if "scene_bbox" not in self.dataset.schema.names: + raise ValueError("Collection has no bbox data") + bbox = kwargs["bbox"] + + if len(bbox) != 4: + raise ValueError("Invalid bbox format") + elif bbox[0] > bbox[2] or bbox[1] > bbox[3]: + raise ValueError("Invalid bbox coordinates") + elif any(not isinstance(coord, (int, float)) for coord in bbox): + raise ValueError("Invalid bbox coordinates") + + bbox_filter = ( + (ds.field("scene_bbox").x0 >= bbox[0]) + & (ds.field("scene_bbox").y0 >= bbox[1]) + & (ds.field("scene_bbox").x1 <= bbox[2]) + & (ds.field("scene_bbox").y1 <= bbox[3]) + ) + filter_expr = ( + bbox_filter if filter_expr is None else filter_expr & bbox_filter + ) + + if "geometries" in kwargs: + if "scene_bbox" not in self.dataset.schema.names: + raise ValueError("Collection has no bbox data") + geometries = kwargs["geometries"] + + if not all(isinstance(geom, Polygon) for geom in geometries): + raise ValueError("Invalid geometry format") + + bbox_filters = [ + (ds.field("scene_bbox").x0 >= geom.bounds[0]) + & (ds.field("scene_bbox").y0 >= geom.bounds[1]) + & (ds.field("scene_bbox").x1 <= geom.bounds[2]) + & (ds.field("scene_bbox").y1 <= geom.bounds[3]) + for geom in geometries + ] + filter_expr = bbox_filters[0] + for bbox_filter in bbox_filters[1:]: + filter_expr |= bbox_filter + + if filter_expr is None: + raise ValueError("No valid filters provided") + + filtered_dataset = self.dataset.filter(filter_expr) + return Collection(dataset=filtered_dataset, name=self.name) + + def to_geodataframe(self) -> gpd.GeoDataFrame: + """ + Convert collection to GeoDataFrame for analysis. + + Returns: + GeoDataFrame with scene metadata and geometries + """ + if len(self.dataset.to_table()) == 0: + return gpd.GeoDataFrame() + + df = self.dataset.to_table().to_pandas() + return gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:4326") + + @classmethod + def list_collections(cls, output_dir: Union[str, Path]) -> List[Dict[str, Any]]: + """List valid parquet collections with year/month partitioning.""" + path = Path(output_dir) + collections = [] + + # Skip these folders + IGNORE_FOLDERS = {".git", "__pycache__", "temp"} + + for collection_dir in path.iterdir(): + if not collection_dir.is_dir() or collection_dir.name in IGNORE_FOLDERS: + continue + + # Only process _stac folders + if not collection_dir.name.endswith("_stac"): + continue + + try: + # Now pass only numeric subfolders to ds.dataset + dataset = ds.dataset( + str(collection_dir), + format="parquet", + partitioning=ds.HivePartitioning( + pa.schema([("year", pa.int32()), ("month", pa.int32())]) + ), + exclude_invalid_files=True, + filesystem=pa.fs.LocalFileSystem(), + ) + + # Validate dataset has data + if dataset.files: + table = dataset.to_table() + collections.append( + { + "name": collection_dir.name, + "size": len(table), + "created": collection_dir.stat().st_ctime, + } + ) + + except Exception as e: + logger.debug(f"Skipping {collection_dir}: {str(e)}") + continue + + return collections + + def save_to_parquet( + self, path: Union[str, Path], partition_by: List[str] = ["year", "month"] + ) -> None: + """Save collection to local storage as partitioned dataset.""" + path = Path(path) + + path.mkdir(parents=True, exist_ok=True) + + if self.dataset is None: + raise ValueError("No Pyarrow dataset provided") + elif len(self.dataset.to_table()) == 0: + raise ValueError("No data to save") + elif not partition_by: + raise ValueError("Partition columns required") + elif any(col not in self.dataset.schema.names for col in partition_by): + raise ValueError("Partition columns not found in schema") + elif not path.is_dir(): + raise ValueError("Invalid directory path") + + # Get table and add metadata + table = self.dataset.to_table() + custom_metadata = { + b"description": ( + self.description.encode("utf-8") if self.description else b"" + ), + b"created": str(datetime.now()).encode("utf-8"), + } + + # Merge with existing metadata + merged_metadata = {**custom_metadata, **(table.schema.metadata or {})} + table_with_metadata = table.replace_schema_metadata(merged_metadata) + + # Write dataset + pq.write_to_dataset( + table_with_metadata, + root_path=str(path), + partition_cols=partition_by, + compression="zstd", + compression_level=3, + row_group_size=20 * 1024 * 1024, + write_statistics=True, + use_dictionary=True, + write_batch_size=10000, + basename_template="part-{i}.parquet", + ) + + async def iterate_scenes(self, data_source: str) -> AsyncIterator[Scene]: + """ + Iterate through scenes. + + Args: + data_source: Data source for the scenes + + Yields + ------ + Scene + Scene objects for processing + """ + required_fields = {"id", "datetime", "geometry", "assets"} + + if len(self.dataset.to_table()) == 0: + return + + # Check required fields + missing = required_fields - set(self.dataset.schema.names) + if missing: + raise ValueError(f"Missing required fields: {missing}") + + for batch in self.dataset.to_batches(): + for row in batch.to_pylist(): + try: + scene_info = SceneInfo( + id=row["id"], + datetime=row["datetime"], + scene_geometry=row["geometry"], + bbox=row["scene_bbox"], + crs=row.get("proj:epsg", None), + cloud_cover=row.get("eo:cloud_cover", 0), + assets=row["assets"], + metadata=self._extract_band_metadata(row), + collection=row.get( + "collection", data_source + ), # Use data_source as default collection + ) + yield Scene( + scene_info, data_source + ) # Pass data_source to Scene constructor + except Exception as e: + # Log error but continue with other scenes + logger.error(f"Error creating scene from row: {str(e)}") + continue + + async def get_first_scene(self) -> Scene: + """ + Get first scene in collection. + + Returns + ------- + Scene + Scene object for processing + """ + async for scene in self.iterate_scenes(data_source=self.name): + return scene + raise ValueError("No scenes found in collection") + + def _validate_parquet_dataset(self) -> None: + """Basic dataset validation.""" + if not isinstance(self.dataset, ds.Dataset): + raise TypeError("Expected pyarrow.dataset.Dataset") + + def _extract_band_metadata(self, row: Dict) -> Dict: + """Extract band metadata from row.""" + return {k: v for k, v in row.items() if k.endswith("_metadata")} + + @classmethod + def create_name( + cls, custom_name: str, date_range: Tuple[str, str], data_source: str + ) -> str: + """Create standardized collection name internally.""" + start_date = pd.to_datetime(date_range[0]) + name_parts = [ + custom_name.lower().replace(" ", "_"), + start_date.strftime("%Y%m"), + data_source.split("-")[0], # First part of data source + ] + return "_".join(name_parts) + + @classmethod + def parse_name(cls, name: str) -> Dict[str, str]: + """Parse collection name components internally.""" + try: + custom_name, date, source = name.split("_") + return {"custom_name": custom_name, "date": date, "data_source": source} + except ValueError: + return {"name": name} # Fallback if name doesn't match pattern diff --git a/src/rasteret/core/processor.py b/src/rasteret/core/processor.py new file mode 100644 index 0000000..0fb019b --- /dev/null +++ b/src/rasteret/core/processor.py @@ -0,0 +1,324 @@ +""" +Rasteret: Efficient satellite imagery retrieval and processing +========================================================== + +Core Components: +-------------- +- Rasteret: Main interface for querying and processing scenes +- Collection: Manages indexed satellite data +- Scene: Handles individual scene processing + +Example: +------- +>>> from rasteret import Rasteret +>>> processor = Rasteret( +... data_source="landsat-c2l2-sr", +... output_dir="workspace" +... ) +>>> processor.create_index( +... bbox=[77.55, 13.01, 77.58, 13.04], +... date_range=["2024-01-01", "2024-01-31"] +... ) +""" + +import asyncio +from pathlib import Path +from typing import Dict, List, Optional, Union, Tuple +import xarray as xr +import geopandas as gpd +import pandas as pd +from shapely.geometry import Polygon +from tqdm.asyncio import tqdm + +from rasteret.constants import STAC_ENDPOINTS +from rasteret.core.collection import Collection +from rasteret.stac.indexer import StacToGeoParquetIndexer +from rasteret.cloud import CLOUD_CONFIG, AWSProvider, CloudProvider +from rasteret.logging import setup_logger + +logger = setup_logger("INFO", customname="rasteret.processor") + + +class Rasteret: + """Main interface for satellite data retrieval and processing. + + Attributes: + data_source (str): Source dataset identifier + output_dir (Path): Directory for storing collections + custom_name (str, optional): Collection name prefix + date_range (Tuple[str, str], optional): Date range for collection + aws_profile (str, optional): AWS profile for authentication + + Examples: + >>> processor = Rasteret("landsat-c2l2-sr", "workspace") + >>> processor.create_index( + ... bbox=[77.55, 13.01, 77.58, 13.04], + ... date_range=["2024-01-01", "2024-01-31"] + ... ) + >>> df = processor.query(geometries=[polygon], bands=["B4", "B5"]) + """ + + def __init__( + self, + data_source: str, + output_dir: Union[str, Path], + custom_name: Optional[str] = None, + date_range: Optional[Tuple[str, str]] = None, + aws_profile: Optional[str] = None, + ): + self.data_source = data_source + self.output_dir = Path(output_dir) + + # Check credentials early if needed + self.cloud_config = CLOUD_CONFIG.get(data_source) + if self.cloud_config and self.cloud_config.requester_pays: + if not CloudProvider.check_aws_credentials(): + raise ValueError( + f"Data source '{data_source}' requires valid AWS credentials" + ) + + # Generate name if not provided + if custom_name and date_range: + custom_name = Collection.create_name( + custom_name=custom_name, date_range=date_range, data_source=data_source + ) + + self.custom_name = custom_name + + # Check if collection exists + if custom_name: + collection_path = self.output_dir / f"{custom_name}_stac" + if collection_path.exists(): + logger.info(f"Loading existing collection: {custom_name}") + self._collection = Collection.from_local(collection_path) + else: + logger.warning( + f"Collection '{custom_name}' not found. " + "Use create_index() to initialize collection." + ) + self._collection = None + else: + self._collection = None + + # Initialize cloud provider if needed + self.cloud_config = CLOUD_CONFIG.get(data_source) + self.provider = None + if self.cloud_config and self.cloud_config.requester_pays: + self.provider = AWSProvider( + profile=aws_profile, region=self.cloud_config.region + ) + logger.info(f"Using {self.provider} as cloud provider") + + def _get_collection_path(self) -> Path: + """Get expected collection path""" + return self.output_dir / f"{self.custom_name}.parquet" + + def _get_bbox_from_geometries(self, geometries: List[Polygon]) -> List[float]: + """Get combined bbox from geometries""" + bounds = [geom.bounds for geom in geometries] + return [ + min(b[0] for b in bounds), # minx + min(b[1] for b in bounds), # miny + max(b[2] for b in bounds), # maxx + max(b[3] for b in bounds), # maxy + ] + + def _ensure_collection(self, geometries: List[Polygon], **filters) -> None: + """Ensure collection exists and is loaded with proper partitioning.""" + stac_path = self.output_dir / f"{self.custom_name}_stac" + + if self._collection is None: + if stac_path.exists(): + try: + # Use partitioned dataset loading + self._collection = Collection.from_local(stac_path) + logger.info(f"Loaded collection from {stac_path}") + return + except Exception as e: + logger.error(f"Failed to load collection: {e}") + + # No valid collection found + bbox = self._get_bbox_from_geometries(geometries) + error_msg = ( + f"\nNo valid collection found at: {stac_path}\n" + f"\nTo create collection run:\n" + f"processor.create_index(\n" + f" bbox={bbox},\n" + f" date_range=['YYYY-MM-DD', 'YYYY-MM-DD'],\n" + f" query={filters}\n" + f")" + ) + raise ValueError(error_msg) + + def create_index( + self, bbox: List[float], date_range: List[str], force: bool = False, **filters + ) -> None: + """ + Create or load STAC index. + + Args: + bbox: Bounding box + date_range: Date range + query: Optional STAC query + force: If True, recreate index even if exists + """ + stac_path = self.output_dir / f"{self.custom_name}_stac" + output_path = self.output_dir / f"{self.custom_name}_outputs" + + # Check if collection exists + if stac_path.exists() and not force: + logger.info(f"Collection {self.custom_name} exists, loading from disk") + self._collection = Collection.from_local(stac_path) + return + + # Create new collection + indexer = StacToGeoParquetIndexer( + data_source=self.data_source, + stac_api=STAC_ENDPOINTS[self.data_source], + output_dir=stac_path, + cloud_provider=self.provider, + cloud_config=self.cloud_config, + name=self.custom_name, + ) + + self._collection = asyncio.run( + indexer.build_index(bbox=bbox, date_range=date_range, query=filters) + ) + + output_path.mkdir(parents=True, exist_ok=True) + + @classmethod + def list_collections(self, dir) -> List[Dict]: + """List available collections in directory.""" + + if not Path(dir).exists(): + raise FileNotFoundError(f"Directory {dir} not found") + if dir is None: + logger.warning("No output directory provided, check default location") + dir = Path.home() / "rasteret_workspace" + + return Collection.list_collections(output_dir=dir) + + def get_gdf( + self, + geometries: List[Polygon], + bands: List[str], + max_concurrent: int = 50, + **filters, + ) -> gpd.GeoDataFrame: + """Query indexed scenes matching filters for specified geometries and bands.""" + # Validate inputs + if not geometries: + raise ValueError("No geometries provided") + if not bands: + raise ValueError("No bands specified") + + # Ensure collection exists and is loaded + self._ensure_collection(geometries, **filters) + + if filters: + self._collection = self._collection.filter_scenes(**filters) + + return asyncio.run(self._get_gdf(geometries, bands, max_concurrent)) + + async def _get_gdf(self, geometries, bands, max_concurrent): + total_scenes = len(self._collection.dataset.to_table()) + logger.info(f"Processing {total_scenes} scenes for {len(bands)} bands") + results = [] + + async for scene in tqdm( + self._collection.iterate_scenes(self.data_source), + total=total_scenes, + desc="Loading scenes", + ): + scene_results = await scene.load_bands( + geometries, + bands, + max_concurrent, + cloud_provider=self.provider, + cloud_config=self.cloud_config, + for_xarray=False, + ) + results.append(scene_results) + + return gpd.GeoDataFrame(pd.concat(results, ignore_index=True)) + + def get_xarray( + self, + geometries: List[Polygon], + bands: List[str], + max_concurrent: int = 50, + **filters, + ) -> xr.Dataset: + """ + Query collection and return as xarray Dataset. + + Args: + geometries: List of polygons to query + bands: List of band identifiers + max_concurrent: Maximum concurrent requests + **filters: Additional filters (e.g. cloud_cover_lt=20) + + Returns: + xarray Dataset with data, coordinates, and CRS + """ + # Same validation as query() + if not geometries: + raise ValueError("No geometries provided") + if not bands: + raise ValueError("No bands specified") + + self._ensure_collection(geometries, **filters) + + if filters: + self._collection = self._collection.filter_scenes(**filters) + + return asyncio.run( + self._get_xarray( + geometries=geometries, bands=bands, max_concurrent=max_concurrent + ) + ) + + async def _get_xarray( + self, + geometries: List[Polygon], + bands: List[str], + max_concurrent: int, + for_xarray: bool = True, + ) -> xr.Dataset: + datasets = [] + total_scenes = len(self._collection.dataset.to_table()) + + logger.info(f"Processing {total_scenes} scenes for {len(bands)} bands") + + async for scene in tqdm( + self._collection.iterate_scenes(self.data_source), + total=total_scenes, + desc="Loading scenes", + ): + scene_ds = await scene.load_bands( + geometries=geometries, + band_codes=bands, + max_concurrent=max_concurrent, + cloud_provider=self.provider, + cloud_config=self.cloud_config, + for_xarray=for_xarray, + ) + if scene_ds is not None: + datasets.append(scene_ds) + logger.debug( + f"Loaded scene {scene.id} ({len(datasets)}/{total_scenes})" + ) + + if not datasets: + raise ValueError("No valid data found for query") + + logger.info(f"Merging {len(datasets)} datasets") + merged = xr.merge(datasets) + merged = merged.sortby("time") + return merged + + def __repr__(self): + return ( + f"Rasteret(data_source={self.data_source}, custom_name={self.custom_name})" + ) diff --git a/src/rasteret/core/scene.py b/src/rasteret/core/scene.py new file mode 100644 index 0000000..1f22418 --- /dev/null +++ b/src/rasteret/core/scene.py @@ -0,0 +1,305 @@ +""" Scene class for handling COG data loading and processing. """ + +from __future__ import annotations +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from shapely.geometry import Polygon +import geopandas as gpd +import rioxarray # noqa +import asyncio +from tqdm.asyncio import tqdm + +from rasteret.types import SceneInfo, CogMetadata +from rasteret.constants import STAC_COLLECTION_BAND_MAPS +from rasteret.fetch.cog import read_cog_tile_data +from rasteret.cloud import CloudProvider, CloudConfig +from rasteret.logging import setup_logger + +logger = setup_logger("INFO") + + +class Scene: + """ + A single scene with associated metadata and data access methods. + + Scene handles the actual data loading from COGs, including: + - Async data loading + - Tile management + - Geometry masking + """ + + def __init__(self, info: SceneInfo, data_source: str) -> None: + """Initialize Scene from metadata. + + Args: + info: Scene metadata including urls and COG info + """ + self.id = info.id + self.datetime = info.datetime + self.bbox = info.bbox + self.scene_geometry = info.scene_geometry + self.crs = info.crs + self.cloud_cover = info.cloud_cover + self.assets = info.assets + self.scene_metadata = info.metadata + self.collection = info.collection + self.data_source = data_source + + def _get_band_radiometric_params( + self, band_code: str + ) -> Optional[Dict[str, float]]: + """Get radiometric parameters from STAC metadata if available.""" + try: + asset = self.assets[band_code] + band_info = asset["raster:bands"][0] + + if "scale" in band_info and "offset" in band_info: + return { + "scale": float(band_info["scale"]), + "offset": float(band_info["offset"]), + } + except (KeyError, IndexError): + pass + + return None + + def _get_asset_url( + self, asset: Dict, provider: CloudProvider, cloud_config: CloudConfig + ) -> str: + """Get authenticated URL for asset""" + url = asset["href"] if isinstance(asset, dict) else asset + if provider and cloud_config: + return provider.get_url(url, cloud_config) + return url + + def get_band_cog_metadata( + self, + band_code: str, + provider: Optional[CloudProvider] = None, + cloud_config: Optional[CloudConfig] = None, + ) -> Tuple[CogMetadata, str]: + """Get COG metadata and url for a specified band.""" + + actual_band_code = STAC_COLLECTION_BAND_MAPS.get(self.data_source, {}).get( + band_code, band_code + ) + + if actual_band_code not in self.assets: + raise ValueError(f"Band {band_code} not found in assets") + + asset = self.assets[actual_band_code] + + # Prefer S3 URL for AWS assets + url = self._get_asset_url(asset, provider, cloud_config) + + # Band metadata key could be either band_code or actual_band_code + metadata_keys = [f"{band_code}_metadata", f"{actual_band_code}_metadata"] + raw_metadata = None + for key in metadata_keys: + if key in self.scene_metadata: + raw_metadata = self.scene_metadata[key] + break + + if raw_metadata is None or url is None: + logger.error( + f"Metadata not found for band {band_code} in scene {self.id}. Available keys: {list(self.scene_metadata.keys())}" + ) + return None, None + + try: + cog_metadata = CogMetadata( + width=raw_metadata.get("image_width", raw_metadata.get("width")), + height=raw_metadata.get("image_height", raw_metadata.get("height")), + tile_width=raw_metadata["tile_width"], + tile_height=raw_metadata["tile_height"], + dtype=np.dtype(raw_metadata["dtype"]), + transform=raw_metadata["transform"], + crs=self.crs, + tile_offsets=raw_metadata["tile_offsets"], + tile_byte_counts=raw_metadata["tile_byte_counts"], + predictor=raw_metadata.get("predictor"), + compression=raw_metadata.get("compression"), + pixel_scale=raw_metadata.get("pixel_scale"), + tiepoint=raw_metadata.get("tiepoint"), + ) + return cog_metadata, url + except KeyError as e: + logger.error( + f"Missing required metadata field {e} for band {band_code} in scene {self.id}" + ) + logger.debug(f"Available metadata: {raw_metadata}") + return None, None + + def intersects(self, geometry: Polygon) -> bool: + """Check if scene intersects with geometry.""" + return self.geometry.intersects(geometry) + + @property + def available_bands(self) -> List[str]: + """List of available bands for this scene.""" + return list(self.assets.keys()) + + def __repr__(self) -> str: + return ( + f"Scene(id='{self.id}', " + f"datetime='{self.datetime}', " + f"cloud_cover={self.cloud_cover})" + ) + + async def _load_single_band( + self, + geometry: Polygon, + band_code: str, + cloud_provider: Optional[CloudProvider], + cloud_config: Optional[CloudConfig], + max_concurrent: int = 50, + ) -> Optional[Dict]: + """Load single band data for geometry.""" + cog_meta, url = self.get_band_cog_metadata( + band_code, provider=cloud_provider, cloud_config=cloud_config + ) + if not cog_meta or not url: + return None + + data, transform = await read_cog_tile_data( + url, cog_meta, geometry, max_concurrent + ) + if data is None or transform is None: + return None + + return {"data": data, "transform": transform, "band": band_code} + + async def load_bands( + self, + geometries: List[Polygon], + band_codes: List[str], + max_concurrent: int = 50, + cloud_provider: Optional[CloudProvider] = None, + cloud_config: Optional[CloudConfig] = None, + for_xarray: bool = True, + ) -> Union[gpd.GeoDataFrame, xr.Dataset]: + """Load bands with parallel processing and progress tracking.""" + + logger.debug( + f"Loading {len(band_codes)} bands for {len(geometries)} geometries" + ) + + geom_progress = tqdm(total=len(geometries), desc=f"Scene {self.id}", position=0) + + async def process_geometry(geometry: Polygon, geom_id: int): + band_progress = tqdm( + total=len(band_codes), desc=f"Geom {geom_id}", position=1, leave=False + ) + + band_tasks = [] + for band_code in band_codes: + task = self._load_single_band( + geometry, band_code, cloud_provider, cloud_config, max_concurrent + ) + band_tasks.append(task) + + results = await asyncio.gather(*band_tasks) + band_progress.update(len(band_codes)) + band_progress.close() + geom_progress.update(1) + + return [r for r in results if r is not None], geom_id + + # Process geometries concurrently with semaphore + sem = asyncio.Semaphore(max_concurrent) + + async def bounded_process(geometry: Polygon, geom_id: int): + async with sem: + return await process_geometry(geometry, geom_id) + + tasks = [bounded_process(geom, idx + 1) for idx, geom in enumerate(geometries)] + results = await asyncio.gather(*tasks) + + geom_progress.close() + + # Process results + if for_xarray: + return self._merge_xarray_results(results) + else: + return self._merge_geodataframe_results(results, geometries) + + def _merge_xarray_results( + self, + results: List[Tuple[List[Dict], int]], + ) -> xr.Dataset: + """Merge results into xarray Dataset.""" + data_arrays = [] + + for band_results, geom_id in results: + if not band_results: + continue + + geom_arrays = [] + for band_result in band_results: + da = xr.DataArray( + data=band_result["data"], + dims=["y", "x"], + coords={ + "y": band_result["transform"].f + + np.arange(band_result["data"].shape[0]) + * band_result["transform"].e, + "x": band_result["transform"].c + + np.arange(band_result["data"].shape[1]) + * band_result["transform"].a, + }, + name=band_result["band"], + ) + da.rio.write_crs(self.crs, inplace=True) + da.rio.write_transform(band_result["transform"], inplace=True) + geom_arrays.append(da) + + if geom_arrays: + ds = xr.merge(geom_arrays) + ds = ds.expand_dims({"time": [self.datetime], "geometry": [geom_id]}) + ds.rio.write_crs(self.crs, inplace=True) + ds.attrs.update( + { + "crs": self.crs, + "geometry_id": geom_id, + "scene_id": self.id, + "datetime": self.datetime, + "cloud_cover": self.cloud_cover, + "collection": self.collection, + } + ) + data_arrays.append(ds) + + if not data_arrays: + return None + + return xr.merge(data_arrays) + + def _merge_geodataframe_results( + self, results: List[Tuple[List[Dict], int]], geometries: List[Polygon] + ) -> Optional[gpd.GeoDataFrame]: + """Merge results into GeoDataFrame.""" + rows = [] + + for band_results, geom_id in results: + if not band_results: + continue + + for band_result in band_results: + rows.append( + { + "scene_id": self.id, + "datetime": self.datetime, + "cloud_cover": self.cloud_cover, + "collection": self.collection, + "geometry": geometries[geom_id - 1], + "band": band_result["band"], + "data": band_result["data"], + } + ) + + if not rows: + return None + + return gpd.GeoDataFrame(rows) diff --git a/src/rasteret/core/utils.py b/src/rasteret/core/utils.py new file mode 100644 index 0000000..10a4713 --- /dev/null +++ b/src/rasteret/core/utils.py @@ -0,0 +1,237 @@ +""" Utility functions for rasteret package. """ + +from typing import Optional +from urllib.parse import urlparse +import boto3 +from pathlib import Path +from shapely.geometry import Polygon +from pyproj import Transformer +from shapely.ops import transform +import numpy as np +import pandas as pd +import xarray as xr +from typing import Tuple, Dict, Union, List +from rasterio.warp import transform_bounds + +from rasteret.logging import setup_logger + +logger = setup_logger() + + +def wgs84_to_utm_convert_poly(geom: Polygon, epsg_code: int) -> Polygon: + """ + Convert scene geometry to UTM. + + Parameters + ---------- + geom : shapely Polygon + Scene geometry in WGS84 + epsg_code : int + UTM zone to convert to (e.g. 32643) + + Returns + ------- + shapely Polygon + Scene geometry in UTM zone + """ + wgs84_to_utm = Transformer.from_crs( + "EPSG:4326", f"EPSG:{epsg_code}", always_xy=True # eg. 32643 + ) + utm_poly = transform(wgs84_to_utm.transform, geom) + + return utm_poly + + +class S3URLSigner: + """Handle S3 URL signing with AWS credential chain.""" + + def __init__(self, aws_profile: Optional[str] = None, region: str = "us-west-2"): + self.region = region + self.aws_profile = aws_profile + self._session = None + self._client = None + + @property + def session(self): + """Lazily create boto3 session.""" + if self._session is None: + if self.aws_profile: + self._session = boto3.Session(profile_name=self.aws_profile) + else: + self._session = boto3.Session() + return self._session + + @property + def client(self): + """Lazily create S3 client.""" + if self._client is None: + self._client = self.session.client("s3", region_name=self.region) + return self._client + + def has_valid_credentials(self) -> bool: + """Check if we have valid AWS credentials.""" + try: + self.session.get_credentials() + return True + except Exception: + raise ValueError("Missing AWS credentials") + + def get_signed_url(self, s3_uri: str) -> Optional[str]: + """Get signed URL if credentials available.""" + try: + parsed = urlparse(s3_uri) + bucket = parsed.netloc + key = parsed.path.lstrip("/") + + url = self.client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket, "Key": key, "RequestPayer": "requester"}, + ExpiresIn=3600, + ) + return url + + except Exception as e: + logger.debug(f"Failed to sign S3 URL {s3_uri}: {str(e)}") + return None + + +class CloudStorageURLHandler: + """Handle URL signing for different cloud storage providers.""" + + def __init__( + self, + storage_platform: str, + aws_profile: Optional[str] = None, + aws_region: str = "us-west-2", + ): + self.storage_platform = storage_platform.upper() + self._s3_signer = None + self.aws_profile = aws_profile + self.aws_region = aws_region + + def get_signed_url(self, url: str) -> Optional[str]: + """Get signed URL based on storage platform.""" + if self.storage_platform == "AWS": + if self._s3_signer is None: + self._s3_signer = S3URLSigner( + aws_profile=self.aws_profile, region=self.aws_region + ) + return self._s3_signer.get_signed_url(url) + elif self.storage_platform in ["AZURE", "GCS"]: + # These platforms are not supported yet + logger.warning(f"Unsupported storage platform: {self.storage_platform}") + return url + else: + logger.warning(f"Unknown storage platform: {self.storage_platform}") + return url + + +def transform_bbox( + bbox: Union[Tuple[float, float, float, float], Polygon], + src_crs: Union[int, str], + dst_crs: Union[int, str], +) -> Tuple[float, float, float, float]: + """ + Transform bounding box between coordinate systems. + + Args: + bbox: Input bbox (minx, miny, maxx, maxy) or Polygon + src_crs: Source CRS (EPSG code or WKT string) + dst_crs: Target CRS (EPSG code or WKT string) + + Returns: + Transformed bbox + """ + if isinstance(bbox, Polygon): + minx, miny, maxx, maxy = bbox.bounds + else: + minx, miny, maxx, maxy = bbox + + return transform_bounds( + src_crs=f"EPSG:{src_crs}" if isinstance(src_crs, int) else src_crs, + dst_crs=f"EPSG:{dst_crs}" if isinstance(dst_crs, int) else dst_crs, + left=minx, + bottom=miny, + right=maxx, + top=maxy, + ) + + +def calculate_scale_offset( + arr: np.ndarray, + target_dtype: np.dtype, + valid_min: Optional[float] = None, + valid_max: Optional[float] = None, +) -> Dict[str, float]: + """ + Calculate optimal scale/offset for data type conversion. + + Args: + arr: Input array + target_dtype: Target numpy dtype + valid_min: Optional minimum valid value + valid_max: Optional maximum valid value + + Returns: + Dict with scale and offset values + """ + # Get data range + if valid_min is None: + valid_min = np.nanmin(arr) + if valid_max is None: + valid_max = np.nanmax(arr) + + # Get target dtype info + target_info = np.iinfo(target_dtype) + + # Calculate scale and offset + scale = (valid_max - valid_min) / (target_info.max - target_info.min) + offset = valid_min - (target_info.min * scale) + + return {"scale": float(scale), "offset": float(offset)} + + +def save_per_geometry( + ds: xr.Dataset, output_dir: Path, prefix: str = "geometry" +) -> Dict[int, List[Path]]: + """Save each geometry's timeseries as separate GeoTIFFs.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + outputs = {} + + # Get CRS info + crs = ds.attrs.get("crs", None) + + # Process each geometry + for g in range(ds.geometry.size): + # Extract geometry data + geom_ds = ds.isel(geometry=g) + geom_id = geom_ds.geometry.values.item() + + # Create geometry subfolder + geom_dir = output_dir / f"geometry_{geom_id}" + geom_dir.mkdir(exist_ok=True) + + # Process each timestamp + for t in range(len(geom_ds.time)): + # Extract 2D array for this timestamp + time_data = geom_ds.NDVI.isel(time=t) + timestamp = pd.Timestamp(geom_ds.time[t].values) + + # Create 2D dataset + ds_2d = xr.Dataset( + data_vars={"NDVI": (("y", "x"), time_data.values)}, + coords={"y": geom_ds.y, "x": geom_ds.x}, + ) + + # Set spatial metadata + ds_2d.rio.write_crs(crs, inplace=True) + if ds.rio.transform(): + ds_2d.rio.write_transform(ds.rio.transform(), inplace=True) + + # Save as GeoTIFF + outfile = geom_dir / f"{prefix}_{timestamp.strftime('%Y%m%d')}.tif" + ds_2d.rio.to_raster(outfile) + outputs.setdefault(geom_id, []).append(outfile) + + return outputs diff --git a/src/rasteret/fetch/__init__.py b/src/rasteret/fetch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rasteret/fetch/cog.py b/src/rasteret/fetch/cog.py new file mode 100644 index 0000000..5d31bb1 --- /dev/null +++ b/src/rasteret/fetch/cog.py @@ -0,0 +1,397 @@ +"""Optimized COG reading using byte ranges.""" + +from __future__ import annotations +import asyncio +import httpx +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import imagecodecs +import numpy as np +from affine import Affine +from shapely.geometry import Polygon, box +from rasterio.mask import geometry_mask + +from rasteret.types import CogMetadata +from rasteret.core.utils import wgs84_to_utm_convert_poly +from rasteret.logging import setup_logger + +logger = setup_logger("INFO", customname="rasteret.fetch.cog") + + +@dataclass +class COGTileRequest: + """Single tile request details.""" + + url: str + offset: int # Byte offset in COG file + size: int # Size in bytes to read + row: int # Tile row in the grid + col: int # Tile column in the grid + metadata: CogMetadata # Full metadata including transform + + +def compute_tile_indices( + geometry: Polygon, + transform: List[float], + tile_size: Tuple[int, int], + image_size: Tuple[int, int], + debug: bool = False, +) -> List[Tuple[int, int]]: + """ + Compute tile indices that intersect with geometry. + Using simplified direct mapping approach from tiles.py. + """ + # Extract parameters + scale_x, translate_x, scale_y, translate_y = transform + tile_width, tile_height = tile_size + image_width, image_height = image_size + + # Calculate number of tiles + tiles_x = (image_width + tile_width - 1) // tile_width + tiles_y = (image_height + tile_height - 1) // tile_height + + # Get geometry bounds + minx, miny, maxx, maxy = geometry.bounds + + if debug: + logger.info( + f""" + Computing tile indices: + - Bounds: {minx}, {miny}, {maxx}, {maxy} + - Transform: {scale_x}, {translate_x}, {scale_y}, {translate_y} + - Image size: {image_width}x{image_height} + - Tile size: {tile_width}x{tile_height} + """ + ) + + # Convert to pixel coordinates, handling negative scales + col_min = max(0, int((minx - translate_x) / abs(scale_x))) + col_max = min(image_width - 1, int((maxx - translate_x) / abs(scale_x))) + + # Handle Y coordinate inversion in raster space + row_min = max(0, int((translate_y - maxy) / abs(scale_y))) + row_max = min(image_height - 1, int((translate_y - miny) / abs(scale_y))) + + if debug: + logger.info(f"Pixel bounds: x({col_min}-{col_max}), y({row_min}-{row_max})") + + # Convert to tile indices + tile_col_min = max(0, col_min // tile_width) + tile_col_max = min(tiles_x - 1, col_max // tile_width) + tile_row_min = max(0, row_min // tile_height) + tile_row_max = min(tiles_y - 1, row_max // tile_height) + + if debug: + logger.info( + f"Tile indices: x({tile_col_min}-{tile_col_max}), y({tile_row_min}-{tile_row_max})" + ) + + # Validate tile ranges + if tile_col_min > tile_col_max or tile_row_min > tile_row_max: + if debug: + logger.info("No valid tiles in range") + return [] + + # Find intersecting tiles + intersecting_tiles = [] + for row in range(tile_row_min, tile_row_max + 1): + for col in range(tile_col_min, tile_col_max + 1): + # Calculate tile bounds in UTM coordinates + tile_minx = translate_x + col * tile_width * scale_x + tile_maxx = tile_minx + tile_width * scale_x + tile_maxy = translate_y - row * tile_height * abs(scale_y) + tile_miny = tile_maxy - tile_height * abs(scale_y) + + # Create tile box and check intersection + tile_box = box( + min(tile_minx, tile_maxx), + min(tile_miny, tile_maxy), + max(tile_minx, tile_maxx), + max(tile_miny, tile_maxy), + ) + + if geometry.intersects(tile_box): + intersecting_tiles.append((row, col)) + if debug: + logger.info(f"Added intersecting tile: ({row}, {col})") + + if debug: + logger.info(f"Found {len(intersecting_tiles)} intersecting tiles") + + return intersecting_tiles + + +def merge_tiles( + tiles: Dict[Tuple[int, int], np.ndarray], + tile_size: Tuple[int, int], + dtype: np.dtype = np.float32, +) -> Tuple[np.ndarray, Tuple[int, int, int, int]]: + """ + Merge multiple tiles into a single array. + Returns merged array and bounds (min_row, min_col, max_row, max_col). + """ + if not tiles: + return np.array([], dtype=dtype), (0, 0, 0, 0) + + # Find bounds + rows, cols = zip(*tiles.keys()) + min_row, max_row = min(rows), max(rows) + min_col, max_col = min(cols), max(cols) + + tile_width, tile_height = tile_size + + # Create output array + height = (max_row - min_row + 1) * tile_height + width = (max_col - min_col + 1) * tile_width + merged = np.full((height, width), np.nan, dtype=dtype) + + # Place tiles with exact positioning + for (row, col), data in tiles.items(): + if data is not None: # Handle potentially failed tiles + y_start = (row - min_row) * tile_height + x_start = (col - min_col) * tile_width + y_end = min(y_start + data.shape[0], height) + x_end = min(x_start + data.shape[1], width) + merged[y_start:y_end, x_start:x_end] = data[ + : y_end - y_start, : x_end - x_start + ] + + return merged, (min_row, min_col, max_row, max_col) + + +def apply_mask_and_crop( + data: np.ndarray, + geometry: Polygon, + transform: Affine, +) -> Tuple[np.ndarray, Affine]: + """Apply geometry mask and crop to valid data region.""" + + mask = geometry_mask( + [geometry], + out_shape=data.shape, + transform=transform, + all_touched=True, + invert=True, + ) + + # Find valid data bounds + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + row_min, row_max = np.where(rows)[0][[0, -1]] + col_min, col_max = np.where(cols)[0][[0, -1]] + + # Crop data and mask + data_cropped = data[row_min : row_max + 1, col_min : col_max + 1] + mask_cropped = mask[row_min : row_max + 1, col_min : col_max + 1] + + # Apply mask to cropped data + masked_data = np.where(mask_cropped, data_cropped, np.nan) + + # Update transform for cropped array + cropped_transform = Affine( + transform.a, + transform.b, + transform.c + col_min * transform.a, + transform.d, + transform.e, + transform.f + row_min * transform.e, + ) + + return masked_data, cropped_transform + + +async def read_tile( + request: COGTileRequest, + client: httpx.AsyncClient, + sem: asyncio.Semaphore, + retries: int = 3, + retry_delay: float = 1.0, +) -> Optional[np.ndarray]: + """Read a single tile using byte range request.""" + for attempt in range(retries): + try: + async with sem: + headers = { + "Range": f"bytes={request.offset}-{request.offset+request.size-1}" + } + response = await client.get(request.url, headers=headers) + if response.status_code != 206: + raise ValueError(f"Range request failed: {response.status_code}") + + # Simple, direct data flow like tiles.py + decompressed = imagecodecs.zlib_decode(response.content) + data = np.frombuffer(decompressed, dtype=np.uint16) + data = data.reshape( + (request.metadata.tile_height, request.metadata.tile_width) + ) + + # Predictor handling exactly like tiles.py + if request.metadata.predictor == 2: + data = data.astype(np.uint16) + for i in range(data.shape[0]): + data[i] = np.cumsum(data[i], dtype=np.uint16) + + return data.astype(np.float32) + + except Exception as e: + if attempt == retries - 1: + logger.error(f"Failed to read tile: {str(e)}") + return None + await asyncio.sleep(retry_delay * (2**attempt)) + + +async def read_cog_tile_data( + url: str, + metadata: CogMetadata, + geometry: Optional[Polygon] = None, + max_concurrent: int = 50, + debug: bool = False, +) -> Tuple[np.ndarray, Optional[Affine]]: + """Read COG data, optionally masked by geometry. + + Args: + url: URL of the COG file + metadata: COG metadata including transform + geometry: Optional polygon to mask/filter data + max_concurrent: Maximum concurrent requests + debug: Enable debug logging + + Returns: + Tuple of: + - np.ndarray: The masked data array + - Affine: Transform matrix for the masked data + None if no transform available + """ + if debug: + logger.info( + f""" + Input Parameters: + - CRS: {metadata.crs} + - Transform: {metadata.transform} + - Image Size: {metadata.width}x{metadata.height} + - Tile Size: {metadata.tile_width}x{metadata.tile_height} + - Geometry: {geometry.wkt if geometry else None} + """ + ) + + if metadata.transform is None: + return np.array([]), None + + # Convert geometry to image CRS if needed + if geometry: + if metadata.crs != 4326: + geometry = wgs84_to_utm_convert_poly(geom=geometry, epsg_code=metadata.crs) + if debug: + logger.info(f"Transformed geometry bounds: {geometry.bounds}") + + # Get tiles that intersect with geometry + intersecting_tiles = compute_tile_indices( + geometry=geometry, + transform=metadata.transform, + tile_size=(metadata.tile_width, metadata.tile_height), + image_size=(metadata.width, metadata.height), + debug=debug, + ) + else: + # Read all tiles if no geometry provided + tiles_x = (metadata.width + metadata.tile_width - 1) // metadata.tile_width + tiles_y = (metadata.height + metadata.tile_height - 1) // metadata.tile_height + intersecting_tiles = [(r, c) for r in range(tiles_y) for c in range(tiles_x)] + + if not intersecting_tiles: + return np.array([]), None + + # Set up HTTP client with connection pooling + limits = httpx.Limits( + max_keepalive_connections=max_concurrent, max_connections=max_concurrent + ) + timeout = httpx.Timeout(30.0) + + async with httpx.AsyncClient(timeout=timeout, limits=limits, http2=True) as client: + sem = asyncio.Semaphore(max_concurrent) + + # Read tiles + tiles = {} + tasks = [] + tiles_x = (metadata.width + metadata.tile_width - 1) // metadata.tile_width + + # Create tasks for all tiles + for row, col in intersecting_tiles: + tile_idx = row * tiles_x + col # Linear tile index + + if tile_idx >= len(metadata.tile_offsets): + if debug: + logger.warning(f"Tile index {tile_idx} out of bounds") + continue + + # Create tile request + request = COGTileRequest( + url=url, + offset=metadata.tile_offsets[tile_idx], + size=metadata.tile_byte_counts[tile_idx], + row=row, + col=col, + metadata=metadata, + ) + + tasks.append((row, col, read_tile(request, client, sem))) + + # Gather results + for row, col, task in tasks: + try: + tile_data = await task + if tile_data is not None: + tiles[(row, col)] = tile_data + except Exception as e: + logger.error(f"Failed to read tile at ({row}, {col}): {str(e)}") + + if not tiles: + return np.array([]), None + + # Merge tiles + merged_data, bounds = merge_tiles( + tiles, (metadata.tile_width, metadata.tile_height), dtype=np.float32 + ) + + if debug: + logger.info( + f""" + Merged Data: + - Shape: {merged_data.shape} + - Bounds: {bounds} + - Data Range: {np.nanmin(merged_data)}-{np.nanmax(merged_data)} + """ + ) + + # Calculate transform for merged data + min_row, min_col, max_row, max_col = bounds + scale_x, translate_x, scale_y, translate_y = metadata.transform + + merged_transform = Affine( + scale_x, + 0, + translate_x + min_col * metadata.tile_width * scale_x, + 0, + scale_y, + translate_y + min_row * metadata.tile_height * scale_y, + ) + + # Apply geometry mask if provided + if geometry is not None: + merged_data, cropped_transform = apply_mask_and_crop( + merged_data, geometry, merged_transform + ) + + if debug: + logger.info( + f""" + Final Output: + - Shape: {merged_data.shape} + - Transform: {merged_transform} + - Data Range: {np.nanmin(merged_data)}-{np.nanmax(merged_data)} + """ + ) + + return merged_data, cropped_transform diff --git a/src/rasteret/logging.py b/src/rasteret/logging.py new file mode 100644 index 0000000..08d001a --- /dev/null +++ b/src/rasteret/logging.py @@ -0,0 +1,41 @@ +""" Logging configuration for the rasteret package. """ + +import logging +import sys +from typing import Optional + + +def setup_logger( + level: Optional[str] = "INFO", customname: Optional[str] = "rasteret" +) -> None: + """ + Set up library-wide logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + """ + # Create formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Create console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(detailed_formatter) + + # Configure root logger for the package + root_logger = logging.getLogger(name=customname) + root_logger.setLevel(getattr(logging, level)) + + # Configure logging - suppress httpx logs + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + # Remove existing handlers and add our handler + root_logger.handlers = [] + root_logger.addHandler(console_handler) + + # Prevent propagation to root logger + root_logger.propagate = False + + return root_logger diff --git a/src/rasteret/stac/__init__.py b/src/rasteret/stac/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rasteret/stac/indexer.py b/src/rasteret/stac/indexer.py new file mode 100644 index 0000000..957b512 --- /dev/null +++ b/src/rasteret/stac/indexer.py @@ -0,0 +1,348 @@ +""" Indexer for creating GeoParquet collections from STAC catalogs. """ + +from __future__ import annotations +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any + +import pyarrow as pa +import pyarrow.parquet as pq +import pyarrow.dataset as ds +import pystac_client +import stac_geoparquet +from shapely.geometry import shape + +from rasteret.stac.parser import AsyncCOGHeaderParser +from rasteret.cloud import CloudProvider, CloudConfig +from rasteret.logging import setup_logger +from rasteret.types import BoundingBox, DateRange +from rasteret.core.collection import Collection +from rasteret.constants import STAC_COLLECTION_BAND_MAPS, COG_BAND_METADATA_STRUCT + +logger = setup_logger("INFO", customname="rasteret.stac.indexer") + + +class StacToGeoParquetIndexer: + """Creates searchable GeoParquet collections from STAC catalogs.""" + + def __init__( + self, + data_source: str, + stac_api: str, + output_dir: Optional[Path] = None, + name: Optional[str] = None, + cloud_provider: Optional[CloudProvider] = None, + cloud_config: Optional[CloudConfig] = None, + max_concurrent: int = 50, + ): + self.data_source = data_source + self.stac_api = stac_api + self.output_dir = output_dir + self.cloud_provider = cloud_provider + self.cloud_config = cloud_config + self.name = name + self.max_concurrent = max_concurrent + + @property + def band_map(self) -> Dict[str, str]: + """Get band mapping for current collection.""" + return STAC_COLLECTION_BAND_MAPS.get(self.data_source, {}) + + async def build_index( + self, + bbox: Optional[BoundingBox] = None, + date_range: Optional[DateRange] = None, + query: Optional[Dict[str, Any]] = None, + ) -> Collection: + """ + Build GeoParquet collection from STAC search. + + Args: + bbox: Bounding box filter + date_range: Date range filter + query: Additional query parameters + + Returns: + Created Collection + """ + logger.info("Starting STAC index creation...") + if bbox: + logger.info(f"Spatial filter: {bbox}") + if date_range: + logger.info(f"Temporal filter: {date_range[0]} to {date_range[1]}") + if query: + logger.info(f"Additional query parameters: {query}") + + # 1. Get STAC items + stac_items = await self._search_stac(bbox, date_range, query) + logger.info(f"Found {len(stac_items)} scenes in STAC catalog") + + # 2. Process in batches, adding COG metadata + processed_items = [] + batch_size = 10 + total_batches = (len(stac_items) + batch_size - 1) // batch_size + + logger.info( + f"Processing {len(stac_items)} scenes (each scene has multiple bands)..." + ) + + async with AsyncCOGHeaderParser( + max_concurrent=self.max_concurrent, + cloud_provider=self.cloud_provider, + cloud_config=self.cloud_config, + ) as cog_parser: + + for i in range(0, len(stac_items), batch_size): + batch = stac_items[i : i + batch_size] + batch_records = await self._process_batch(batch, cog_parser) + if batch_records: + processed_items.extend(batch_records) + logger.info( + f"Processed scene batch {(i//batch_size)+1}/{total_batches} yielding {len(batch_records)} band assets" + ) + + total_assets = sum(len(item["assets"]) for item in stac_items) + logger.info( + f"Completed processing {len(stac_items)} scenes with {len(processed_items)}/{total_assets} band assets" + ) + + logger.info(f"Successfully processed {len(processed_items)} items") + + try: + logger.info("Creating GeoParquet table with metadata...") + # First create json file with STAC items + temp_ndjson = Path(f"/tmp/stac_items_{datetime.now().timestamp()}.ndjson") + with open(temp_ndjson, "w") as f: + for item in stac_items: # Use original STAC items + json.dump(item, f) + f.write("\n") + + # Create temporary parquet with stac-geoparquet + temp_parquet = Path(f"/tmp/temp_stac_{datetime.now().timestamp()}.parquet") + stac_geoparquet.arrow.parse_stac_ndjson_to_parquet( + temp_ndjson, temp_parquet + ) + + # Read and enrich parquet table + stac_table = pq.read_table(temp_parquet) + + if not pa.types.is_timestamp(stac_table["datetime"].type): + stac_table = stac_table.append_column( + "datetime", + pa.array( + [ + datetime.fromisoformat(str(d.as_py())) + for d in stac_table["datetime"] + ], + type=pa.timestamp("us"), + ), + ) + + logger.info("Adding time columns...") + # Add time columns + datetime_col = stac_table.column("datetime") + if not pa.types.is_timestamp(datetime_col.type): + datetime_col = pa.array( + [datetime.fromisoformat(str(d.as_py())) for d in datetime_col], + type=pa.timestamp("us"), + ) + + table = stac_table.append_column( + "year", pa.compute.year(datetime_col) + ).append_column("month", pa.compute.month(datetime_col)) + + logger.info("Adding scene bounding boxes...") + scene_bboxes = {} + for item in stac_items: + polygon = shape(item["geometry"]) + scene_bboxes[item["id"]] = list(polygon.bounds) + + bbox_list = [scene_bboxes[id_] for id_ in table.column("id").to_pylist()] + table = table.append_column( + "scene_bbox", pa.array(bbox_list, type=pa.list_(pa.float64(), 4)) + ) + + logger.info("Adding band metadata...") + # Add band metadata columns + scene_metadata = {} + for scene_id in table.column("id").to_pylist(): + scene_metadata[scene_id] = {band: None for band in self.band_map.keys()} + + for item in processed_items: + if ( + "scene_id" not in item + ): # Handle case where scene_id might be missing + continue + scene_id = item["scene_id"] + band = item["band"] + if scene_id in scene_metadata: + scene_metadata[scene_id][band] = { + "image_width": item["width"], + "image_height": item["height"], + "tile_width": item["tile_width"], + "tile_height": item["tile_height"], + "dtype": item["dtype"], + "transform": item.get("transform", []), + "predictor": item["predictor"], + "compression": item["compression"], + "tile_offsets": item["tile_offsets"], + "tile_byte_counts": item["tile_byte_counts"], + "pixel_scale": item.get("pixel_scale", []), + "tiepoint": item.get("tiepoint", []), + } + logger.debug(f"Added metadata for scene {scene_id} band {band}") + + for band in self.band_map.keys(): + metadata_list = [ + scene_metadata[id_][band] for id_ in table.column("id").to_pylist() + ] + table = table.append_column( + f"{band}_metadata", + pa.array(metadata_list, type=COG_BAND_METADATA_STRUCT), + ) + + logger.info("Creating final collection...") + # Create collection + collection = Collection( + dataset=ds.dataset(table), # Create dataset from table + name=self.name, + description="STAC collection indexed from {self.data_source}", + ) + + # Optionally write to disk + if self.output_dir: + logger.info(f"Saving collection to {self.output_dir}") + collection.save_to_parquet(self.output_dir) + + logger.info("Index creation completed successfully") + return collection + + except Exception as e: + logger.error(f"Failed to create index: {str(e)}") + raise + finally: + # Cleanup temp files + logger.debug("Cleaning up temporary files...") + if temp_ndjson.exists(): + temp_ndjson.unlink() + if temp_parquet.exists(): + temp_parquet.unlink() + + async def _search_stac( + self, + bbox: Optional[BoundingBox] = None, + date_range: Optional[DateRange] = None, + query: Optional[Dict[str, Any]] = None, + ) -> List[dict]: + """ + Search STAC API for items. + + Returns: + List of STAC items + """ + + # Build search parameters + search_params = {"collections": [self.data_source], "limit": None} + if bbox: + search_params["bbox"] = bbox + if date_range: + search_params["datetime"] = f"{date_range[0]}/{date_range[1]}" + if query is not None: + search_params["query"] = query + + # Initialize STAC client and search + client = pystac_client.Client.open(self.stac_api) + search = client.search(**search_params) + + items = [] + for item in search.items(): + items.append(item.to_dict()) + + logger.info(f"Found {len(items)} scenes") + return items + + def _get_asset_url(self, asset: Dict) -> str: + """Get authenticated URL for asset""" + url = asset["href"] if isinstance(asset, dict) else asset + if self.cloud_provider and self.cloud_config: + return self.cloud_provider.get_url(url, self.cloud_config) + return url + + async def _process_batch( + self, stac_items: List[dict], cog_parser: AsyncCOGHeaderParser + ) -> List[dict]: + """ + Add COG metadata to STAC items. + """ + urls_to_process = [] + url_mapping = {} # Track which url belongs to which item/band + + for item in stac_items: + item_id = item.get("id") + if not item_id: + continue + + for band_code, asset_name in self.band_map.items(): + if asset_name not in item["assets"]: + continue + + asset = item["assets"][asset_name] + url = self._get_asset_url(asset) + if url: + urls_to_process.append(url) + url_mapping[url] = (item_id, band_code, item) + + # Get COG metadata for all URLs + metadata_results = await cog_parser.process_cog_headers_batch(urls_to_process) + + # Enrich items with metadata + processed_items = {} + + for url, metadata in zip(urls_to_process, metadata_results): + if not metadata: + continue + + item_id, band_code, item = url_mapping[url] + + if item_id not in processed_items: + processed_items[item_id] = { + "id": item_id, + "scene_id": item_id, + "geometry": item["geometry"], + "datetime": item["properties"].get("datetime"), + "cloud_cover": item["properties"].get("eo:cloud_cover"), + "bands": {}, + } + + processed_items[item_id]["bands"][band_code] = { + "width": metadata.width, + "height": metadata.height, + "tile_width": metadata.tile_width, + "tile_height": metadata.tile_height, + "dtype": str(metadata.dtype), + "transform": metadata.transform, + "predictor": metadata.predictor, + "compression": metadata.compression, + "tile_offsets": metadata.tile_offsets, + "tile_byte_counts": metadata.tile_byte_counts, + "pixel_scale": metadata.pixel_scale, + "tiepoint": metadata.tiepoint, + } + + # Convert to enriched items list with proper band metadata structure + enriched_items = [] + for item_id, item_data in processed_items.items(): + for band_code, band_metadata in item_data["bands"].items(): + enriched_items.append( + { + "scene_id": item_id, + "band": band_code, + "geometry": item_data["geometry"], + "datetime": item_data["datetime"], + "cloud_cover": item_data["cloud_cover"], + **band_metadata, + } + ) + + return enriched_items diff --git a/src/rasteret/stac/parser.py b/src/rasteret/stac/parser.py new file mode 100644 index 0000000..efff589 --- /dev/null +++ b/src/rasteret/stac/parser.py @@ -0,0 +1,331 @@ +"""Async COG header parsing with caching.""" + +from __future__ import annotations +import asyncio +import struct +import time +from typing import Dict, List, Optional, Set, Any + +import httpx +from cachetools import TTLCache, LRUCache + +from rasteret.types import CogMetadata +from rasteret.cloud import CloudProvider, CloudConfig +from rasteret.logging import setup_logger + +logger = setup_logger() + + +def get_crs_from_tiff_tags(tags: Dict[int, Any]) -> Optional[str]: + """ + Extract CRS from GeoTIFF tags using multiple methods. + + Args: + tags: Dictionary of TIFF tags + + Returns: + Optional[str]: EPSG code string like "EPSG:32643" if found, None otherwise + """ + # Method 1: GeoTiff WKT string (tag 34737 - GeoAsciiParamsTag) + if 34737 in tags: + wkt = tags[34737] + try: + import re + + # Look for EPSG code in WKT string + epsg_match = re.search(r'ID\["EPSG",(\d+)\]', wkt) + if epsg_match: + return int(epsg_match.group(1)) + except Exception as e: + logger.debug(f"Failed to parse WKT string: {e}") + + # Method 2: GeoKey directory (tag 34735) + geokeys = tags.get(34735) + if geokeys: + try: + num_keys = geokeys[3] + for i in range(4, 4 + (4 * num_keys), 4): + key_id = geokeys[i] + tiff_tag_loc = geokeys[i + 1] + count = geokeys[i + 2] + value = geokeys[i + 3] + + if key_id in (3072, 2048): # ProjectedCRS or GeographicCRS + if tiff_tag_loc == 0 and count == 1: # Direct value + return int(value) + except Exception as e: + logger.debug(f"Failed to parse GeoKey directory: {e}") + + return None + + +class AsyncCOGHeaderParser: + """Optimized async parser for COG headers with connection pooling and caching.""" + + def __init__( + self, + max_concurrent: int = 50, + cache_ttl: int = 3600, # 1 hour + retry_attempts: int = 3, + cloud_provider: Optional[CloudProvider] = None, + cloud_config: Optional[CloudConfig] = None, + ): + self.max_concurrent = max_concurrent + self.retry_attempts = retry_attempts + self.cloud_provider = cloud_provider + self.cloud_config = cloud_config + + # Connection optimization + self.connector = httpx.Limits( + max_keepalive_connections=max_concurrent, + max_connections=max_concurrent, + keepalive_expiry=120, + ) + + # Rate limiting + self.semaphore = asyncio.Semaphore(max_concurrent) + self.active_requests: Set[str] = set() + + # Caching + self.header_cache = TTLCache(maxsize=1000, ttl=cache_ttl) + self.dns_cache = LRUCache(maxsize=500) + + self.client = None + self.dtype_map = { + (1, 8): "uint8", + (1, 16): "uint16", + (2, 8): "int8", + (2, 16): "int16", + (3, 32): "float32", + } + + async def __aenter__(self): + self.client = httpx.AsyncClient( + limits=self.connector, + timeout=30.0, + http2=True, + headers={"Connection": "keep-alive", "Keep-Alive": "timeout=120"}, + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.client: + await self.client.aclose() + + async def process_cog_headers_batch( + self, urls: List[str], batch_size: int = 10 + ) -> List[Optional[CogMetadata]]: + """Process multiple URLs in parallel with smart batching.""" + + results = [] + total = len(urls) + + logger.info( + f"Processing {total} COG headers {'(single batch)' if total <= batch_size else f'in {(total + batch_size - 1) // batch_size} batches of {batch_size}'}" + ) + + for i in range(0, total, batch_size): + batch = urls[i : min(i + batch_size, total)] + batch_start = time.time() + + tasks = [self.parse_cog_header(url) for url in batch] + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in batch_results: + if isinstance(result, Exception): + logger.error(f"Failed to process header: {result}") + results.append(None) + else: + results.append(result) + + batch_time = time.time() - batch_start + remaining = total - (i + len(batch)) + batch_msg = ( + f"Processed batch {i//batch_size + 1}/{(total + batch_size - 1) // batch_size} " + f"({len(batch)} {'header' if len(batch) == 1 else 'headers'}) " + f"in {batch_time:.2f}s. " + f"{'Completed!' if remaining == 0 else f'Remaining: {remaining}'}" + ) + logger.info(batch_msg) + + return results + + async def _fetch_byte_range(self, url: str, start: int, size: int) -> bytes: + """Fetch a byte range from a URL.""" + + cache_key = f"{url}:{start}:{size}" + + if cache_key in self.header_cache: + return self.header_cache[cache_key] + + while url in self.active_requests: + await asyncio.sleep(0.1) + + self.active_requests.add(url) + try: + headers = {"Range": f"bytes={start}-{start + size - 1}"} + + for attempt in range(self.retry_attempts): + try: + async with self.semaphore: + response = await self.client.get(url, headers=headers) + if response.status_code != 206: + raise IOError( + f"Range request failed: {response.status_code}" + ) + + data = response.content + self.header_cache[cache_key] = data + return data + + except Exception as e: + if attempt == self.retry_attempts - 1: + raise IOError(f"Failed to fetch bytes from {url}: {e}") + await asyncio.sleep(1 * (attempt + 1)) + finally: + self.active_requests.remove(url) + + async def parse_cog_header(self, url: str) -> Optional[CogMetadata]: + """Parse COG header from URL.""" + try: + # Read initial header bytes + header_bytes = await self._fetch_byte_range(url, 0, 16) + + # Check byte order + big_endian = header_bytes[0:2] == b"MM" + endian = ">" if big_endian else "<" + + # Parse version and IFD offset + version = struct.unpack(f"{endian}H", header_bytes[2:4])[0] + if version == 42: + ifd_offset = struct.unpack(f"{endian}L", header_bytes[4:8])[0] + entry_size = 12 + elif version == 43: + ifd_offset = struct.unpack(f"{endian}Q", header_bytes[8:16])[0] + entry_size = 20 + else: + raise ValueError(f"Unsupported TIFF version: {version}") + + # Read IFD entries + ifd_count_size = 2 if version == 42 else 8 + ifd_count_bytes = await self._fetch_byte_range( + url, ifd_offset, ifd_count_size + ) + entry_count = struct.unpack(f"{endian}H", ifd_count_bytes)[0] + + ifd_bytes = await self._fetch_byte_range( + url, ifd_offset + ifd_count_size, entry_count * entry_size + ) + + # Parse tags + tags = {} + for i in range(entry_count): + entry = ifd_bytes[i * entry_size : (i + 1) * entry_size] + tag = struct.unpack(f"{endian}H", entry[0:2])[0] + type_id = struct.unpack(f"{endian}H", entry[2:4])[0] + count = struct.unpack(f"{endian}L", entry[4:8])[0] + value_or_offset = entry[8:12] if version == 42 else entry[16:24] + + tags[tag] = await self._parse_tiff_tag_value( + url, tag, type_id, count, value_or_offset, endian + ) + + # Extract essential metadata + image_width = tags.get(256)[0] # ImageWidth + image_height = tags.get(257)[0] # ImageLength + tile_width = tags.get(322, [image_width])[0] # TileWidth + tile_height = tags.get(323, [image_height])[0] # TileLength + + compression = tags.get(259, (1,))[0] # Compression + predictor = tags.get(317, (1,))[0] # Predictor + + # Data type + sample_format = tags.get(339, (1,))[0] + bits_per_sample = tags.get(258, (8,))[0] + dtype = self.dtype_map.get((sample_format, bits_per_sample), "uint8") + + # Tile layout + tile_offsets = list(tags.get(324, [])) # TileOffsets + tile_byte_counts = list(tags.get(325, [])) # TileByteCounts + + # Geotransform + pixel_scale = tags.get(33550) # ModelPixelScaleTag + tiepoint = tags.get(33922) # ModelTiepointTag + + # Calculate transform + transform = None + if pixel_scale and tiepoint: + scale_x, scale_y = pixel_scale[0], -pixel_scale[1] + translate_x, translate_y = tiepoint[3], tiepoint[4] + transform = (scale_x, translate_x, scale_y, translate_y) + + crs = get_crs_from_tiff_tags(tags) + + return CogMetadata( + width=image_width, + height=image_height, + tile_width=tile_width, + tile_height=tile_height, + dtype=dtype, + transform=transform, + predictor=predictor, + compression=compression, + tile_offsets=tile_offsets, + tile_byte_counts=tile_byte_counts, + crs=crs, + pixel_scale=pixel_scale, + tiepoint=tiepoint, + ) + + except Exception as e: + logger.error(f"Failed to parse header for {url}: {str(e)}") + return None + + async def _parse_tiff_tag_value( + self, + url: str, + tag: int, + type_id: int, + count: int, + value_or_offset: bytes, + endian: str, + ) -> tuple: + """Parse a TIFF tag value based on its type.""" + # Handle single values + if count == 1: + if type_id == 3: # SHORT + return (struct.unpack(f"{endian}H", value_or_offset[:2])[0],) + elif type_id == 4: # LONG + return (struct.unpack(f"{endian}L", value_or_offset[:4])[0],) + elif type_id == 5: # RATIONAL + offset = struct.unpack(f"{endian}L", value_or_offset[:4])[0] + data = await self._fetch_byte_range(url, offset, 8) + nums = struct.unpack(f"{endian}LL", data) + return (float(nums[0]) / nums[1],) + + # Handle offset values + offset = struct.unpack(f"{endian}L", value_or_offset[:4])[0] + size = { + 1: 1, # BYTE + 2: 1, # ASCII + 3: 2, # SHORT + 4: 4, # LONG + 5: 8, # RATIONAL + 12: 8, # DOUBLE + }[type_id] * count + + data = await self._fetch_byte_range(url, offset, size) + + if type_id == 1: # BYTE + return struct.unpack(f"{endian}{count}B", data) + elif type_id == 2: # ASCII + return (data[: count - 1].decode("ascii"),) + elif type_id == 3: # SHORT + return struct.unpack(f"{endian}{count}H", data) + elif type_id == 4: # LONG + return struct.unpack(f"{endian}{count}L", data) + elif type_id == 5: # RATIONAL + vals = struct.unpack(f"{endian}{count*2}L", data) + return tuple(vals[i] / vals[i + 1] for i in range(0, len(vals), 2)) + elif type_id == 12: # DOUBLE + return struct.unpack(f"{endian}{count}d", data) diff --git a/src/rasteret/tests/test_cloud_provider.py b/src/rasteret/tests/test_cloud_provider.py new file mode 100644 index 0000000..f6ece59 --- /dev/null +++ b/src/rasteret/tests/test_cloud_provider.py @@ -0,0 +1,51 @@ +import unittest +from unittest.mock import patch, MagicMock +from rasteret.cloud import CloudConfig, AWSProvider + + +class TestCloudProvider(unittest.TestCase): + + @patch("rasteret.cloud.boto3") + def test_s3_url_signing_handler(self, mock_boto3): + # Mock AWS session and client + mock_session = MagicMock() + mock_s3_client = MagicMock() + mock_boto3.Session.return_value = mock_session + mock_session.client.return_value = mock_s3_client + mock_s3_client.generate_presigned_url.return_value = ( + "https://signed-url.example.com" + ) + + # Mock credentials check + mock_credentials = MagicMock() + mock_session.get_credentials.return_value = mock_credentials + + # Create test configuration + cloud_config = CloudConfig( + provider="aws", + requester_pays=True, + region="us-west-2", + url_patterns={"https://example.com/": "s3://example-bucket/"}, + ) + + # Initialize provider + provider = AWSProvider(region="us-west-2") + + # Test URL pattern conversion + test_url = "https://example.com/test.tif" + provider.get_url(test_url, cloud_config) + + # Verify S3 client calls + mock_s3_client.generate_presigned_url.assert_called_once_with( + "get_object", + Params={ + "Bucket": "example-bucket", + "Key": "test.tif", + "RequestPayer": "requester", + }, + ExpiresIn=3600, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/rasteret/tests/test_stac_indexer.py b/src/rasteret/tests/test_stac_indexer.py new file mode 100644 index 0000000..1551d6b --- /dev/null +++ b/src/rasteret/tests/test_stac_indexer.py @@ -0,0 +1,134 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +from pathlib import Path +import pystac +from datetime import datetime + +from rasteret.stac.indexer import StacToGeoParquetIndexer +from rasteret.cloud import CloudConfig +from rasteret.types import CogMetadata + + +class TestStacIndexer(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.mock_stac_items = [ + pystac.Item( + id="test_scene_1", + datetime=datetime(2023, 1, 1, 0, 0, 0), + geometry={ + "type": "Polygon", + "coordinates": [[[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]]], + }, + bbox=[0, 0, 1, 1], + properties={ + "datetime": datetime(2023, 1, 1, 0, 0, 0), + "eo:cloud_cover": 10.5, + }, + assets={ + "B1": pystac.Asset(href="s3://test-bucket/test1_B1.tif"), + "B2": pystac.Asset(href="s3://test-bucket/test1_B2.tif"), + }, + ) + ] + + self.mock_cog_metadata = CogMetadata( + width=1000, + height=1000, + tile_width=256, + tile_height=256, + dtype="uint16", + crs=4326, + transform=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + predictor=2, + compression="deflate", + tile_offsets=[1000], + tile_byte_counts=[10000], + pixel_scale=[1.0, 1.0, 0.0], + tiepoint=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ) + + self.cloud_config = CloudConfig( + provider="aws", + requester_pays=True, + region="us-west-2", + url_patterns={"https://test.com/": "s3://test-bucket/"}, + ) + + @patch("rasteret.stac.indexer.pystac_client") + async def test_stac_search(self, mock_pystac): + # Setup mock STAC client + mock_client = MagicMock() + mock_search = MagicMock() + + # Configure mock chain + mock_pystac.Client.open.return_value = mock_client + mock_client.search.return_value = mock_search + mock_search.items.return_value = iter(self.mock_stac_items) + + # Create indexer + indexer = StacToGeoParquetIndexer( + data_source="test-source", stac_api="https://test-stac.com" + ) + + # Test search + items = await indexer._search_stac( + bbox=[-180, -90, 180, 90], date_range=["2023-01-01", "2023-12-31"] + ) + + # Verify results + self.assertEqual(len(items), 1) + self.assertEqual(items[0]["id"], "test_scene_1") + + # Verify mock calls + mock_pystac.Client.open.assert_called_once() + mock_client.search.assert_called_once() + mock_search.items.assert_called_once() + + @patch("rasteret.stac.indexer.AsyncCOGHeaderParser") + @patch("rasteret.stac.indexer.pystac_client") + async def test_index_creation(self, mock_pystac, mock_parser): + # Setup STAC client mock chain + mock_client = MagicMock() + mock_search = MagicMock() + + mock_pystac.Client.open.return_value = mock_client + mock_client.search.return_value = mock_search + mock_search.items.return_value = iter(self.mock_stac_items) + + # Setup COG parser mock + mock_parser_instance = AsyncMock() + mock_parser.return_value.__aenter__.return_value = mock_parser_instance + mock_parser_instance.process_cog_headers_batch.return_value = [ + self.mock_cog_metadata + ] + + indexer = StacToGeoParquetIndexer( + data_source="test-source", + stac_api="https://test-stac.com", + output_dir=Path("/tmp/test_output"), + ) + + collection = await indexer.build_index( + bbox=[-180, -90, 180, 90], date_range=["2023-01-01", "2023-12-31"] + ) + + self.assertIsNotNone(collection) + mock_parser_instance.process_cog_headers_batch.assert_called_once() + + def test_url_signing(self): + mock_provider = MagicMock() + mock_provider.get_url.return_value = "https://signed-url.test.com" + + indexer = StacToGeoParquetIndexer( + data_source="test-source", + stac_api="https://test-stac.com", + cloud_provider=mock_provider, + cloud_config=self.cloud_config, + ) + + url = indexer._get_asset_url({"href": "https://test.com/asset.tif"}) + self.assertEqual(url, "https://signed-url.test.com") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/rasteret/types.py b/src/rasteret/types.py new file mode 100644 index 0000000..b8479f5 --- /dev/null +++ b/src/rasteret/types.py @@ -0,0 +1,76 @@ +"""Type definitions used throughout Rasteret.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Tuple, Any, Union + +import pyarrow as pa +import numpy as np + +# Type aliases +BoundingBox = Tuple[float, float, float, float] # minx, miny, maxx, maxy +DateRange = Tuple[str, str] # ("YYYY-MM-DD", "YYYY-MM-DD") +Transform = List[float] # Affine transform coefficients + + +@dataclass +class CogMetadata: + """ + Metadata for a Cloud-Optimized GeoTIFF. + + Attributes: + width (int): Image width in pixels + height (int): Image height in pixels + tile_width (int): Internal tile width + tile_height (int): Internal tile height + dtype (Union[np.dtype, pa.DataType]): Data type + crs (int): Coordinate reference system code + predictor (Optional[int]): Compression predictor + transform (Optional[List[float]]): Affine transform coefficients + compression (Optional[int]): Compression type + tile_offsets (Optional[List[int]]): Byte offsets to tiles + tile_byte_counts (Optional[List[int]]): Size of each tile + pixel_scale (Optional[Tuple[float, ...]]): Resolution in CRS units + tiepoint (Optional[Tuple[float, ...]]): Reference point coordinates + """ + + width: int + height: int + tile_width: int + tile_height: int + dtype: Union[np.dtype, pa.DataType] + crs: int + predictor: Optional[int] = None + transform: Optional[List[float]] = None + compression: Optional[int] = None + tile_offsets: Optional[List[int]] = None + tile_byte_counts: Optional[List[int]] = None + pixel_scale: Optional[Tuple[float, ...]] = None + tiepoint: Optional[Tuple[float, ...]] = None + + +@dataclass +class SceneInfo: + """Metadata for a single scene. + + Attributes: + id (str): Scene ID of Scene from STAC + datetime (datetime): Datetime of Scene from STAC + bbox (List[float]): Bounds of Scene from STAC + scene_geometry (Any): Geometry of the scene (shapely geometry) + crs (int): Coordinate reference system code + cloud_cover (float): Cloud cover percentage + assets (Dict[str, Any]): Assets associated with the scene + metadata (Dict[str, Any]): Additional metadata + collection (str): Collection to which the scene belongs + """ + + id: str + datetime: datetime + bbox: List[float] + scene_geometry: Any + crs: int + cloud_cover: float + assets: Dict[str, Any] + metadata: Dict[str, Any] + collection: str