diff --git a/geemap/common.py b/geemap/common.py index baaf9188ef..2316bd1b83 100644 --- a/geemap/common.py +++ b/geemap/common.py @@ -24,6 +24,7 @@ import ee import ipywidgets as widgets from ipytree import Node, Tree +from typing import Union, List, Dict, Optional, Tuple try: from IPython.display import display, IFrame, Javascript @@ -15681,3 +15682,88 @@ def geotiff_to_image(image: str, output: str) -> None: # Save the image as a JPEG file image.save(output) + + +def xarray_to_image( + xds, + filenames: Optional[Union[str, List[str]]] = None, + out_dir: Optional[str] = None, + crs: Optional[str] = None, + driver: str = "COG", + time_unit: str = "D", + quiet: bool = False, + **kwargs, +) -> None: + """ + Convert xarray Dataset to georeferenced images. + + Args: + xds (xr.Dataset): The xarray Dataset to convert to images. + filenames (Union[str, List[str]], optional): Output filenames for the images. + If a single string is provided, it will be used as the filename for all images. + If a list of strings is provided, the filenames will be used in order. Defaults to None. + out_dir (str, optional): Output directory for the images. Defaults to current working directory. + crs (str, optional): Coordinate reference system (CRS) of the output images. + If not provided, the CRS is inferred from the Dataset's attributes ('crs' attribute) or set to 'EPSG:4326'. + driver (str, optional): Driver used for writing the output images, such as 'GTiff'. Defaults to "COG". + time_unit (str, optional): Time unit used for generating default filenames. Defaults to 'D'. + quiet (bool, optional): If True, suppresses progress messages. Defaults to False. + **kwargs: Additional keyword arguments passed to rioxarray's `rio.to_raster()` function. + + Returns: + None + + Raises: + ValueError: If the number of filenames doesn't match the number of time steps in the Dataset. + + """ + import numpy as np + + try: + import rioxarray + except ImportError: + install_package("rioxarray") + import rioxarray + + if crs is None and "crs" in xds.attrs: + crs = xds.attrs["crs"] + if crs is None: + crs = "EPSG:4326" + + if out_dir is None: + out_dir = os.getcwd() + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + if isinstance(filenames, str): + filenames = [filenames] + if isinstance(filenames, list): + if len(filenames) != len(xds.time): + raise ValueError( + "The number of filenames must match the number of time steps" + ) + + coords = [coord for coord in xds.coords] + x_dim = coords[1] + y_dim = coords[2] + + for index, time in enumerate(xds.time.values): + if not quiet: + print(f"Processing {index + 1}/{len(xds.time.values)}: {time}") + image = xds.sel(time=time) + # transform the image to suit rioxarray format + image = ( + image.rename({y_dim: "y", x_dim: "x"}) + .transpose("y", "x") + .rio.write_crs(crs) + ) + + if filenames is None: + date = np.datetime_as_string(time, unit=time_unit) + filename = f"{date}.tif" + else: + filename = filenames.pop() + + output_path = os.path.join(out_dir, filename) + image.rio.to_raster(output_path, driver=driver, **kwargs)