From 81017132896aaaa6cd081de4f9261f1d6c447b50 Mon Sep 17 00:00:00 2001 From: print-sid8 <sidsub94@gmail.com> Date: Thu, 9 Jan 2025 01:46:17 +0530 Subject: [PATCH 1/4] minor change to example code, showing example of 1 polygon --- examples/basic_workflow_xarray.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/examples/basic_workflow_xarray.py b/examples/basic_workflow_xarray.py index bf18c27..231d44c 100644 --- a/examples/basic_workflow_xarray.py +++ b/examples/basic_workflow_xarray.py @@ -23,14 +23,8 @@ def main(): [(77.55, 13.01), (77.58, 13.01), (77.58, 13.08), (77.55, 13.08), (77.55, 13.01)] ) - aoi2_polygon = Polygon( - [(77.56, 13.02), (77.59, 13.02), (77.59, 13.09), (77.56, 13.09), (77.56, 13.02)] - ) - - # get total bounds of all polygons above for stac search and stac index creation - bbox = aoi1_polygon.union(aoi2_polygon).bounds + bbox = aoi1_polygon.bounds - # 2. List existing collections print("1. Available Collections") print("----------------------") collections = Rasteret.list_collections(workspace_dir=workspace_dir) @@ -39,9 +33,8 @@ def main(): f"- {c['name']}: {c['data_source']}, {c['date_range']}, {c['size']} scenes" ) - # 3. Try loading existing collection or create new try: - processor = Rasteret.load_collection(f"{custom_name}_202401-03_landsat") + processor = Rasteret.load_collection(f"{custom_name}_202403-03_landsat") except ValueError: print("\n2. Creating New Collection") print("-------------------------") @@ -63,9 +56,7 @@ def main(): # Calculate NDVI using xarray operations ds = processor.get_xarray( - # pass multiple geometries not its union bounds - # for separate processing of each geometry - geometries=[aoi1_polygon, aoi2_polygon], + geometries=[aoi1_polygon], bands=["B4", "B5"], cloud_cover_lt=20, ) @@ -77,18 +68,16 @@ def main(): ndvi = (ds.B5 - ds.B4) / (ds.B5 + ds.B4) ndvi_ds = xr.Dataset( {"NDVI": ndvi}, - coords=ds.coords, # Preserve coordinates including CRS - attrs=ds.attrs, # Preserve metadata + coords=ds.coords, + attrs=ds.attrs, ) print("\nNDVI dataset:") print(ndvi_ds) - # Create output directory output_dir = Path("ndvi_results") output_dir.mkdir(exist_ok=True) - # Save per geometry, give prefix for output files in this case "ndvi" output_files = save_per_geometry( ndvi_ds, output_dir, file_prefix="ndvi", data_var="NDVI" ) From 72cb4523a97c47af0be1638401b98a532d3b2fab Mon Sep 17 00:00:00 2001 From: print-sid8 <sidsub94@gmail.com> Date: Thu, 9 Jan 2025 02:11:03 +0530 Subject: [PATCH 2/4] revamp code with more asynchronous methods, and keep user facing code unaffected --- src/rasteret/__init__.py | 8 +- src/rasteret/core/collection.py | 1 - src/rasteret/core/processor.py | 447 +++++++++++++++++++------------- src/rasteret/core/scene.py | 16 +- src/rasteret/fetch/cog.py | 266 +++++++++++-------- src/rasteret/stac/indexer.py | 6 + 6 files changed, 459 insertions(+), 285 deletions(-) diff --git a/src/rasteret/__init__.py b/src/rasteret/__init__.py index 8d99dd2..995344e 100644 --- a/src/rasteret/__init__.py +++ b/src/rasteret/__init__.py @@ -20,4 +20,10 @@ def version(): __version__ = version() -__all__ = ["Collection", "Rasteret", "CloudConfig", "AWSProvider", "DataSources"] +__all__ = [ + "Collection", + "Rasteret", + "CloudConfig", + "AWSProvider", + "DataSources", +] diff --git a/src/rasteret/core/collection.py b/src/rasteret/core/collection.py index c6a3386..e09e21e 100644 --- a/src/rasteret/core/collection.py +++ b/src/rasteret/core/collection.py @@ -7,7 +7,6 @@ import pyarrow.parquet as pq import pyarrow.dataset as ds import pandas as pd -import geopandas as gpd from shapely.geometry import Polygon from pathlib import Path diff --git a/src/rasteret/core/processor.py b/src/rasteret/core/processor.py index 0d683d9..397ad21 100644 --- a/src/rasteret/core/processor.py +++ b/src/rasteret/core/processor.py @@ -3,7 +3,7 @@ ========================================================== Core Components: --------------- +--------------- - Rasteret: Main interface for querying and processing scenes - Collection: Manages indexed satellite data - Scene: Handles individual scene processing @@ -21,9 +21,11 @@ ... ) """ +from __future__ import annotations import asyncio from pathlib import Path from typing import Dict, List, Optional, Union, Tuple, Any + import xarray as xr import geopandas as gpd import pandas as pd @@ -37,28 +39,67 @@ from rasteret.stac.indexer import StacToGeoParquetIndexer from rasteret.cloud import AWSProvider, CloudProvider, CloudConfig from rasteret.logging import setup_logger +from rasteret.fetch.cog import COGReader logger = setup_logger("INFO", customname="rasteret.processor") +class URLSigningCache: + """Efficient URL signing cache with thread-safe caching.""" + + def __init__( + self, + cloud_provider: Optional[CloudProvider] = None, + cloud_config: Optional[CloudConfig] = None, + max_size: int = 1024, + ): + """ + Initialize URL signing cache. + + Args: + cloud_provider: Cloud provider for URL signing + cloud_config: Cloud configuration + max_size: Maximum number of cached signed URLs + """ + self._cloud_provider = cloud_provider + self._cloud_config = cloud_config + self._cache = {} + self._max_size = max_size + self._lock = asyncio.Lock() + + async def get_signed_url(self, url: str) -> str: + """ + Get a signed URL, using cache if possible. + + Args: + url: Original URL to be signed + + Returns: + Signed URL + """ + async with self._lock: + # Check cache first + if url in self._cache: + return self._cache[url] + + # Sign URL if provider exists + if self._cloud_provider and self._cloud_config: + signed_url = self._cloud_provider.get_url(url, self._cloud_config) + + # Manage cache size + if len(self._cache) >= self._max_size: + # Remove oldest entry + self._cache.pop(next(iter(self._cache))) + + self._cache[url] = signed_url + return signed_url + + # If no signing possible, return original URL + return url + + class Rasteret: - """Main interface for satellite data retrieval and processing. - - Attributes: - data_source (str): Source dataset identifier - workspace_dir (Path): Directory for storing collections - custom_name (str, optional): Collection name prefix - date_range (Tuple[str, str], optional): Date range for collection - aws_profile (str, optional): AWS profile for authentication - - Examples: - >>> processor = Rasteret("landsat-c2l2-sr", "workspace") - >>> processor.create_index( - ... bbox=[77.55, 13.01, 77.58, 13.04], - ... date_range=["2024-01-01", "2024-01-31"] - ... ) - >>> df = processor.query(geometries=[polygon], bands=["B4", "B5"]) - """ + """Optimized Rasteret processor with connection pooling and caching.""" def __init__( self, @@ -66,49 +107,48 @@ def __init__( workspace_dir: Optional[Union[str, Path]] = None, custom_name: Optional[str] = None, date_range: Optional[Tuple[str, str]] = None, + max_concurrent: int = 50, ): - """Initialize Rasteret processor.""" + """ + Initialize Rasteret processor with optimized async handling. + + Args: + data_source: Source of satellite data + workspace_dir: Directory for storing collections + custom_name: Custom name for the collection + date_range: Date range for collection + max_concurrent: Maximum concurrent connections + """ self.data_source = data_source self.workspace_dir = Path(workspace_dir or Path.home() / "rasteret_workspace") self.custom_name = custom_name self.date_range = date_range + self.max_concurrent = max_concurrent - # Initialize cloud config early + # Initialize cloud resources self.cloud_config = CloudConfig.get_config(str(data_source)) self._cloud_provider = None if not self.cloud_config else AWSProvider() - self._collection = None - @property - def provider(self): - """Get cloud provider.""" - return self._cloud_provider - - def _get_collection_path(self) -> Path: - """Get expected collection path""" - return self.workspace_dir / f"{self.custom_name}.parquet" - - def _get_bbox_from_geometries(self, geometries: List[Polygon]) -> List[float]: - """Get combined bbox from geometries""" - bounds = [geom.bounds for geom in geometries] - return [ - min(b[0] for b in bounds), # minx - min(b[1] for b in bounds), # miny - max(b[2] for b in bounds), # maxx - max(b[3] for b in bounds), # maxy - ] - - def _ensure_collection(self) -> None: - """Ensure collection exists and is loaded with proper partitioning.""" - if not self.custom_name: - raise ValueError("custom_name is required") + # URL signing cache + self._url_cache = URLSigningCache( + cloud_provider=self._cloud_provider, cloud_config=self.cloud_config + ) - stac_path = self.workspace_dir / f"{self.custom_name}_stac" + # Persistent COG reader + self._cog_reader = None + self._collection = None - if self._collection is None: - if stac_path.exists(): - self._collection = Collection.from_local(stac_path) - else: - raise ValueError(f"Collection not found: {stac_path}") + async def __aenter__(self): + """Async context manager entry for resource management.""" + # Initialize COG reader with connection pooling + self._cog_reader = COGReader(max_concurrent=self.max_concurrent) + await self._cog_reader.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit for cleanup.""" + if self._cog_reader: + await self._cog_reader.__aexit__(exc_type, exc_val, exc_tb) def create_collection( self, @@ -117,46 +157,64 @@ def create_collection( force: bool = False, **filters, ) -> None: - """Create or load STAC index.""" + """Sync interface for collection creation""" + + def _sync_create(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._async_create_collection(bbox, date_range, force, **filters) + ) + finally: + loop.close() + + return _sync_create() + + async def _async_create_collection( + self, + bbox: List[float], + date_range: Optional[Tuple[str, str]] = None, + force: bool = False, + **filters, + ) -> None: + """Internal async implementation for collection creation""" if not self.custom_name: raise ValueError("custom_name is required") - # Create standardized collection name with date range collection_name = Collection.create_name( self.custom_name, date_range or self.date_range, str(self.data_source) ) + collection_path = self.workspace_dir / f"{collection_name}.parquet" - stac_path = self.workspace_dir / f"{collection_name}_stac" - - if stac_path.exists() and not force: - self._collection = Collection.from_local(stac_path) - logger.info(f"Loading existing collection: {collection_name}") + if collection_path.exists() and not force: + logger.info(f"Collection {collection_name} already exists") + self._collection = Collection.from_parquet(collection_path) return - # Create new collection + # Initialize indexer with required params indexer = StacToGeoParquetIndexer( - data_source=self.data_source, + data_source=str(self.data_source), stac_api=STAC_ENDPOINTS[self.data_source], - workspace_dir=stac_path, - cloud_provider=self.provider, - cloud_config=self.cloud_config, + workspace_dir=self.workspace_dir, name=collection_name, + cloud_provider=self._cloud_provider, + cloud_config=self.cloud_config, + max_concurrent=self.max_concurrent, ) - self._collection = asyncio.run( - indexer.build_index( - bbox=bbox, date_range=date_range or self.date_range, query=filters - ) + # Use build_index instead of create_index + self._collection = await indexer.build_index( + bbox=bbox, date_range=date_range or self.date_range, query=filters ) - if self._collection is not None: - logger.info(f"Created collection: {collection_name}") + logger.info(f"Created collection: {collection_name}") @classmethod def list_collections( cls, workspace_dir: Optional[Path] = None ) -> List[Dict[str, Any]]: - """List collections with metadata.""" + """List collections with metadata (unchanged).""" workspace_dir = workspace_dir or Path.home() / "rasteret_workspace" collections = [] @@ -202,21 +260,27 @@ def list_collections( def load_collection( cls, collection_name: str, workspace_dir: Optional[Path] = None ) -> "Rasteret": - """Load collection by name.""" + """Load collection by name with async preparation.""" workspace_dir = workspace_dir or Path.home() / "rasteret_workspace" - stac_path = workspace_dir / f"{collection_name.replace('_stac', '')}_stac" + + # Remove _stac suffix if present + clean_name = collection_name.replace("_stac", "") + stac_path = workspace_dir / f"{clean_name}_stac" if not stac_path.exists(): raise ValueError(f"Collection not found: {collection_name}") - # Get data source from name - data_source = collection_name.split("_")[-1].upper() + # Parse data source from name + try: + data_source = clean_name.split("_")[-1].upper() + except IndexError: + data_source = "UNKNOWN" # Create processor processor = cls( data_source=getattr(DataSources, data_source, data_source), workspace_dir=workspace_dir, - custom_name=collection_name, + custom_name=clean_name, ) # Load collection @@ -225,130 +289,159 @@ def load_collection( logger.info(f"Loaded existing collection: {collection_name}") return processor - def get_gdf( - self, - geometries: List[Polygon], - bands: List[str], - max_concurrent: int = 50, - **filters, - ) -> gpd.GeoDataFrame: - """Query indexed scenes matching filters for specified geometries and bands.""" - # Validate inputs - if not geometries: - raise ValueError("No geometries provided") - if not bands: - raise ValueError("No bands specified") - - # Ensure collection exists and is loaded - self._ensure_collection() - - if filters: - self._collection = self._collection.filter_scenes(**filters) - - return asyncio.run(self._get_gdf(geometries, bands, max_concurrent)) - - async def _get_gdf(self, geometries, bands, max_concurrent): - total_scenes = len(self._collection.dataset.to_table()) - logger.info(f"Processing {total_scenes} scenes for {len(bands)} bands") - results = [] - - async for scene in tqdm( - self._collection.iterate_scenes(self.data_source), - total=total_scenes, - desc="Loading scenes", - ): - scene_results = await scene.load_bands( - geometries, - bands, - max_concurrent, - cloud_provider=self.provider, - cloud_config=self.cloud_config, - for_xarray=False, - ) - results.append(scene_results) + async def _sign_scene_urls(self, scene): + """ + Sign URLs for a scene's assets. - return gpd.GeoDataFrame(pd.concat(results, ignore_index=True)) + Args: + scene: Scene with assets to sign - def get_xarray( + Returns: + Scene with signed URLs + """ + # Create copies to avoid modifying original + signed_assets = {} + for band, asset in scene.assets.items(): + # Get signed URL + signed_url = await self._url_cache.get_signed_url(asset["href"]) + + # Create a copy of the asset with signed URL + signed_asset = asset.copy() + signed_asset["href"] = signed_url + signed_assets[band] = signed_asset + + # Update scene assets with signed URLs + scene.assets = signed_assets + return scene + + async def _get_scene_data( self, geometries: List[Polygon], bands: List[str], - max_concurrent: int = 50, + for_xarray: bool = True, + batch_size: int = 10, **filters, - ) -> xr.Dataset: + ) -> Union[List[gpd.GeoDataFrame], List[xr.Dataset]]: """ - Query collection and return as xarray Dataset. + Optimized async scene data retrieval with URL signing and batching. Args: - geometries: List of polygons to query - bands: List of band identifiers - max_concurrent: Maximum concurrent requests - **filters: Additional filters (e.g. cloud_cover_lt=20) + geometries: List of geometries to process + bands: Bands to retrieve + for_xarray: Whether to return xarray or GeoDataFrame + batch_size: Number of scenes to process in parallel + **filters: Additional filtering parameters Returns: - xarray Dataset with data, coordinates, and CRS + List of processed datasets """ - # Same validation as query() - if not geometries: - raise ValueError("No geometries provided") - if not bands: - raise ValueError("No bands specified") - - self._ensure_collection() + if not self._collection: + raise ValueError("No collection loaded") + # Apply filters if provided if filters: self._collection = self._collection.filter_scenes(**filters) - return asyncio.run( - self._get_xarray( - geometries=geometries, bands=bands, max_concurrent=max_concurrent - ) - ) + results = [] - async def _get_xarray( - self, - geometries: List[Polygon], - bands: List[str], - max_concurrent: int, - for_xarray: bool = True, + # Prepare scene batches + scene_batches = [] + current_batch = [] + + async for scene in self._collection.iterate_scenes(self.data_source): + # Sign URLs for the scene + scene = await self._sign_scene_urls(scene) + current_batch.append(scene) + + if len(current_batch) == batch_size: + scene_batches.append(current_batch) + current_batch = [] + + # Add remaining scenes + if current_batch: + scene_batches.append(current_batch) + + # Process scene batches in parallel + for batch in tqdm(scene_batches, desc="Processing scenes"): + # Create tasks for batch processing + tasks = [ + scene.load_bands( + geometries=geometries, + band_codes=bands, + max_concurrent=self.max_concurrent, + cloud_provider=self._cloud_provider, + cloud_config=self.cloud_config, + for_xarray=for_xarray, + ) + for scene in batch + ] + + # Gather results from batch + batch_results = await asyncio.gather(*tasks) + results.extend([r for r in batch_results if r is not None]) + + return results + + def get_xarray( + self, geometries: Union[Polygon, List[Polygon]], bands: List[str], **filters ) -> xr.Dataset: - datasets = [] - total_scenes = len(self._collection.dataset.to_table()) - - logger.info(f"Processing {total_scenes} scenes for {len(bands)} bands") - - async for scene in tqdm( - self._collection.iterate_scenes(self.data_source), - total=total_scenes, - desc="Loading scenes", - ): - scene_ds = await scene.load_bands( - geometries=geometries, - band_codes=bands, - max_concurrent=max_concurrent, - cloud_provider=self.provider, - cloud_config=self.cloud_config, - for_xarray=for_xarray, - ) - if scene_ds is not None: - datasets.append(scene_ds) - logger.debug( - f"Loaded scene {scene.id} ({len(datasets)}/{total_scenes})" + """Sync interface for xarray retrieval""" + if isinstance(geometries, Polygon): + geometries = [geometries] + + def _sync_get(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._async_get_xarray(geometries, bands, **filters) ) + finally: + loop.close() + + return _sync_get() - if not datasets: - raise ValueError("No valid data found for query") + async def _async_get_xarray( + self, geometries: List[Polygon], bands: List[str], **filters + ) -> xr.Dataset: + """Internal async implementation""" + async with self: + scene_datasets = await self._get_scene_data( + geometries=geometries, bands=bands, for_xarray=True, **filters + ) + if not scene_datasets: + raise ValueError("No valid data found") + logger.info(f"Merging {len(scene_datasets)} datasets") + merged = xr.merge(scene_datasets) + return merged.sortby("time") - logger.info(f"Merging {len(datasets)} datasets") - merged = xr.merge(datasets) - merged = merged.sortby("time") + def get_gdf( + self, geometries: Union[Polygon, List[Polygon]], bands: List[str], **filters + ) -> gpd.GeoDataFrame: + """Sync interface for GeoDataFrame retrieval""" + if isinstance(geometries, Polygon): + geometries = [geometries] - logger.info(f"Data retrieved for {len(geometries)} geometries") - logger.info(f"Dataset shape: {merged.sizes}") + def _sync_get(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self._async_get_gdf(geometries, bands, **filters) + ) + finally: + loop.close() - return merged + return _sync_get() - def __repr__(self): - return ( - f"Rasteret(data_source={self.data_source}, custom_name={self.custom_name})" - ) + async def _async_get_gdf( + self, geometries: List[Polygon], bands: List[str], **filters + ) -> gpd.GeoDataFrame: + """Internal async implementation""" + async with self: + scene_dfs = await self._get_scene_data( + geometries=geometries, bands=bands, for_xarray=False, **filters + ) + if not scene_dfs: + raise ValueError("No valid data found") + return gpd.GeoDataFrame(pd.concat(scene_dfs, ignore_index=True)) diff --git a/src/rasteret/core/scene.py b/src/rasteret/core/scene.py index 1f22418..0d0b4a1 100644 --- a/src/rasteret/core/scene.py +++ b/src/rasteret/core/scene.py @@ -157,9 +157,19 @@ async def _load_single_band( max_concurrent: int = 50, ) -> Optional[Dict]: """Load single band data for geometry.""" - cog_meta, url = self.get_band_cog_metadata( - band_code, provider=cloud_provider, cloud_config=cloud_config - ) + # Get metadata and URL only once + if not hasattr(self, "_band_meta_cache"): + self._band_meta_cache = {} + + cache_key = f"{band_code}" + if cache_key not in self._band_meta_cache: + cog_meta, url = self.get_band_cog_metadata( + band_code, provider=cloud_provider, cloud_config=cloud_config + ) + self._band_meta_cache[cache_key] = (cog_meta, url) + else: + cog_meta, url = self._band_meta_cache[cache_key] + if not cog_meta or not url: return None diff --git a/src/rasteret/fetch/cog.py b/src/rasteret/fetch/cog.py index 5d31bb1..857a3f1 100644 --- a/src/rasteret/fetch/cog.py +++ b/src/rasteret/fetch/cog.py @@ -31,6 +31,147 @@ class COGTileRequest: metadata: CogMetadata # Full metadata including transform +class COGReader: + """Manages connection pooling and COG reading operations.""" + + def __init__(self, max_concurrent: int = 50): + self.max_concurrent = max_concurrent + self.limits = httpx.Limits( + max_keepalive_connections=max_concurrent, + max_connections=max_concurrent, + keepalive_expiry=60.0, # Shorter keepalive for HTTP/2 + ) + self.timeout = httpx.Timeout(30.0, connect=10.0) + self.client = None + self.sem = None + self.batch_size = 12 # Reduced for better HTTP/2 multiplexing + + async def __aenter__(self): + self.client = httpx.AsyncClient( + timeout=self.timeout, + limits=self.limits, + http2=True, + verify=True, + trust_env=True, + ) + self.sem = asyncio.Semaphore(self.max_concurrent) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.client: + await self.client.aclose() + + def merge_ranges( + self, requests: List[COGTileRequest], gap_threshold: int = 1024 + ) -> List[Tuple[int, int]]: + """Merge nearby byte ranges to minimize HTTP requests""" + if not requests: + return [] + + ranges = [(r.offset, r.offset + r.size) for r in requests] + ranges.sort() + merged = [ranges[0]] + + for curr in ranges[1:]: + prev = merged[-1] + if curr[0] <= prev[1] + gap_threshold: + merged[-1] = (prev[0], max(prev[1], curr[1])) + else: + merged.append(curr) + + return merged + + async def read_merged_tiles( + self, requests: List[COGTileRequest], debug: bool = False + ) -> Dict[Tuple[int, int], np.ndarray]: + """Parallel tile reading with HTTP/2 multiplexing""" + if not requests: + return {} + + # Group by URL for HTTP/2 connection reuse + url_groups = {} + for req in requests: + url_groups.setdefault(req.url, []).append(req) + + results = {} + for url, group_requests in url_groups.items(): + ranges = self.merge_ranges(group_requests) + + # Process ranges in batches + for i in range(0, len(ranges), self.batch_size): + batch = ranges[i : i + self.batch_size] + batch_tasks = [ + self._read_and_process_range( + url, + start, + end, + [r for r in group_requests if start <= r.offset < end], + ) + for start, end in batch + ] + batch_results = await asyncio.gather(*batch_tasks) + for result in batch_results: + results.update(result) + + return results + + async def _read_and_process_range( + self, url: str, start: int, end: int, requests: List[COGTileRequest] + ) -> Dict[Tuple[int, int], np.ndarray]: + """Read and process a byte range with retries""" + async with self.sem: + data = await self._read_range(url, start, end) + + # Process tiles in parallel + tasks = [] + for req in requests: + offset = req.offset - start + tile_data = data[offset : offset + req.size] + tasks.append(self._process_tile(tile_data, req.metadata)) + + tiles = await asyncio.gather(*tasks) + return {(req.row, req.col): tile for req, tile in zip(requests, tiles)} + + async def _read_range(self, url: str, start: int, end: int) -> bytes: + """HTTP/2 optimized range reading""" + headers = {"Range": f"bytes={start}-{end-1}"} + + for attempt in range(3): + try: + async with self.sem: + response = await self.client.get(url, headers=headers) + response.raise_for_status() + return response.content + except Exception: + if attempt == 2: + raise + await asyncio.sleep(1 * (2**attempt)) + + async def _process_tile(self, data: bytes, metadata: CogMetadata) -> np.ndarray: + """Process tile data in thread pool""" + loop = asyncio.get_running_loop() + + # Decompress in thread pool + decompressed = await loop.run_in_executor(None, imagecodecs.zlib_decode, data) + + # Process in thread pool to avoid blocking + return await loop.run_in_executor( + None, self._process_tile_sync, decompressed, metadata + ) + + def _process_tile_sync(self, data: bytes, metadata: CogMetadata) -> np.ndarray: + """Synchronous tile processing""" + tile = np.frombuffer(data, dtype=np.uint16).reshape( + (metadata.tile_height, metadata.tile_width) + ) + + if metadata.predictor == 2: + tile = tile.astype(np.uint16) + np.cumsum(tile, axis=1, out=tile) + + return tile.astype(np.float32) + + def compute_tile_indices( geometry: Polygon, transform: List[float], @@ -202,46 +343,6 @@ def apply_mask_and_crop( return masked_data, cropped_transform -async def read_tile( - request: COGTileRequest, - client: httpx.AsyncClient, - sem: asyncio.Semaphore, - retries: int = 3, - retry_delay: float = 1.0, -) -> Optional[np.ndarray]: - """Read a single tile using byte range request.""" - for attempt in range(retries): - try: - async with sem: - headers = { - "Range": f"bytes={request.offset}-{request.offset+request.size-1}" - } - response = await client.get(request.url, headers=headers) - if response.status_code != 206: - raise ValueError(f"Range request failed: {response.status_code}") - - # Simple, direct data flow like tiles.py - decompressed = imagecodecs.zlib_decode(response.content) - data = np.frombuffer(decompressed, dtype=np.uint16) - data = data.reshape( - (request.metadata.tile_height, request.metadata.tile_width) - ) - - # Predictor handling exactly like tiles.py - if request.metadata.predictor == 2: - data = data.astype(np.uint16) - for i in range(data.shape[0]): - data[i] = np.cumsum(data[i], dtype=np.uint16) - - return data.astype(np.float32) - - except Exception as e: - if attempt == retries - 1: - logger.error(f"Failed to read tile: {str(e)}") - return None - await asyncio.sleep(retry_delay * (2**attempt)) - - async def read_cog_tile_data( url: str, metadata: CogMetadata, @@ -249,32 +350,9 @@ async def read_cog_tile_data( max_concurrent: int = 50, debug: bool = False, ) -> Tuple[np.ndarray, Optional[Affine]]: - """Read COG data, optionally masked by geometry. - - Args: - url: URL of the COG file - metadata: COG metadata including transform - geometry: Optional polygon to mask/filter data - max_concurrent: Maximum concurrent requests - debug: Enable debug logging - - Returns: - Tuple of: - - np.ndarray: The masked data array - - Affine: Transform matrix for the masked data - None if no transform available - """ + """Read COG data, optionally masked by geometry.""" if debug: - logger.info( - f""" - Input Parameters: - - CRS: {metadata.crs} - - Transform: {metadata.transform} - - Image Size: {metadata.width}x{metadata.height} - - Tile Size: {metadata.tile_width}x{metadata.tile_height} - - Geometry: {geometry.wkt if geometry else None} - """ - ) + logger.info(f"Reading COG data from {url}") if metadata.transform is None: return np.array([]), None @@ -303,31 +381,19 @@ async def read_cog_tile_data( if not intersecting_tiles: return np.array([]), None - # Set up HTTP client with connection pooling - limits = httpx.Limits( - max_keepalive_connections=max_concurrent, max_connections=max_concurrent - ) - timeout = httpx.Timeout(30.0) - - async with httpx.AsyncClient(timeout=timeout, limits=limits, http2=True) as client: - sem = asyncio.Semaphore(max_concurrent) - - # Read tiles - tiles = {} - tasks = [] - tiles_x = (metadata.width + metadata.tile_width - 1) // metadata.tile_width - - # Create tasks for all tiles - for row, col in intersecting_tiles: - tile_idx = row * tiles_x + col # Linear tile index + # Create tile requests + requests = [] + tiles_x = (metadata.width + metadata.tile_width - 1) // metadata.tile_width - if tile_idx >= len(metadata.tile_offsets): - if debug: - logger.warning(f"Tile index {tile_idx} out of bounds") - continue + for row, col in intersecting_tiles: + tile_idx = row * tiles_x + col + if tile_idx >= len(metadata.tile_offsets): + if debug: + logger.warning(f"Tile index {tile_idx} out of bounds") + continue - # Create tile request - request = COGTileRequest( + requests.append( + COGTileRequest( url=url, offset=metadata.tile_offsets[tile_idx], size=metadata.tile_byte_counts[tile_idx], @@ -335,22 +401,16 @@ async def read_cog_tile_data( col=col, metadata=metadata, ) + ) - tasks.append((row, col, read_tile(request, client, sem))) + # Use COGReader for efficient tile reading + async with COGReader(max_concurrent=max_concurrent) as reader: + tiles = await reader.read_merged_tiles(requests, debug=debug) - # Gather results - for row, col, task in tasks: - try: - tile_data = await task - if tile_data is not None: - tiles[(row, col)] = tile_data - except Exception as e: - logger.error(f"Failed to read tile at ({row}, {col}): {str(e)}") - - if not tiles: - return np.array([]), None + if not tiles: + return np.array([]), None - # Merge tiles + # Merge tiles and handle transforms merged_data, bounds = merge_tiles( tiles, (metadata.tile_width, metadata.tile_height), dtype=np.float32 ) @@ -380,7 +440,7 @@ async def read_cog_tile_data( # Apply geometry mask if provided if geometry is not None: - merged_data, cropped_transform = apply_mask_and_crop( + merged_data, merged_transform = apply_mask_and_crop( merged_data, geometry, merged_transform ) @@ -394,4 +454,4 @@ async def read_cog_tile_data( """ ) - return merged_data, cropped_transform + return merged_data, merged_transform diff --git a/src/rasteret/stac/indexer.py b/src/rasteret/stac/indexer.py index a98cc7a..98122c6 100644 --- a/src/rasteret/stac/indexer.py +++ b/src/rasteret/stac/indexer.py @@ -346,3 +346,9 @@ async def _process_batch( ) return enriched_items + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass From e399c11d5e00bdf459df4c13ddd1f291243195d2 Mon Sep 17 00:00:00 2001 From: print-sid8 <sidsub94@gmail.com> Date: Thu, 9 Jan 2025 02:11:24 +0530 Subject: [PATCH 3/4] minor typo --- examples/basic_workflow_xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/basic_workflow_xarray.py b/examples/basic_workflow_xarray.py index 231d44c..622e554 100644 --- a/examples/basic_workflow_xarray.py +++ b/examples/basic_workflow_xarray.py @@ -14,7 +14,7 @@ def main(): workspace_dir = Path.home() / "rasteret_workspace" workspace_dir.mkdir(exist_ok=True) - custom_name = "bangalore-v3" + custom_name = "bangalore" date_range = ("2024-03-01", "2024-03-31") data_source = DataSources.LANDSAT From 16acb379f164f67b8c613f35cf639ae5509bef06 Mon Sep 17 00:00:00 2001 From: print-sid8 <sidsub94@gmail.com> Date: Thu, 9 Jan 2025 02:11:45 +0530 Subject: [PATCH 4/4] version bump --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1441c37..a129ccc 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="rasteret", - version="0.1.9", + version="0.1.10", author="Sidharth Subramaniam", author_email="sid@terrafloww.com", description="Fast and efficient access to Cloud-Optimized GeoTIFFs (COGs)",