Skip to content

Commit

Permalink
move functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Ines Elmufti authored and Ines Elmufti committed Feb 9, 2025
1 parent 76c425c commit d5293d9
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 121 deletions.
97 changes: 2 additions & 95 deletions instageo/apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import matplotlib.cm
from pathlib import Path
from pyproj import CRS, Transformer
from instageo.apps.viz import create_map_with_geotiff_tiles

# Initialize Dash App
app = dash.Dash(__name__)
server = app.server # Needed for deployment
Expand Down Expand Up @@ -41,101 +43,7 @@ def load_tile_metadata(json_path: str) -> list:
with open(json_path, "r") as f:
return json.load(f)

def is_tile_in_viewport(tile_bounds: dict, viewport: dict, zoom: float) -> bool:
"""Check if a tile is within the current viewport."""
lat_min, lat_max = viewport['latitude']['min'], viewport['latitude']['max']
lon_min, lon_max = viewport['longitude']['min'], viewport['longitude']['max']
lat_min -= 1 - math.exp(-0.1 * zoom)
lat_max += 1 - math.exp(-0.1 * zoom)
lon_min -= 1 - math.exp(-0.1 * zoom)
lon_max += 1 - math.exp(-0.1 * zoom)
tile_lat_min, tile_lat_max = tile_bounds['lat_min'], tile_bounds['lat_max']
tile_lon_min, tile_lon_max = tile_bounds['lon_min'], tile_bounds['lon_max']
return not (tile_lat_max < lat_min or tile_lat_min > lat_max or
tile_lon_max < lon_min or tile_lon_min > lon_max)

@lru_cache(maxsize = 8)
def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]:
"""Read GeoTIFF file into an xarray Dataset."""
xarr_dataset = xr.open_dataset(filepath).sel(band=1)
crs = rasterio.open(filepath).crs
return xarr_dataset, crs

def zoom_to_scale(zoom: float):
zoom_dict = {1:0.1,2:0.1,3:0.1,4:0.25,5:0.5,6:0.6,7:0.1,8:0.1}
zoom_ceiled = math.ceil(zoom)
print("zoom ciel",zoom_ceiled)
if zoom_ceiled in zoom_dict.keys():
scale = zoom_dict[zoom_ceiled]
else:
scale = 1.0
return scale

def create_map_with_geotiff_tiles(tile_metadata: list, viewport: dict, zoom: float, base_dir: str) -> go.Figure:
"""Create a map with multiple GeoTIFF tiles overlaid."""

fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style=MAP_STYLE if MAP_STYLE else "open-street-map" ,
mapbox=dict(
center=go.layout.mapbox.Center(
lat=(viewport["latitude"]["min"] + viewport["latitude"]["max"]) / 2,
lon=(viewport["longitude"]["min"] + viewport["longitude"]["max"]) / 2,
),
zoom=zoom,
),
margin={"r": 0, "t": 40, "l": 0, "b": 0},
)
mapbox_layers = []
for tile in tile_metadata:
if len(mapbox_layers) > 15:
break
if is_tile_in_viewport(tile['bounds'], viewport, zoom=zoom):
tile_path = os.path.join(base_dir, tile['name'])
xarr_dataset, crs = read_geotiff_to_xarray(tile_path)
scale = zoom_to_scale(zoom)
print("--zoom--", zoom)
print("----sclale",scale)
img, coordinates = add_raster_to_plotly_figure(xarr_dataset, crs, scale=scale)
mapbox_layers.append({"sourcetype": "image", "source": img, "coordinates": coordinates})
fig.update_layout(mapbox_layers=mapbox_layers)
return fig

def add_raster_to_plotly_figure(xarr_dataset: xr.Dataset, from_crs: CRS, scale: float = 1.0) -> tuple:
"""Convert raster data to an image and coordinates for Plotly."""
# Ensure the raster has the correct CRS
xarr_dataset = xarr_dataset.rio.write_crs(from_crs).rio.reproject("EPSG:3857")
xarr_dataset = xarr_dataset.where(xarr_dataset <= 1, 0) # Mask values <= 1

# Extract the variable containing raster data ('band_data' in this case)
band_data = xarr_dataset['band_data']

numpy_data = band_data.squeeze().to_numpy() # Ensure the array is 2D
plot_height, plot_width = numpy_data.shape

canvas = ds.Canvas(plot_width=int(plot_width * scale), plot_height=int(plot_height * scale))

# Use 'band_data' to aggregate
agg = canvas.raster(band_data, interpolate="linear") # Specify the variable to aggregate

# Calculate coordinates for the image
coords_lat_min, coords_lat_max = agg.coords["y"].values.min(), agg.coords["y"].values.max()
coords_lon_min, coords_lon_max = agg.coords["x"].values.min(), agg.coords["x"].values.max()

(coords_lon_min, coords_lon_max), (coords_lat_min, coords_lat_max) = epsg3857_to_epsg4326.transform(
[coords_lon_min, coords_lon_max], [coords_lat_min, coords_lat_max]
)

coordinates = [
[coords_lon_min, coords_lat_max],
[coords_lon_max, coords_lat_max],
[coords_lon_max, coords_lat_min],
[coords_lon_min, coords_lat_min],
]

# Generate the image using Datashader
img = tf.shade(agg, cmap=matplotlib.colormaps["Reds"], alpha=100, how="linear")[::-1].to_pil()
return img, coordinates


# Layout with Sidebar
Expand Down Expand Up @@ -231,7 +139,6 @@ def add_raster_to_plotly_figure(xarr_dataset: xr.Dataset, from_crs: CRS, scale:
)
def update_map(relayout_data, directory, year, month, current_viewport,dimensions):
"""Update the map based on viewport, zoom, and directory selection."""
print(dimensions)
if not directory or not Path(directory).is_dir():
fig = go.Figure(go.Scattermapbox())
fig.update_layout(
Expand Down
117 changes: 91 additions & 26 deletions instageo/apps/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# ------------------------------------------------------------------------------

"""Utils for Raster Visualisation."""

from functools import lru_cache
import math
import os
import datashader as ds
import datashader.transfer_functions as tf
import matplotlib.cm
Expand All @@ -29,6 +31,7 @@

epsg3857_to_epsg4326 = Transformer.from_crs(3857, 4326, always_xy=True)

MAP_STYLE="https://tiles.stadiamaps.com/styles/alidade_smooth.json"

def get_crs(filepath: str) -> CRS:
"""Retrieves the CRS of a GeoTiff data.
Expand Down Expand Up @@ -116,6 +119,72 @@ def add_raster_to_plotly_figure(
)[::-1].to_pil()
return img, coordinates

def is_tile_in_viewport(tile_bounds: dict, viewport: dict, zoom: float) -> bool:
"""Check if a tile is within the current viewport."""
lat_min, lat_max = viewport['latitude']['min'], viewport['latitude']['max']
lon_min, lon_max = viewport['longitude']['min'], viewport['longitude']['max']
lat_min -= 1 - math.exp(-0.1 * zoom)
lat_max += 1 - math.exp(-0.1 * zoom)
lon_min -= 1 - math.exp(-0.1 * zoom)
lon_max += 1 - math.exp(-0.1 * zoom)
tile_lat_min, tile_lat_max = tile_bounds['lat_min'], tile_bounds['lat_max']
tile_lon_min, tile_lon_max = tile_bounds['lon_min'], tile_bounds['lon_max']
return not (tile_lat_max < lat_min or tile_lat_min > lat_max or
tile_lon_max < lon_min or tile_lon_min > lon_max)

@lru_cache(maxsize = 8)
def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]:
"""Read GeoTIFF file into an xarray Dataset."""
xarr_dataset = xr.open_dataset(filepath).sel(band=1)
crs = rasterio.open(filepath).crs
return xarr_dataset, crs

def zoom_to_scale(zoom: float):
zoom_dict = {1:0.1,2:0.1,3:0.1,4:0.25,5:0.5,6:0.6,7:0.1,8:0.1}
zoom_ceiled = math.ceil(zoom)
print("zoom ciel",zoom_ceiled)
if zoom_ceiled in zoom_dict.keys():
scale = zoom_dict[zoom_ceiled]
else:
scale = 1.0
return scale


def add_raster_to_plotly_figure(xarr_dataset: xr.Dataset, from_crs: CRS, scale: float = 1.0) -> tuple:
"""Convert raster data to an image and coordinates for Plotly."""
# Ensure the raster has the correct CRS
xarr_dataset = xarr_dataset.rio.write_crs(from_crs).rio.reproject("EPSG:3857")
xarr_dataset = xarr_dataset.where(xarr_dataset <= 1, 0) # Mask values <= 1

# Extract the variable containing raster data ('band_data' in this case)
band_data = xarr_dataset['band_data']

numpy_data = band_data.squeeze().to_numpy() # Ensure the array is 2D
plot_height, plot_width = numpy_data.shape

canvas = ds.Canvas(plot_width=int(plot_width * scale), plot_height=int(plot_height * scale))

# Use 'band_data' to aggregate
agg = canvas.raster(band_data, interpolate="linear") # Specify the variable to aggregate

# Calculate coordinates for the image
coords_lat_min, coords_lat_max = agg.coords["y"].values.min(), agg.coords["y"].values.max()
coords_lon_min, coords_lon_max = agg.coords["x"].values.min(), agg.coords["x"].values.max()

(coords_lon_min, coords_lon_max), (coords_lat_min, coords_lat_max) = epsg3857_to_epsg4326.transform(
[coords_lon_min, coords_lon_max], [coords_lat_min, coords_lat_max]
)

coordinates = [
[coords_lon_min, coords_lat_max],
[coords_lon_max, coords_lat_max],
[coords_lon_max, coords_lat_min],
[coords_lon_min, coords_lat_min],
]

# Generate the image using Datashader
img = tf.shade(agg, cmap=matplotlib.colormaps["Reds"], alpha=100, how="linear")[::-1].to_pil()
return img, coordinates

def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]:
"""Read a GeoTIFF file into an xarray Dataset.
Expand All @@ -129,34 +198,30 @@ def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]:
return xr.open_dataset(filepath).sel(band=1), get_crs(filepath)


def create_map_with_geotiff_tiles(tiles_to_overlay: list[str]) -> go.Figure:
"""Create a map with multiple GeoTIFF tiles overlaid.
This function reads GeoTIFF files from a specified directory and overlays them on a
Plotly map.
Args:
tiles_to_overlay (list[str]): Path to tiles to overlay on map.
Returns:
Figure: A Plotly figure with overlaid GeoTIFF tiles.
"""
def create_map_with_geotiff_tiles(tile_metadata: list, viewport: dict, zoom: float, base_dir: str) -> go.Figure:
"""Create a map with multiple GeoTIFF tiles overlaid."""

fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style="open-street-map",
mapbox=dict(center=go.layout.mapbox.Center(lat=0, lon=20), zoom=2.0),
mapbox_style=MAP_STYLE if MAP_STYLE else "open-street-map" ,
mapbox=dict(
center=go.layout.mapbox.Center(
lat=(viewport["latitude"]["min"] + viewport["latitude"]["max"]) / 2,
lon=(viewport["longitude"]["min"] + viewport["longitude"]["max"]) / 2,
),
zoom=zoom,
),
margin={"r": 0, "t": 40, "l": 0, "b": 0},
)
fig.update_layout(margin={"r": 0, "t": 40, "l": 0, "b": 0})
mapbox_layers = []
for tile in tiles_to_overlay:
if tile.endswith(".tif") or tile.endswith(".tiff"):
xarr_dataset, crs = read_geotiff_to_xarray(tile)
img, coordinates = add_raster_to_plotly_figure(
xarr_dataset, crs, "band_data", scale=1.0
)
mapbox_layers.append(
{"sourcetype": "image", "source": img, "coordinates": coordinates}
)
# Overlay the resulting image
for tile in tile_metadata:
if len(mapbox_layers) > 15:
break
if is_tile_in_viewport(tile['bounds'], viewport, zoom=zoom):
tile_path = os.path.join(base_dir, tile['name'])
xarr_dataset, crs = read_geotiff_to_xarray(tile_path)
scale = zoom_to_scale(zoom)
img, coordinates = add_raster_to_plotly_figure(xarr_dataset, crs, scale=scale)
mapbox_layers.append({"sourcetype": "image", "source": img, "coordinates": coordinates})
fig.update_layout(mapbox_layers=mapbox_layers)
return fig

0 comments on commit d5293d9

Please sign in to comment.