Skip to content

Commit

Permalink
adding v0.1.0 code
Browse files Browse the repository at this point in the history
  • Loading branch information
sid_terrafloww committed Jan 5, 2025
1 parent 6bb8b28 commit 84f913e
Show file tree
Hide file tree
Showing 19 changed files with 3,027 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/basic_workflow_gdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# examples/basic_workflow.py
from pathlib import Path
from shapely.geometry import Polygon

from rasteret import Rasteret


def main():
"""Demonstrate core workflows with Rasteret."""
# 1. Define parameters

custom_name = "bangalore3"
date_range = ("2024-01-01", "2024-01-31")
data_source = "landsat-c2l2-sr"

workspace_dir = Path.home() / "rasteret_workspace"
workspace_dir.mkdir(exist_ok=True)

print("1. Defining Area of Interest")
print("--------------------------")

# Define area and time of interest
aoi_polygon = Polygon(
[(77.55, 13.01), (77.58, 13.01), (77.58, 13.08), (77.55, 13.08), (77.55, 13.01)]
)

aoi_polygon2 = Polygon(
[(77.56, 13.02), (77.59, 13.02), (77.59, 13.09), (77.56, 13.09), (77.56, 13.02)]
)

# get total bounds of all polygons above
bbox = aoi_polygon.union(aoi_polygon2).bounds

print("\n2. Creating and Loading Collection")
print("--------------------------")

# 2. Initialize processor - name generated automatically
processor = Rasteret(
custom_name=custom_name,
data_source=data_source,
output_dir=workspace_dir,
date_range=date_range,
)

# Create index if needed
if processor._collection is None:
processor.create_index(
bbox=bbox, date_range=date_range, query={"cloud_cover_lt": 20}
)

# List existing collections
collections = Rasteret.list_collections(dir=workspace_dir)
print("Available collections:")
for c in collections:
print(f"- {c['name']}: {c['size']} scenes")

print("\n3. Processing Data")
print("----------------")

df = processor.get_gdf(
geometries=[aoi_polygon, aoi_polygon2], bands=["B4", "B5"], cloud_cover_lt=20
)

print(f"Columns: {df.columns}")
print(f"Unique dates: {df.datetime.dt.date.unique()}")
print(f"Unique geometries: {df.geometry.unique()}")


if __name__ == "__main__":
main()
104 changes: 104 additions & 0 deletions examples/basic_workflow_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# examples/basic_workflow.py
from pathlib import Path
from shapely.geometry import Polygon
import xarray as xr

from rasteret import Rasteret
from rasteret.constants import DataSources
from rasteret.core.utils import save_per_geometry


def main():

# 1. Define parameters
custom_name = "bangalore"
date_range = ("2024-01-01", "2024-01-31")
data_source = DataSources.LANDSAT # or SENTINEL2

workspace_dir = Path.home() / "rasteret_workspace"
workspace_dir.mkdir(exist_ok=True)

print("1. Defining Area of Interest")
print("--------------------------")

# Define area and time of interest
aoi1_polygon = Polygon(
[(77.55, 13.01), (77.58, 13.01), (77.58, 13.08), (77.55, 13.08), (77.55, 13.01)]
)

aoi2_polygon = Polygon(
[(77.56, 13.02), (77.59, 13.02), (77.59, 13.09), (77.56, 13.09), (77.56, 13.02)]
)

# get total bounds of all polygons above for stac search and stac index creation
bbox = aoi1_polygon.union(aoi2_polygon).bounds

print("\n2. Creating and Loading Collection")
print("--------------------------")

# 2. Initialize processor - name generated automatically
processor = Rasteret(
custom_name=custom_name,
data_source=data_source,
output_dir=workspace_dir,
date_range=date_range,
)

# Create index if needed
if processor._collection is None:
processor.create_index(
bbox=bbox,
date_range=date_range,
cloud_cover_lt=20,
# add platform filter for Landsat 9, 8, 7, 5, 4 if needed,
# else remove it for all platforms
# This is unique to Landsat STAC endpoint
platform={"in": ["LANDSAT_8"]},
)

# List existing collections
collections = Rasteret.list_collections(dir=workspace_dir)
print("Available collections:")
for c in collections:
print(f"- {c['name']}: {c['size']} scenes")

print("\n3. Processing Data")
print("----------------")

# Calculate NDVI using xarray operations
ds = processor.get_xarray(
# pass multiple geometries not its union bounds
# for separate processing of each geometry
geometries=[aoi1_polygon, aoi2_polygon],
bands=["B4", "B5"],
cloud_cover_lt=20,
)

print("\nInput dataset:")
print(ds)

# Calculate NDVI and preserve metadata
ndvi = (ds.B5 - ds.B4) / (ds.B5 + ds.B4)
ndvi_ds = xr.Dataset(
{"NDVI": ndvi},
coords=ds.coords, # Preserve coordinates including CRS
attrs=ds.attrs, # Preserve metadata
)

print("\nNDVI dataset:")
print(ndvi_ds)

# Create output directory
output_dir = Path("ndvi_results")
output_dir.mkdir(exist_ok=True)

# Save per geometry, give prefix for output files in this case "ndvi"
output_files = save_per_geometry(ndvi_ds, output_dir, prefix="ndvi")

print("\nProcessed NDVI files:")
for geom_id, filepath in output_files.items():
print(f"Geometry {geom_id}: {filepath}")


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions src/rasteret/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Rasteret package."""

from importlib.metadata import version as get_version

from rasteret.core.processor import Rasteret
from rasteret.core.collection import Collection
from rasteret.cloud import CloudConfig, AWSProvider
from rasteret.constants import DataSources
from rasteret.logging import setup_logger

# Set up logging
setup_logger("INFO")


def version():
"""Return the version of the rasteret package."""
return get_version("rasteret")

__version__ = version()

__all__ = ["Collection", "Rasteret", "CloudConfig", "AWSProvider", "DataSources"]
118 changes: 118 additions & 0 deletions src/rasteret/cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
""" Utilities for cloud storage """

from dataclasses import dataclass
from typing import Optional, Dict
import boto3
from rasteret.logging import setup_logger

logger = setup_logger()


@dataclass
class CloudConfig:
"""Storage configuration for data source"""

provider: str
requester_pays: bool = False
region: str = "us-west-2"
url_patterns: Dict[str, str] = None # Map HTTPS patterns to cloud URLs


# Configuration for supported data sources
CLOUD_CONFIG = {
"landsat-c2l2-sr": CloudConfig(
provider="aws",
requester_pays=True,
region="us-west-2",
url_patterns={"https://landsatlook.usgs.gov/data/": "s3://usgs-landsat/"},
),
"sentinel-2-l2a": CloudConfig(
provider="aws", requester_pays=False, region="us-west-2"
),
}


class CloudProvider:
"""Base class for cloud providers"""

@staticmethod
def check_aws_credentials() -> bool:
"""Check AWS credentials before any operations"""
try:
session = boto3.Session()
credentials = session.get_credentials()
if credentials is None:
logger.error(
"\nAWS credentials not found. To configure:\n"
"1. Create ~/.aws/credentials with:\n"
"[default]\n"
"aws_access_key_id = YOUR_ACCESS_KEY\n"
"aws_secret_access_key = YOUR_SECRET_KEY\n"
"OR\n"
"2. Set environment variables:\n"
"export AWS_ACCESS_KEY_ID='your_key'\n"
"export AWS_SECRET_ACCESS_KEY='your_secret'"
)
return False
return True
except Exception:
return False

def get_url(self, url: str, config: CloudConfig) -> str:
"""Central URL resolution and signing method"""
raise NotImplementedError


class AWSProvider(CloudProvider):
def __init__(self, profile: Optional[str] = None, region: str = "us-west-2"):
if not self.check_aws_credentials():
raise ValueError("AWS credentials not configured")

try:
session = (
boto3.Session(profile_name=profile) if profile else boto3.Session()
)
self.client = session.client("s3", region_name=region)
except Exception as e:
logger.error(f"Failed to initialize AWS client: {str(e)}")
raise ValueError("AWS provider initialization failed")

def get_url(self, url: str, config: CloudConfig) -> Optional[str]:
"""Resolve and sign URL based on configuration"""
# First check for alternate S3 URL in STAC metadata
if isinstance(url, dict) and "alternate" in url and "s3" in url["alternate"]:
s3_url = url["alternate"]["s3"]["href"]
logger.debug(f"Using alternate S3 URL: {s3_url}")
url = s3_url
# Then check URL patterns if defined
elif config.url_patterns:
for http_pattern, s3_pattern in config.url_patterns.items():
if url.startswith(http_pattern):
url = url.replace(http_pattern, s3_pattern)
logger.debug(f"Converted to S3 URL: {url}")
break

# Sign URL if it's an S3 URL
if url.startswith("s3://"):
try:
bucket = url.split("/")[2]
key = "/".join(url.split("/")[3:])

params = {
"Bucket": bucket,
"Key": key,
}
if config.requester_pays:
params["RequestPayer"] = "requester"

return self.client.generate_presigned_url(
"get_object", Params=params, ExpiresIn=3600
)
except Exception as e:
logger.error(f"Failed to sign URL {url}: {str(e)}")
return None

return url


__all__ = ["CloudConfig", "AWSProvider"]
Loading

0 comments on commit 84f913e

Please sign in to comment.