Skip to content

Commit

Permalink
Merge branch 'geo-ai-hack' into 26-fixed-positional-embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Alikerin committed Feb 6, 2025
2 parents a8a1a1c + 6eb5d26 commit 2f9f76e
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions instageo/model/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,53 @@
import os
import random
from functools import partial
from typing import Callable, List, Tuple
from typing import Any, Callable, List, Tuple

import numpy as np
import pandas as pd
import rasterio
import torch
import xarray as xr
from absl import logging
from PIL import Image
from rasterio.crs import CRS
from torchvision import transforms

from instageo.data.hls_utils import open_mf_tiff_dataset

def open_mf_tiff_dataset(
band_files: dict[str, Any], load_masks: bool
) -> tuple[xr.Dataset, xr.Dataset | None, CRS]:
"""Open multiple TIFF files as an xarray Dataset.
Args:
band_files (Dict[str, Dict[str, str]]): A dictionary mapping band names to file paths.
load_masks (bool): Whether or not to load the masks files.
Returns:
(xr.Dataset, xr.Dataset | None, CRS): A tuple of xarray Dataset combining data from all the
provided TIFF files, (optionally) the masks, and the CRS
"""
band_paths = list(band_files["tiles"].values())
bands_dataset = xr.open_mfdataset(
band_paths,
concat_dim="band",
combine="nested",
mask_and_scale=False, # Scaling will be applied manually
)
bands_dataset.band_data.attrs["scale_factor"] = 1
mask_paths = list(band_files["fmasks"].values())
mask_dataset = (
xr.open_mfdataset(
mask_paths,
concat_dim="band",
combine="nested",
)
if load_masks
else None
)
with rasterio.open(band_paths[0]) as src:
crs = src.crs
return bands_dataset, mask_dataset, crs


def random_crop_and_flip(
Expand Down Expand Up @@ -246,14 +282,6 @@ def get_raster_data(
data = src.read()
if (not is_label) and bands:
data = data[bands, ...]
# For some reasons, some few HLS tiles are not scaled in v2.0.
# In the following lines, we find and scale them
bands = []
for band in data:
if band.max() > 10:
band *= 0.0001
bands.append(band)
data = np.stack(bands, axis=0)
return data


Expand Down

0 comments on commit 2f9f76e

Please sign in to comment.