Skip to content

Commit

Permalink
optimize_tile_loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Ines Elmufti authored and Ines Elmufti committed Feb 9, 2025
1 parent fb475e7 commit dbc3ca9
Show file tree
Hide file tree
Showing 2 changed files with 19,226 additions and 103 deletions.
283 changes: 180 additions & 103 deletions instageo/apps/app.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,187 @@
# ------------------------------------------------------------------------------
# This code is licensed under the Attribution-NonCommercial-ShareAlike 4.0
# International (CC BY-NC-SA 4.0) License.
#
# You are free to:
# - Share: Copy and redistribute the material in any medium or format
# - Adapt: Remix, transform, and build upon the material
#
# Under the following terms:
# - Attribution: You must give appropriate credit, provide a link to the license,
# and indicate if changes were made. You may do so in any reasonable manner,
# but not in any way that suggests the licensor endorses you or your use.
# - NonCommercial: You may not use the material for commercial purposes.
# - ShareAlike: If you remix, transform, or build upon the material, you must
# distribute your contributions under the same license as the original.
#
# For more details, see https://creativecommons.org/licenses/by-nc-sa/4.0/
# ------------------------------------------------------------------------------

"""InstaGeo Serve Module.
InstaGeo Serve is a web application that enables the visualisation of GeoTIFF files in an
interactive map.
"""

import glob
import json
import os
import json
import dash
import plotly.graph_objects as go
from dash import dcc, html
from dash.dependencies import Input, Output, State
import rasterio
import xarray as xr
import datashader as ds
import datashader.transfer_functions as tf
import matplotlib.cm
from pathlib import Path

import streamlit as st

from instageo import INSTAGEO_APPS_PATH
from instageo.apps.viz import create_map_with_geotiff_tiles


def generate_map(
directory: str, year: int, month: int, country_tiles: list[str]
) -> None:
"""Generate the plotly map.
Arguments:
directory (str): Directory containing GeoTiff files.
year (int): Selected year.
month (int): Selected month formatted as an integer in the range 1-12.
country_tiles (list[str]): List of MGRS tiles for the selected country.
Returns:
None.
"""
try:
if not directory or not Path(directory).is_dir():
raise ValueError("Invalid directory path.")

prediction_tiles = glob.glob(os.path.join(directory, f"{year}/{month}/*.tif"))
tiles_to_consider = [
tile
for tile in prediction_tiles
if os.path.basename(tile).split("_")[1].strip("T") in country_tiles
]

if not tiles_to_consider:
raise FileNotFoundError(
"No GeoTIFF files found for the given year, month, and country."
)

fig = create_map_with_geotiff_tiles(tiles_to_consider)
st.plotly_chart(fig, use_container_width=True)
except (ValueError, FileNotFoundError, Exception) as e:
st.error(f"An error occurred: {str(e)}")


def main() -> None:
"""Instageo Serve Main Entry Point."""
st.set_page_config(layout="wide")
st.title("InstaGeo Serve")

st.sidebar.subheader(
"This application enables the visualisation of GeoTIFF files on an interactive map.",
divider="rainbow",
from pyproj import CRS, Transformer
# Initialize Dash App
app = dash.Dash(__name__)
server = app.server # Needed for deployment

# Transformer for coordinate conversion
epsg3857_to_epsg4326 = Transformer.from_crs(3857, 4326, always_xy=True)
relayoutData = 10
# Default viewport
default_viewport = {
"latitude": {"min": -2.91, "max": -1.13},
"longitude": {"min": 29.02, "max": 30.81},
}
default_zoom = 5.0
MAP_STYLE="https://tiles.stadiamaps.com/styles/alidade_smooth.json"
def load_tile_metadata(json_path: str) -> list:
"""Load tile metadata from a JSON file."""
with open(json_path, "r") as f:
return json.load(f)

def is_tile_in_viewport(tile_bounds: dict, viewport: dict, zoom: float) -> bool:
"""Check if a tile is within the current viewport."""
lat_min, lat_max = viewport['latitude']['min'], viewport['latitude']['max']
lon_min, lon_max = viewport['longitude']['min'], viewport['longitude']['max']
lat_min -= 1/zoom
lat_max += 1/zoom
lon_min -= 1/zoom
lon_max += 1/zoom
tile_lat_min, tile_lat_max = tile_bounds['lat_min'], tile_bounds['lat_max']
tile_lon_min, tile_lon_max = tile_bounds['lon_min'], tile_bounds['lon_max']
return not (tile_lat_max < lat_min or tile_lat_min > lat_max or
tile_lon_max < lon_min or tile_lon_min > lon_max)

def read_geotiff_to_xarray(filepath: str) -> tuple[xr.Dataset, CRS]:
"""Read GeoTIFF file into an xarray Dataset."""
xarr_dataset = xr.open_dataset(filepath).sel(band=1)
crs = rasterio.open(filepath).crs
return xarr_dataset, crs

def create_map_with_geotiff_tiles(tile_metadata: list, viewport: dict, zoom: float, base_dir: str) -> go.Figure:
"""Create a map with multiple GeoTIFF tiles overlaid."""

fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style=MAP_STYLE if MAP_STYLE else "open-street-map" ,
mapbox=dict(
center=go.layout.mapbox.Center(
lat=(viewport["latitude"]["min"] + viewport["latitude"]["max"]) / 2,
lon=(viewport["longitude"]["min"] + viewport["longitude"]["max"]) / 2,
),
zoom=zoom,
),
margin={"r": 0, "t": 40, "l": 0, "b": 0},
)
print(viewport)
mapbox_layers = []
for tile in tile_metadata:
if is_tile_in_viewport(tile['bounds'], viewport, zoom=zoom):
tile_path = os.path.join(base_dir, tile['name'])
print(tile_path)
xarr_dataset, crs = read_geotiff_to_xarray(tile_path)
img, coordinates = add_raster_to_plotly_figure(xarr_dataset, crs, scale=1.0 if zoom > 8 else 0.5)
mapbox_layers.append({"sourcetype": "image", "source": img, "coordinates": coordinates})
fig.update_layout(mapbox_layers=mapbox_layers)
return fig

def add_raster_to_plotly_figure(xarr_dataset: xr.Dataset, from_crs: CRS, scale: float = 1.0) -> tuple:
"""Convert raster data to an image and coordinates for Plotly."""
# Ensure the raster has the correct CRS
xarr_dataset = xarr_dataset.rio.write_crs(from_crs).rio.reproject("EPSG:3857")
xarr_dataset = xarr_dataset.where(xarr_dataset <= 1, 0) # Mask values <= 1

# Extract the variable containing raster data ('band_data' in this case)
band_data = xarr_dataset['band_data']

numpy_data = band_data.squeeze().to_numpy() # Ensure the array is 2D
plot_height, plot_width = numpy_data.shape

canvas = ds.Canvas(plot_width=int(plot_width * scale), plot_height=int(plot_height * scale))

# Use 'band_data' to aggregate
agg = canvas.raster(band_data, interpolate="linear") # Specify the variable to aggregate

# Calculate coordinates for the image
coords_lat_min, coords_lat_max = agg.coords["y"].values.min(), agg.coords["y"].values.max()
coords_lon_min, coords_lon_max = agg.coords["x"].values.min(), agg.coords["x"].values.max()

(coords_lon_min, coords_lon_max), (coords_lat_min, coords_lat_max) = epsg3857_to_epsg4326.transform(
[coords_lon_min, coords_lon_max], [coords_lat_min, coords_lat_max]
)
st.sidebar.header("Settings")
with open(
INSTAGEO_APPS_PATH / "utils/country_code_to_mgrs_tiles.json"
) as json_file:
countries_to_tiles_map = json.load(json_file)

with st.sidebar.container():
directory = st.sidebar.text_input(
"GeoTiff Directory:",
help="Write the path to the directory containing your GeoTIFF files",
coordinates = [
[coords_lon_min, coords_lat_max],
[coords_lon_max, coords_lat_max],
[coords_lon_max, coords_lat_min],
[coords_lon_min, coords_lat_min],
]

# Generate the image using Datashader
img = tf.shade(agg, cmap=matplotlib.colormaps["Reds"], alpha=100, how="linear")[::-1].to_pil()
return img, coordinates


# Layout with Sidebar
app.layout = html.Div([
dcc.Store(id="stored-viewport", data=default_viewport),

html.Div([
html.H2("Settings", style={"textAlign": "center", "padding": "10px", "color": "white"}),
html.Label("GeoTIFF Directory:", style={"color": "white"}),
dcc.Input(id="directory", type="text", placeholder="Enter directory path",
style={"width": "97%", "marginBottom": "10px", "padding": "5px", "borderRadius": "5px"}),
html.Label("Year:", style={"color": "white"}),
dcc.Dropdown(id="year", options=[{"label": str(y), "value": y} for y in range(2023, 2025)],
value=2023, style={"width": "100%","marginBottom": "10px", "borderRadius": "5px"}),
html.Label("Month:", style={"color": "white"}),
dcc.Dropdown(id="month", options=[{"label": str(m), "value": m} for m in range(1, 13)],
value=6, style={"marginBottom": "10px"}),
], style={"width": "20%", "backgroundColor": "#343a40", "position": "fixed", "height": "100vh",
"padding": "20px", "top": 0, "left": 0, "boxShadow": "2px 0px 10px rgba(0,0,0,0.2)", "borderRadius": "0px 10px 10px 0px"}),

html.Div([
html.H1("Dash GeoTIFF Viewer", style={"textAlign": "center", "marginTop": "20px", "color": "#343a40"}),
dcc.Graph(id="map", config={"scrollZoom": True}, style={"width": "78vw", "height": "90vh", "borderRadius": "10px", "boxShadow": "0px 4px 10px rgba(0,0,0,0.2)"}),
], style={"marginLeft": "22%", "padding": "20px", "backgroundColor": "#f8f9fa", "height": "100vh"}),
])

tile_metadata_path = "instageo/apps/tile_metadata.json"
tile_metadata = load_tile_metadata(tile_metadata_path)
# Callback to update the map
@app.callback(
Output("map", "figure"),
Input("map", "relayoutData"),
Input("directory", "value"),
Input("year", "value"),
Input("month", "value"),
State("stored-viewport", "data"),
)
def update_map(relayout_data, directory, year, month, current_viewport):
"""Update the map based on viewport, zoom, and directory selection."""
if not directory or not Path(directory).is_dir():
fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style=MAP_STYLE if MAP_STYLE else "open-street-map" ,
mapbox=dict(
center=go.layout.mapbox.Center(
lat=(current_viewport["latitude"]["min"] + current_viewport["latitude"]["max"]) / 2,
lon=(current_viewport["longitude"]["min"] + current_viewport["longitude"]["max"]) / 2,
),
zoom=default_zoom,
),
margin={"r": 0, "t": 40, "l": 0, "b": 0},
)
country_code = st.sidebar.selectbox(
"ISO 3166-1 Alpha-2 Country Code:",
options=list(countries_to_tiles_map.keys()),
)
year = st.sidebar.number_input("Select Year", 2023, 2024)
month = st.sidebar.number_input("Select Month", 1, 12)

if st.sidebar.button("Generate Map"):
country_tiles = countries_to_tiles_map[country_code]
generate_map(directory, year, month, country_tiles)
return fig

zoom = relayout_data.get("mapbox.zoom", default_zoom)
base_dir = os.path.join(directory, f"{year}/{month}")

if relayout_data and "mapbox.center" in relayout_data:
relayout_data
new_viewport = {
"latitude": {
"min": relayout_data["mapbox.center"]["lat"] - 0.1,
"max": relayout_data["mapbox.center"]["lat"] + 0.1,
},
"longitude": {
"min": relayout_data["mapbox.center"]["lon"] - 0.1,
"max": relayout_data["mapbox.center"]["lon"] + 0.1,
},
}
else:
st.plotly_chart(
create_map_with_geotiff_tiles(tiles_to_overlay=[]), use_container_width=True
)


new_viewport = current_viewport
return create_map_with_geotiff_tiles(tile_metadata, new_viewport, zoom, base_dir)
if __name__ == "__main__":
main()
app.run_server(debug=True)
Loading

0 comments on commit dbc3ca9

Please sign in to comment.