From fb42ed095154ec2338fdef749d92a847642e740e Mon Sep 17 00:00:00 2001 From: Benjamin Schmidt Date: Thu, 11 Jul 2024 11:50:13 +0200 Subject: [PATCH] feat(types): Add type hints and code optimizations Python3.8 is almost end-of-life and backwards compatibility to versions before type hints is not a priority anymore. This commit is a start and by no means complete or correct for all instances. Furthermore I included code optimizations. These mainly regard raising of exceptions, pytest style tests and some minor refactoring. --- pyproject.toml | 2 +- salem/__init__.py | 89 +-- salem/datasets.py | 545 ++++++++++------- salem/descartes.py | 27 +- salem/gis.py | 961 ++++++++++++++++++------------ salem/graphics.py | 748 ++++++++++++++---------- salem/sio.py | 826 +++++++++++++++++--------- salem/tests/__init__.py | 63 +- salem/tests/test_datasets.py | 442 +++++++------- salem/tests/test_gis.py | 1065 +++++++++++++++++++++------------- salem/tests/test_graphics.py | 576 ++++++++++-------- salem/tests/test_misc.py | 607 ++++++++++--------- salem/utils.py | 207 +++++-- salem/version.py | 10 +- salem/wrftools.py | 272 +++++---- setup.py | 1 + 16 files changed, 3892 insertions(+), 2549 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7cb0f76..ece6630 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,4 +7,4 @@ requires = [ ] [tool.setuptools_scm] -fallback_version = "0.3.8" +fallback_version = "0.3.8" \ No newline at end of file diff --git a/salem/__init__.py b/salem/__init__.py index e6e6d93..e701c40 100644 --- a/salem/__init__.py +++ b/salem/__init__.py @@ -1,26 +1,22 @@ -""" -Salem package -""" -from __future__ import division +"""Salem package.""" -from os import path -from os import makedirs import sys +from collections.abc import Callable from functools import wraps +from pathlib import Path import pyproj from .version import __version__ -def lazy_property(fn): - """Decorator that makes a property lazy-evaluated.""" - +def lazy_property(fn: Callable) -> Callable: + """Lazy-evaluate a property (Decorator).""" attr_name = '_lazy_' + fn.__name__ @property @wraps(fn) - def _lazy_property(self): + def _lazy_property(self: object) -> object: if not hasattr(self, attr_name): setattr(self, attr_name, fn(self)) return getattr(self, attr_name) @@ -32,16 +28,13 @@ def _lazy_property(self): wgs84 = pyproj.Proj(proj='latlong', datum='WGS84') # Path to the cache directory -cache_dir = path.join(path.expanduser('~'), '.salem_cache') -if not path.exists(cache_dir): - makedirs(cache_dir) -download_dir = path.join(cache_dir, 'downloads') -if not path.exists(download_dir): - makedirs(download_dir) +cache_dir = Path.home() / '.salem_cache' +cache_dir.mkdir(exist_ok=True) +download_dir = cache_dir / 'downloads' +download_dir.mkdir(exist_ok=True) sample_data_gh_commit = '454bf696324000d198f574a1bf5bc56e3e489051' -sample_data_dir = path.join(cache_dir, 'salem-sample-data-' + - sample_data_gh_commit) +sample_data_dir = cache_dir / f'salem-sample-data-{sample_data_gh_commit}' # python version python_version = 'py3' @@ -49,27 +42,57 @@ def _lazy_property(self): python_version = 'py2' # API -from salem.gis import * -from salem.datasets import * -from salem.sio import read_shapefile, read_shapefile_to_grid, grid_from_dataset -from salem.sio import (open_xr_dataset, open_metum_dataset, - open_wrf_dataset, open_mf_wrf_dataset) -from salem.sio import DataArrayAccessor, DatasetAccessor +from salem.datasets import ( + WRF, + EsriITMIX, + GeoDataset, + GeoNetcdf, + GeoTiff, + GoogleCenterMap, + GoogleVisibleMap, +) +from salem.gis import ( + Grid, + check_crs, + googlestatic_mercator_grid, + mercator_grid, + proj_is_latlong, + proj_is_same, + proj_to_cartopy, + transform_geometry, + transform_geopandas, + transform_proj, +) +from salem.sio import ( + DataArrayAccessor, + DatasetAccessor, + grid_from_dataset, + open_metum_dataset, + open_mf_wrf_dataset, + open_wrf_dataset, + open_xr_dataset, + read_shapefile, + read_shapefile_to_grid, +) from salem.utils import get_demo_file, reduce try: - from salem.graphics import get_cmap, DataLevels, Map + from salem.graphics import DataLevels, Map, get_cmap except ImportError as err: - if 'matplotlib' not in str(err): - raise + if 'matplotlib' not in str(err): + raise + + def get_cmap() -> None: + msg = 'requires matplotlib' + raise ImportError(msg) - def get_cmap(): - raise ImportError('requires matplotlib') + def DataLevels() -> None: + msg = 'requires matplotlib' + raise ImportError(msg) - def DataLevels(): - raise ImportError('requires matplotlib') + def Map() -> None: + msg = 'requires matplotlib' + raise ImportError(msg) - def Map(): - raise ImportError('requires matplotlib') from salem.wrftools import geogrid_simulator diff --git a/salem/datasets.py b/salem/datasets.py index b5752ce..bd9b38c 100644 --- a/salem/datasets.py +++ b/salem/datasets.py @@ -1,38 +1,45 @@ -""" -This module provides a GeoDataset interface and a few implementations for +"""This module provides a GeoDataset interface and a few implementations for e.g. netcdf, geotiff, WRF... This is kept for backwards compatibility reasons, but ideally everything should soon happen at the xarray level. """ -from __future__ import division # Builtins +from __future__ import annotations + import io import os import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any from urllib.request import urlopen -# External libs -import pyproj -import numpy as np import netCDF4 +import numpy as np import pandas as pd -import xarray as xr -try: - import rasterio -except ImportError: - rasterio = None +# External libs +import pyproj +import xarray as xr +from numpy._typing import NDArray +from typing_extensions import Self # Locals -from salem import lazy_property -from salem import Grid -from salem import wgs84 -from salem import utils, gis, wrftools, sio, check_crs +from salem import gis, lazy_property, sio, utils, wgs84, wrftools +from salem.gis import Grid +from salem.utils import import_if_exists + +has_rasterio = import_if_exists('rasterio') +has_matplotlib = import_if_exists('matplotlib') + +if TYPE_CHECKING: + from datetime import datetime + from shapely.geometry import GeometryCollection -class GeoDataset(object): + +class GeoDataset: """Interface for georeferenced datasets. A GeoDataset is a formalism for gridded data arrays, which are usually @@ -45,21 +52,23 @@ class GeoDataset(object): properties. """ - def __init__(self, grid, time=None): - """Set-up the georeferencing, time is optional. - Parameters: + def __init__(self, grid: Grid, time=None) -> None: + """Set up the georeferencing, time is optional. + + Parameters + ---------- grid: a salem.Grid object which represents the underlying data time: if the data has a time dimension - """ + """ # The original grid, for always stored self._ogrid = grid # The current grid (changes if set_subset() is called) self.grid = grid # Default indexes to get in the underlying data (BOTH inclusive, # i.e [, ], not [,[ as in numpy) - self.sub_x = [0, grid.nx-1] - self.sub_y = [0, grid.ny-1] + self.sub_x = [0, grid.nx - 1] + self.sub_y = [0, grid.ny - 1] # Roi is a ny, nx array if set self.roi = None self.set_roi() @@ -75,7 +84,7 @@ def __init__(self, grid, time=None): except AttributeError: # https://github.com/pandas-dev/pandas/issues/23419 for t in time: - setattr(t, 'nanosecond', 0) + t.nanosecond = 0 time = pd.Series(np.arange(len(time)), index=time) self._time = time @@ -86,23 +95,27 @@ def __init__(self, grid, time=None): self.set_period() @property - def time(self): - """Time array""" + def time(self) -> NDArray[Any] | None: + """Time array.""" if self._time is None: return None - return self._time[self.t0:self.t1].index + return self._time[self.t0 : self.t1].index - def set_period(self, t0=0, t1=-1): + def set_period( + self, t0: str | int | datetime = 0, t1: str | int | datetime = -1 + ) -> None: """Set a period of interest for the dataset. - This will be remembered at later calls to time() or GeoDataset's - getvardata implementations. - Parameters - ---------- - t0: anything that represents a time. Could be a string (e.g - '2012-01-01'), a DateTime, or an index in the dataset's time - t1: same as t0 (inclusive) - """ + This will be remembered at later calls to time() or GeoDataset's + getvardata implementations. + + Parameters + ---------- + t0: anything that represents a time. Could be a string (e.g + '2012-01-01'), a DateTime, or an index in the dataset's time + t1: same as t0 (inclusive) + + """ if self._time is not None: self.sub_t = [0, -1] # we dont check for what t0 or t1 is, we let Pandas do the job @@ -117,76 +130,103 @@ def set_period(self, t0=0, t1=-1): self.t0 = self._time.index[self.sub_t[0]] self.t1 = self._time.index[self.sub_t[1]] - def set_subset(self, corners=None, crs=wgs84, toroi=False, margin=0): + def set_subset( + self, + corners: tuple[tuple[float, float], tuple[float, float]] | None = None, + crs: pyproj.Proj | Grid = wgs84, + margin: int = 0, + *, + toroi: bool = False, + ): """Set a subset for the dataset. - This will be remembered at later calls to GeoDataset's - getvardata implementations. - Parameters - ---------- - corners: a ((x0, y0), (x1, y1)) tuple of the corners of the square - to subset the dataset to. The coordinates are not expressed in - wgs84, set the crs keyword - crs: the coordinates of the corner coordinates - toroi: set to true to generate the smallest possible subset arond - the region of interest set with set_roi() - margin: when doing the subset, add a margin (can be negative!). Can - be used alone: set_subset(margin=-5) will remove five pixels from - each boundary of the dataset. - TODO: shouldnt we make the toroi stuff easier to use? - """ + This will be remembered at later calls to GeoDataset's + getvardata implementations. + + Parameters + ---------- + corners: a ((x0, y0), (x1, y1)) tuple of the corners of the square + to subset the dataset to. The coordinates are not expressed in + wgs84, set the crs keyword + crs: the coordinates of the corner coordinates + margin: when doing the subset, add a margin (can be negative!). Can + be used alone: set_subset(margin=-5) will remove five pixels from + each boundary of the dataset. + toroi: set to true to generate the smallest possible subset arond + the region of interest set with set_roi() + TODO: shouldnt we make the toroi stuff easier to use? + + """ # Useful variables - mx = self._ogrid.nx-1 - my = self._ogrid.ny-1 + mx = self._ogrid.nx - 1 + my = self._ogrid.ny - 1 cgrid = self._ogrid.center_grid # Three possible cases if toroi: if self.roi is None or np.max(self.roi) == 0: - raise RuntimeError('roi is empty.') + msg = 'roi is empty.' + raise RuntimeError(msg) ids = np.nonzero(self.roi) - sub_x = [np.min(ids[1])-margin, np.max(ids[1])+margin] - sub_y = [np.min(ids[0])-margin, np.max(ids[0])+margin] + sub_x = [np.min(ids[1]) - margin, np.max(ids[1]) + margin] + sub_y = [np.min(ids[0]) - margin, np.max(ids[0]) + margin] elif corners is not None: xy0, xy1 = corners x0, y0 = cgrid.transform(*xy0, crs=crs, nearest=True) x1, y1 = cgrid.transform(*xy1, crs=crs, nearest=True) - sub_x = [np.min([x0, x1])-margin, np.max([x0, x1])+margin] - sub_y = [np.min([y0, y1])-margin, np.max([y0, y1])+margin] + sub_x = [np.min([x0, x1]) - margin, np.max([x0, x1]) + margin] + sub_y = [np.min([y0, y1]) - margin, np.max([y0, y1]) + margin] else: # Reset - sub_x = [0-margin, mx+margin] - sub_y = [0-margin, my+margin] + sub_x = [0 - margin, mx + margin] + sub_y = [0 - margin, my + margin] # Some necessary checks - if (np.max(sub_x) < 0) or (np.min(sub_x) > mx) or \ - (np.max(sub_y) < 0) or (np.min(sub_y) > my): - raise RuntimeError('subset not valid') + if ( + (np.max(sub_x) < 0) + or (np.min(sub_x) > mx) + or (np.max(sub_y) < 0) + or (np.min(sub_y) > my) + ): + msg = 'subset not valid' + raise RuntimeError(msg) if (sub_x[0] < 0) or (sub_x[1] > mx): - warnings.warn('x0 out of bounds', RuntimeWarning) + warnings.warn('x0 out of bounds', RuntimeWarning, stacklevel=1) if (sub_y[0] < 0) or (sub_y[1] > my): - warnings.warn('y0 out of bounds', RuntimeWarning) + warnings.warn('y0 out of bounds', RuntimeWarning, stacklevel=1) # Make the new grid sub_x = np.clip(sub_x, 0, mx) sub_y = np.clip(sub_y, 0, my) nxny = (sub_x[1] - sub_x[0] + 1, sub_y[1] - sub_y[0] + 1) dxdy = (self._ogrid.dx, self._ogrid.dy) - xy0 = (self._ogrid.x0 + sub_x[0] * self._ogrid.dx, - self._ogrid.y0 + sub_y[0] * self._ogrid.dy) + xy0 = ( + self._ogrid.x0 + sub_x[0] * self._ogrid.dx, + self._ogrid.y0 + sub_y[0] * self._ogrid.dy, + ) self.grid = Grid(proj=self._ogrid.proj, nxny=nxny, dxdy=dxdy, x0y0=xy0) # If we arrived here, we can safely set the subset self.sub_x = sub_x self.sub_y = sub_y - def set_roi(self, shape=None, geometry=None, crs=wgs84, grid=None, - corners=None, noerase=False): + def set_roi( + self, + shape: Path | str | None = None, + geometry: GeometryCollection | None = None, + crs: pyproj.Proj | Grid = wgs84, + grid: Grid | None = None, + corners: tuple[tuple[float, float], tuple[float, float]] | None = None, + *, + noerase: bool = False, + ) -> None: """Set a region of interest for the dataset. + If set succesfully, a ROI is simply a mask of the same size as the dataset's grid, obtained with the .roi attribute. I haven't decided yet if the data should be masekd out when a ROI has been set. + Parameters ---------- shape: path to a shapefile @@ -197,8 +237,10 @@ def set_roi(self, shape=None, geometry=None, crs=wgs84, grid=None, to subset the dataset to. The coordinates are not expressed in wgs84, set the crs keyword noerase: set to true in order to add the new ROI to the previous one - """ + """ + if isinstance(shape, str): + shape = Path(shape) # The rois are always defined on the original grids, but of course # we take that into account when a subset is set (see roi # decorator below) @@ -211,24 +253,31 @@ def set_roi(self, shape=None, geometry=None, crs=wgs84, grid=None, mask = np.zeros((ogrid.ny, ogrid.nx), dtype=np.int16) # Several cases + msg = 'This feature needs rasterio' if shape is not None: if isinstance(shape, pd.DataFrame): gdf = shape else: gdf = sio.read_shapefile(shape) - gis.transform_geopandas(gdf, to_crs=ogrid.corner_grid, - inplace=True) - if rasterio is None: - raise ImportError('This feature needs rasterio') + gis.transform_geopandas( + gdf, to_crs=ogrid.corner_grid, inplace=True + ) + if not has_rasterio: + raise ImportError(msg) + import rasterio from rasterio.features import rasterize + with rasterio.Env(): mask = rasterize(gdf.geometry, out=mask) if geometry is not None: - geom = gis.transform_geometry(geometry, crs=crs, - to_crs=ogrid.corner_grid) - if rasterio is None: - raise ImportError('This feature needs rasterio') + geom = gis.transform_geometry( + geometry, crs=crs, to_crs=ogrid.corner_grid + ) + if not has_rasterio: + raise ImportError(msg) + import rasterio from rasterio.features import rasterize + with rasterio.Env(): mask = rasterize(np.atleast_1d(geom), out=mask) if grid is not None: @@ -239,117 +288,148 @@ def set_roi(self, shape=None, geometry=None, crs=wgs84, grid=None, xy0, xy1 = corners x0, y0 = cgrid.transform(*xy0, crs=crs, nearest=True) x1, y1 = cgrid.transform(*xy1, crs=crs, nearest=True) - mask[np.min([y0, y1]):np.max([y0, y1])+1, - np.min([x0, x1]):np.max([x0, x1])+1] = 1 + mask[ + np.min([y0, y1]) : np.max([y0, y1]) + 1, + np.min([x0, x1]) : np.max([x0, x1]) + 1, + ] = 1 self.roi = mask @property - def roi(self): + def roi(self) -> np.ndarray: """Mask of the ROI (same size as subset).""" - return self._roi[self.sub_y[0]:self.sub_y[1]+1, - self.sub_x[0]:self.sub_x[1]+1] + return self._roi[ + self.sub_y[0] : self.sub_y[1] + 1, + self.sub_x[0] : self.sub_x[1] + 1, + ] @roi.setter - def roi(self, value): - """A mask is allways defined on _ogrid""" + def roi(self, value: np.ndarray | None) -> None: + """Set a roi. + + A mask is always defined on _ogrid + """ self._roi = value - def get_vardata(self, var_id=None): - """Interface to implement by subclasses, taking sub_x, sub_y and - sub_t into account.""" - raise NotImplementedError() + def get_vardata(self, var_id: int | None = None) -> None: + """Implement by subclasses, taking sub_x, sub_y and sub_t into account.""" + raise NotImplementedError class GeoTiff(GeoDataset): """Geolocalised tiff images (needs rasterio).""" - def __init__(self, file): + def __init__(self, file: str | Path) -> None: """Open the file. Parameters ---------- file: path to the file + """ - if rasterio is None: - raise ImportError('This feature needs rasterio to be insalled') + if not has_rasterio: + msg = 'This feature needs rasterio to be insalled' + raise ImportError(msg) + import rasterio # brutally efficient - with rasterio.Env(): - with rasterio.open(file) as src: - nxny = (src.width, src.height) - ul_corner = (src.bounds.left, src.bounds.top) - proj = pyproj.Proj(src.crs) - dxdy = (src.res[0], -src.res[1]) - grid = Grid(x0y0=ul_corner, nxny=nxny, dxdy=dxdy, - pixel_ref='corner', proj=proj) + with rasterio.Env(), rasterio.open(file) as src: + nxny = (src.width, src.height) + ul_corner = (src.bounds.left, src.bounds.top) + proj = pyproj.Proj(src.crs) + dxdy = (src.res[0], -src.res[1]) + grid = Grid( + x0y0=ul_corner, + nxny=nxny, + dxdy=dxdy, + pixel_ref='corner', + proj=proj, + ) # done self.file = file GeoDataset.__init__(self, grid) - def get_vardata(self, var_id=1): + def get_vardata(self, var_id: int = 1) -> np.ndarray: """Read the geotiff band. Parameters ---------- var_id: the variable name (here the band number) + """ - wx = (self.sub_x[0], self.sub_x[1]+1) - wy = (self.sub_y[0], self.sub_y[1]+1) - with rasterio.Env(): - with rasterio.open(self.file) as src: - band = src.read(var_id, window=(wy, wx)) - return band + wx = (self.sub_x[0], self.sub_x[1] + 1) + wy = (self.sub_y[0], self.sub_y[1] + 1) + if not has_rasterio: + msg = 'This feature needs rasterio to be insalled' + raise ImportError(msg) + import rasterio + + with rasterio.Env(), rasterio.open(self.file) as src: + return src.read(var_id, window=(wy, wx)) class EsriITMIX(GeoDataset): """Open ITMIX geolocalised Esri ASCII images (needs rasterio).""" - def __init__(self, file): + def __init__(self, file: str | Path) -> None: """Open the file. Parameters ---------- file: path to the file - """ - bname = os.path.basename(file).split('.')[0] + """ + if not has_rasterio: + msg = 'This feature needs rasterio to be insalled' + raise ImportError(msg) + import rasterio + + if isinstance(file, str): + file = Path(file) + bname = file.name.split('.')[0] pok = bname.find('UTM') if pok == -1: - raise ValueError(file + ' does not seem to be an ITMIX file.') - zone = int(bname[pok+3:]) + raise ValueError(str(file) + ' does not seem to be an ITMIX file.') + zone = int(bname[pok + 3 :]) south = False if zone < 0: south = True zone = -zone - proj = pyproj.Proj(proj='utm', zone=zone, ellps='WGS84', - south=south) + proj = pyproj.Proj(proj='utm', zone=zone, ellps='WGS84', south=south) # brutally efficient - with rasterio.Env(): - with rasterio.open(file) as src: - nxny = (src.width, src.height) - ul_corner = (src.bounds.left, src.bounds.top) - dxdy = (src.res[0], -src.res[1]) - grid = Grid(x0y0=ul_corner, nxny=nxny, dxdy=dxdy, - pixel_ref='corner', proj=proj) + with rasterio.Env(), rasterio.open(file) as src: + nxny = (src.width, src.height) + ul_corner = (src.bounds.left, src.bounds.top) + dxdy = (src.res[0], -src.res[1]) + grid = Grid( + x0y0=ul_corner, + nxny=nxny, + dxdy=dxdy, + pixel_ref='corner', + proj=proj, + ) # done self.file = file GeoDataset.__init__(self, grid) - def get_vardata(self, var_id=1): + def get_vardata(self, var_id: int = 1) -> np.ndarray: """Read the geotiff band. Parameters ---------- var_id: the variable name (here the band number) + """ - wx = (self.sub_x[0], self.sub_x[1]+1) - wy = (self.sub_y[0], self.sub_y[1]+1) - with rasterio.Env(): - with rasterio.open(self.file) as src: - band = src.read(var_id, window=(wy, wx)) - return band + if not has_rasterio: + msg = 'This feature needs rasterio to be insalled' + raise ImportError(msg) + import rasterio + + wx = (self.sub_x[0], self.sub_x[1] + 1) + wy = (self.sub_y[0], self.sub_y[1] + 1) + with rasterio.Env(), rasterio.open(self.file) as src: + return src.read(var_id, window=(wy, wx)) class GeoNetcdf(GeoDataset): @@ -359,7 +439,14 @@ class GeoNetcdf(GeoDataset): but if it can't you can still provide the time and grid at instantiation. """ - def __init__(self, file, grid=None, time=None, monthbegin=False): + def __init__( + self, + file: str | Path, + grid: Grid | None = None, + time: pd.Series | None = None, + *, + monthbegin: bool = False, + ) -> None: """Open the file and try to understand it. Parameters @@ -372,49 +459,54 @@ def __init__(self, file, grid=None, time=None, monthbegin=False): monthbegin: set to true if you are sure that your data is monthly and that the data provider decided to tag the date as the center of the month (stupid) - """ + """ self._nc = netCDF4.Dataset(file) self._nc.set_auto_mask(False) self.variables = self._nc.variables if grid is None: grid = sio.grid_from_dataset(self._nc) if grid is None: - raise RuntimeError('File grid not understood') + msg = 'File grid not understood' + raise RuntimeError(msg) if time is None: - time = sio.netcdf_time(self._nc, monthbegin=monthbegin) + time_idx = sio.netcdf_time(self._nc, monthbegin=monthbegin) + else: + time_idx = time dn = self._nc.dimensions.keys() try: self.x_dim = utils.str_in_list(dn, utils.valid_names['x_dim'])[0] self.y_dim = utils.str_in_list(dn, utils.valid_names['y_dim'])[0] - except IndexError: - raise RuntimeError('File coordinates not understood') + except IndexError as err: + msg = 'File coordinates not understood' + raise RuntimeError(msg) from err dim = utils.str_in_list(dn, utils.valid_names['t_dim']) self.t_dim = dim[0] if dim else None dim = utils.str_in_list(dn, utils.valid_names['z_dim']) self.z_dim = dim[0] if dim else None - GeoDataset.__init__(self, grid, time=time) + GeoDataset.__init__(self, grid, time=time_idx) - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exception_type, exception_value, traceback): + def __exit__(self, exception_type, exception_value, traceback) -> None: self.close() - def close(self): + def close(self) -> None: self._nc.close() - def get_vardata(self, var_id=0, as_xarray=False): - """Reads the data out of the netCDF file while taking into account - time and spatial subsets. + def get_vardata( + self, var_id: int | str = 0, *, as_xarray: bool = False + ) -> np.ndarray: + """Read a netCDF file while taking into account time and spatial subsets. Parameters ---------- var_id: the name of the variable (must be available in self.variables) as_xarray: returns a DataArray object - """ + """ v = self.variables[var_id] # Make the slices @@ -422,11 +514,11 @@ def get_vardata(self, var_id=0, as_xarray=False): for d in v.dimensions: it = slice(None) if d == self.t_dim and self.sub_t is not None: - it = slice(self.sub_t[0], self.sub_t[1]+1) + it = slice(self.sub_t[0], self.sub_t[1] + 1) elif d == self.y_dim: - it = slice(self.sub_y[0], self.sub_y[1]+1) + it = slice(self.sub_y[0], self.sub_y[1] + 1) elif d == self.x_dim: - it = slice(self.sub_x[0], self.sub_x[1]+1) + it = slice(self.sub_x[0], self.sub_x[1] + 1) item.append(it) with np.errstate(invalid='ignore'): @@ -436,7 +528,7 @@ def get_vardata(self, var_id=0, as_xarray=False): if as_xarray: # convert to xarray dims = v.dimensions - coords = dict() + coords = {} x, y = self.grid.x_coord, self.grid.y_coord for d in dims: if d == self.t_dim: @@ -446,8 +538,13 @@ def get_vardata(self, var_id=0, as_xarray=False): elif d == self.x_dim: coords[d] = x attrs = v.__dict__.copy() - bad_keys = ['scale_factor', 'add_offset', - '_FillValue', 'missing_value', 'ncvars'] + bad_keys = [ + 'scale_factor', + 'add_offset', + '_FillValue', + 'missing_value', + 'ncvars', + ] _ = [attrs.pop(b, None) for b in bad_keys] out = xr.DataArray(out, dims=dims, coords=coords, attrs=attrs) @@ -460,8 +557,12 @@ class WRF(GeoNetcdf): Adds unstaggered and diagnostic variables. """ - def __init__(self, file, grid=None, time=None): - + def __init__( + self, + file: str | Path, + grid: Grid | None = None, + time: pd.Series | None = None, + ) -> None: GeoNetcdf.__init__(self, file, grid=grid, time=time) # Change staggered variables to unstaggered ones @@ -485,10 +586,20 @@ class GoogleCenterMap(GeoDataset): for pricing. """ - def __init__(self, center_ll=(11.38, 47.26), size_x=640, size_y=640, - scale=1, zoom=12, maptype='satellite', use_cache=True, - key=None, **kwargs): - """Initialize + def __init__( + self, + center_ll: tuple[float, float] = (11.38, 47.26), + size_x: int = 640, + size_y: int = 640, + scale: int = 1, + zoom: int = 12, + maptype: str = 'satellite', + key: str | None = None, + *, + use_cache: bool = True, + **kwargs, + ) -> None: + """Initialize a Google map instance. Parameters ---------- @@ -506,35 +617,45 @@ def __init__(self, center_ll=(11.38, 47.26), size_x=640, size_y=640, static-maps/intro#Zoomlevels). 1 (world) - 20 (buildings) maptype : str, default: 'satellite' 'roadmap', 'satellite', 'hybrid', 'terrain' - use_cache : bool, default: True - store the downloaded image in the cache to avoid future downloads key : str, default: None Google API key. If None, it will try to read it from the environment variable STATIC_MAP_API_KEY + use_cache : bool, default: True + store the downloaded image in the cache to avoid future downloads kwargs : ** any keyword accepted by motionless.CenterMap - """ + """ # Google grid - grid = gis.googlestatic_mercator_grid(center_ll=center_ll, - nx=size_x, ny=size_y, - zoom=zoom, scale=scale) + grid = gis.googlestatic_mercator_grid( + center_ll=center_ll, nx=size_x, ny=size_y, zoom=zoom, scale=scale + ) if key is None: try: key = os.environ['STATIC_MAP_API_KEY'] - except KeyError: - raise ValueError('You need to provide a Google API key' - ' or set the STATIC_MAP_API_KEY environment' - ' variable.') + except KeyError as err: + msg = ( + 'You need to provide a Google API key' + ' or set the STATIC_MAP_API_KEY environment' + ' variable.' + ) + raise ValueError(msg) from err # Motionless import motionless - googleurl = motionless.CenterMap(lon=center_ll[0], lat=center_ll[1], - size_x=size_x, size_y=size_y, - maptype=maptype, zoom=zoom, - scale=scale, key=key, - **kwargs) + + googleurl = motionless.CenterMap( + lon=center_ll[0], + lat=center_ll[1], + size_x=size_x, + size_y=size_y, + maptype=maptype, + zoom=zoom, + scale=scale, + key=key, + **kwargs, + ) # done self.googleurl = googleurl @@ -542,19 +663,29 @@ def __init__(self, center_ll=(11.38, 47.26), size_x=640, size_y=640, GeoDataset.__init__(self, grid) @lazy_property - def _img(self): + def _img(self) -> np.ndarray: """Download the image.""" + if not has_matplotlib: + msg = 'This feature needs matplotlib to be insalled' + raise ImportError(msg) if self.use_cache: return utils.joblib_read_img_url(self.googleurl.generate_url()) - else: - from matplotlib.image import imread - fd = urlopen(self.googleurl.generate_url()) + from matplotlib.image import imread + + url = self.googleurl.generate_url() + if not url.startswith(('http:', 'https:')): + msg = "URL must start with 'http:' or 'https:'" + raise ValueError(msg) + with urlopen(url) as fd: return imread(io.BytesIO(fd.read())) - def get_vardata(self, var_id=0): + def get_vardata(self, var_id: int = 0) -> np.ndarray: """Return and subset the image.""" - return self._img[self.sub_y[0]:self.sub_y[1]+1, - self.sub_x[0]:self.sub_x[1]+1, :] + return self._img[ + self.sub_y[0] : self.sub_y[1] + 1, + self.sub_x[0] : self.sub_x[1] + 1, + :, + ] class GoogleVisibleMap(GoogleCenterMap): @@ -566,9 +697,21 @@ class GoogleVisibleMap(GoogleCenterMap): for pricing. """ - def __init__(self, x, y, crs=wgs84, size_x=640, size_y=640, scale=1, - maptype='satellite', use_cache=True, key=None, **kwargs): - """Initialize + def __init__( + self, + x: np.ndarray, + y: np.ndarray, + crs: pyproj.Proj | Grid = wgs84, + size_x: int = 640, + size_y: int = 640, + scale: int = 1, + maptype: str = 'satellite', + key: str | None = None, + *, + use_cache: bool = True, + **kwargs, + ) -> None: + """Initialize a Google Visible Map instance. Parameters ---------- @@ -587,11 +730,11 @@ def __init__(self, x, y, crs=wgs84, size_x=640, size_y=640, scale=1, longer to download maptype : str, default: 'satellite' 'roadmap', 'satellite', 'hybrid', 'terrain' - use_cache : bool, default: True - store the downloaded image in the cache to avoid future downloads key : str, default: None Google API key. If None, it will try to read it from the environment variable STATIC_MAP_API_KEY + use_cache : bool, default: True + store the downloaded image in the cache to avoid future downloads kwargs : ** any keyword accepted by motionless.CenterMap @@ -599,18 +742,22 @@ def __init__(self, x, y, crs=wgs84, size_x=640, size_y=640, scale=1, ----- To obtain the exact domain specified in `x` and `y` you may have to play with the `size_x` and `size_y` kwargs. - """ + """ if key is None: try: key = os.environ['STATIC_MAP_API_KEY'] - except KeyError: - raise ValueError('You need to provide a Google API key' - ' or set the STATIC_MAP_API_KEY environment' - ' variable.') + except KeyError as err: + msg = ( + 'You need to provide a Google API key' + ' or set the STATIC_MAP_API_KEY environment' + ' variable.' + ) + raise ValueError(msg) from err if 'zoom' in kwargs or 'center_ll' in kwargs: - raise ValueError('incompatible kwargs.') + msg = 'incompatible kwargs.' + raise ValueError(msg) # Transform to lonlat crs = gis.check_crs(crs) @@ -619,22 +766,30 @@ def __init__(self, x, y, crs=wgs84, size_x=640, size_y=640, scale=1, elif isinstance(crs, Grid): lon, lat = crs.ij_to_crs(x, y, crs=wgs84) else: - raise NotImplementedError() + raise NotImplementedError # surely not the smartest way to do but should be enough for now mc = (np.mean(lon), np.mean(lat)) zoom = 20 while zoom >= 0: - grid = gis.googlestatic_mercator_grid(center_ll=mc, nx=size_x, - ny=size_y, zoom=zoom, - scale=scale) + grid = gis.googlestatic_mercator_grid( + center_ll=mc, nx=size_x, ny=size_y, zoom=zoom, scale=scale + ) dx, dy = grid.transform(lon, lat, maskout=True) if np.any(dx.mask): zoom -= 1 else: break - GoogleCenterMap.__init__(self, center_ll=mc, size_x=size_x, - size_y=size_y, zoom=zoom, scale=scale, - maptype=maptype, use_cache=use_cache, - key=key, **kwargs) + GoogleCenterMap.__init__( + self, + center_ll=mc, + size_x=size_x, + size_y=size_y, + zoom=zoom, + scale=scale, + maptype=maptype, + use_cache=use_cache, + key=key, + **kwargs, + ) diff --git a/salem/descartes.py b/salem/descartes.py index 65a06f0..81c8c11 100644 --- a/salem/descartes.py +++ b/salem/descartes.py @@ -1,4 +1,4 @@ -"""Paths and patches +"""Paths and patches. This file is part of the package "descartes" by sgilles, apparently discontinued today. @@ -13,7 +13,7 @@ from numpy import asarray, concatenate, ones -class Polygon(object): +class Polygon: # Adapt Shapely or GeoJSON/geo_interface polygons to a common interface def __init__(self, context): if hasattr(context, 'interiors'): @@ -23,13 +23,14 @@ def __init__(self, context): @property def geom_type(self): - return (getattr(self.context, 'geom_type', None) - or self.context['type']) + return getattr(self.context, 'geom_type', None) or self.context['type'] @property def exterior(self): - return (getattr(self.context, 'exterior', None) - or self.context['coordinates'][0]) + return ( + getattr(self.context, 'exterior', None) + or self.context['coordinates'][0] + ) @property def interiors(self): @@ -41,9 +42,12 @@ def interiors(self): def PolygonPath(polygon): """Constructs a compound matplotlib path from a Shapely or GeoJSON-like - geometric object""" + geometric object + """ this = Polygon(polygon) - assert this.geom_type == 'Polygon' + if this.geom_type != 'Polygon': + msg = 'Not a Polygon: {}'.format(this.geom_type) + raise ValueError(msg) def coding(ob): # The codes will be all "LINETO" commands, except for "MOVETO"s at the @@ -55,10 +59,11 @@ def coding(ob): vertices = concatenate( [asarray(this.exterior.coords)[:, :2]] - + [asarray(r.coords)[:, :2] for r in this.interiors]) + + [asarray(r.coords)[:, :2] for r in this.interiors] + ) codes = concatenate( - [coding(this.exterior)] - + [coding(r) for r in this.interiors]) + [coding(this.exterior)] + [coding(r) for r in this.interiors] + ) return Path(vertices, codes) diff --git a/salem/gis.py b/salem/gis.py index 85b91ae..6efdfac 100644 --- a/salem/gis.py +++ b/salem/gis.py @@ -1,40 +1,62 @@ -""" -Projections and grids -""" +"""Projections and grids.""" + # Python 2 stuff -from __future__ import division +from __future__ import annotations # Builtins +import contextlib import copy import warnings from functools import partial -from packaging.version import Version +from typing import TYPE_CHECKING, Any, Callable + +import numpy as np # External libs import pyproj -import numpy as np -from scipy.interpolate import RegularGridInterpolator, RectBivariateSpline +import xarray as xr +from packaging.version import Version +from scipy.interpolate import RectBivariateSpline, RegularGridInterpolator -try: +# Locals +from salem import lazy_property, wgs84 +from salem.utils import deprecated_arg, import_if_exists + +has_cartopy = import_if_exists('cartopy') +if TYPE_CHECKING: + from pathlib import Path + + import geopandas as gpd + from numpy._typing import NDArray + from pyproj.transformer import AreaOfInterest + from shapely.geometry.base import BaseGeometry + + if has_cartopy: + import cartopy.crs as ccrs + +has_gdal = import_if_exists('osgeo') +if has_gdal: from osgeo import osr + osr.UseExceptions() - has_gdal = True -except ImportError: - has_gdal = False -# Locals -from salem import lazy_property, wgs84 try: crs_type = pyproj.crs.CRS except AttributeError: - class Dummy(): + + class Dummy: pass + crs_type = Dummy -def check_crs(crs, raise_on_error=False): - """Checks if the crs represents a valid grid, projection or ESPG string. +def check_crs( + crs: pyproj.Proj | Grid | str | xr.DataArray, + *, + raise_on_error: bool = False, +) -> pyproj.Proj | Grid | None: + """Check if the crs represents a valid grid, projection or ESPG string. Examples -------- @@ -48,20 +70,19 @@ def check_crs(crs, raise_on_error=False): Returns ------- A valid crs if possible, otherwise None - """ - try: - crs = crs.salem.grid # try xarray - except: - pass + """ + if isinstance(crs, (xr.DataArray, xr.Dataset)): + with contextlib.suppress(Exception): + crs = crs.salem.grid # try xarray err1, err2 = None, None - if isinstance(crs, pyproj.Proj) or isinstance(crs, Grid): + if isinstance(crs, (pyproj.Proj, Grid)): out = crs elif isinstance(crs, crs_type): out = pyproj.Proj(crs.to_wkt(), preserve_units=True) - elif isinstance(crs, dict) or isinstance(crs, str): + elif isinstance(crs, (dict, str)): if isinstance(crs, str): # quick fix for https://github.com/pyproj4/pyproj/issues/345 crs = crs.replace(' ', '').replace('+', ' +') @@ -83,12 +104,14 @@ def check_crs(crs, raise_on_error=False): out = None if raise_on_error and out is None: - msg = ('salem could not properly parse the provided coordinate ' - 'reference system (crs). This could be due to errors in your ' - 'data, in PyProj, or with salem itself. If this occurs ' - 'unexpectedly, report an issue to https://github.com/fmaussion/' - 'salem/issues. Full log: \n' - 'crs: {} ; \n'.format(crs)) + msg = ( + 'salem could not properly parse the provided coordinate ' + 'reference system (crs). This could be due to errors in your ' + 'data, in PyProj, or with salem itself. If this occurs ' + 'unexpectedly, report an issue to https://github.com/fmaussion/' + 'salem/issues. Full log: \n' + f'crs: {crs} ; \n' + ) if err1 is not None: msg += 'Output of `pyproj.Proj(crs, preserve_units=True)`: {} ; \n' msg = msg.format(err1) @@ -100,7 +123,7 @@ def check_crs(crs, raise_on_error=False): return out -class Grid(object): +class Grid: """A structured grid on a map projection. Central class in the library, taking over user concerns about the @@ -136,7 +159,6 @@ class Grid(object): Attributes ---------- - proj nx ny @@ -157,12 +179,22 @@ class Grid(object): center_grid corner_grid extent + """ - def __init__(self, proj=wgs84, nxny=None, dxdy=None, x0y0=None, - pixel_ref='center', - corner=None, ul_corner=None, ll_corner=None): - """ + def __init__( + self, + proj: pyproj.Proj | Grid | None = wgs84, + nxny: tuple[int, int] | None = None, + dxdy: tuple[float, float] | None = None, + x0y0: tuple[float, float] | None = None, + pixel_ref: str = 'center', + corner: tuple[float, float] | None = None, + ul_corner: tuple[float, float] | None = None, + ll_corner: tuple[float, float] | None = None, + ) -> None: + """Initialize a Grid object. + Parameters ---------- proj : pyproj.Proj instance @@ -214,70 +246,82 @@ def __init__(self, proj=wgs84, nxny=None, dxdy=None, x0y0=None, [ 0.5, 0.5, 0.5]]) >>> g.corner_grid == g.center_grid # the two reprs are equivalent True - """ + """ # Check for coordinate system proj = check_crs(proj) if proj is None: - raise ValueError('proj must be of type pyproj.Proj') + msg = 'proj should not be None' + raise TypeError(msg) + self._proj = proj - # deprecations if corner is not None: - warnings.warn('The `corner` kwarg is deprecated: ' - 'use `x0y0` instead.', DeprecationWarning) + deprecated_arg( + 'The `corner` kwarg is deprecated: use `x0y0` instead.' + ) x0y0 = corner if ul_corner is not None: - warnings.warn('The `ul_corner` kwarg is deprecated: ' - 'use `x0y0` instead.', DeprecationWarning) - if dxdy[1] > 0.: - raise ValueError('dxdy and input params not compatible') + deprecated_arg( + 'The `ul_corner` kwarg is deprecated: use `x0y0` instead.' + ) + if dxdy is not None and dxdy[1] > 0.0: + msg = 'dxdy and input params not compatible' + raise ValueError(msg) x0y0 = ul_corner if ll_corner is not None: - warnings.warn('The `ll_corner` kwarg is deprecated: ' - 'use `x0y0` instead.', DeprecationWarning) - if dxdy[1] < 0.: - raise ValueError('dxdy and input params not compatible') + deprecated_arg( + 'The `ll_corner` kwarg is deprecated: use `x0y0` instead.' + ) + if dxdy is not None and dxdy[1] < 0.0: + msg = 'dxdy and input params not compatible' + raise ValueError(msg) x0y0 = ll_corner # Check for shortcut - if dxdy[1] < 0.: + if dxdy is not None and dxdy[1] < 0.0: ul_corner = x0y0 else: ll_corner = x0y0 # Initialise the rest - self._check_input(nxny=nxny, dxdy=dxdy, - ul_corner=ul_corner, - ll_corner=ll_corner, - pixel_ref=pixel_ref) + self._check_input( + nxny=nxny, + dxdy=dxdy, + ul_corner=ul_corner, + ll_corner=ll_corner, + pixel_ref=pixel_ref, + ) - def _check_input(self, **kwargs): + def _check_input(self, **kwargs) -> None: """See which parameter combination we have and set everything.""" - combi_a = ['nxny', 'dxdy', 'ul_corner'] combi_b = ['nxny', 'dxdy', 'll_corner'] if all(kwargs[k] is not None for k in combi_a): nx, ny = kwargs['nxny'] dx, dy = kwargs['dxdy'] x0, y0 = kwargs['ul_corner'] - if (dx <= 0.) or (dy >= 0.): - raise ValueError('dxdy and input params not compatible') + if (dx <= 0.0) or (dy >= 0.0): + msg = 'dxdy and input params not compatible' + raise ValueError(msg) origin = 'upper-left' elif all(kwargs[k] is not None for k in combi_b): nx, ny = kwargs['nxny'] dx, dy = kwargs['dxdy'] x0, y0 = kwargs['ll_corner'] - if (dx <= 0.) or (dy <= 0.): - raise ValueError('dxdy and input params not compatible') + if (dx <= 0.0) or (dy <= 0.0): + msg = 'dxdy and input params not compatible' + raise ValueError(msg) origin = 'lower-left' else: - raise ValueError('Input params not compatible') + msg = 'Input params not compatible' + raise ValueError(msg) self._nx = int(nx) self._ny = int(ny) if (self._nx <= 0) or (self._ny <= 0): - raise ValueError('nxny not valid') + msg = 'nxny not valid' + raise ValueError(msg) self._dx = float(dx) self._dy = float(dy) self._x0 = float(x0) @@ -287,10 +331,13 @@ def _check_input(self, **kwargs): # Check for pixel ref self._pixel_ref = kwargs['pixel_ref'].lower() if self._pixel_ref not in ['corner', 'center']: - raise ValueError('pixel_ref not recognized') + msg = 'pixel_ref not recognized' + raise ValueError(msg) - def __eq__(self, other): - """Two grids are considered equal when their defining coordinates + def __eq__(self, other: Grid) -> bool: + """Check equality with another grid. + + Two grids are considered equal when their defining coordinates and projection are equal. Note: equality also means floating point equality, with all the @@ -298,17 +345,23 @@ def __eq__(self, other): (independent of the grid's cornered or centered representation.) """ - # Attributes defining the instance ckeys = ['x0', 'y0', 'nx', 'ny', 'dx', 'dy', 'origin'] - a = dict((k, getattr(self.corner_grid, k)) for k in ckeys) - b = dict((k, getattr(other.corner_grid, k)) for k in ckeys) + a = {k: getattr(self.corner_grid, k) for k in ckeys} + b = {k: getattr(other.corner_grid, k) for k in ckeys} p1 = self.corner_grid.proj p2 = other.corner_grid.proj return (a == b) and proj_is_same(p1, p2) - def __repr__(self): + def __repr__(self) -> str: + """Print a string representation of the grid. + + Returns + ------- + a string + + """ srs = '+'.join(sorted(self.proj.srs.split('+'))).strip() summary = [''] summary += [' proj: ' + srs] @@ -320,201 +373,193 @@ def __repr__(self): return '\n'.join(summary) + '\n' @property - def proj(self): + def proj(self) -> pyproj.Proj: """``pyproj.Proj`` instance defining the grid's map projection.""" return self._proj @property - def nx(self): - """number of grid points in the x direction.""" + def nx(self) -> int: + """Number of grid points in the x direction.""" return self._nx @property - def ny(self): - """number of grid points in the y direction.""" + def ny(self) -> int: + """Number of grid points in the y direction.""" return self._ny @property - def dx(self): - """x grid spacing (always positive).""" + def dx(self) -> float: + """X grid spacing (always positive).""" return self._dx @property - def dy(self): - """y grid spacing (positive if ll_corner, negative if ul_corner).""" + def dy(self) -> float: + """Y grid spacing (positive if ll_corner, negative if ul_corner).""" return self._dy @property - def x0(self): + def x0(self) -> float: """X reference point in projection coordinates.""" return self._x0 @property - def y0(self): + def y0(self) -> float: """Y reference point in projection coordinates.""" return self._y0 @property - def origin(self): + def origin(self) -> str: """``'upper-left'`` or ``'lower-left'``.""" return self._origin @property - def pixel_ref(self): - """if coordinates are at the ``'center'`` or ``'corner'`` of the grid. - """ + def pixel_ref(self) -> str: + """If coordinates are at the ``'center'`` or ``'corner'`` of the grid.""" return self._pixel_ref @lazy_property - def center_grid(self): - """``salem.Grid`` instance representing the grid in center coordinates. - """ - + def center_grid(self) -> Grid: + """``salem.Grid`` instance representing the grid in center coordinates.""" if self.pixel_ref == 'center': return self - else: - # shift the grid - x0y0 = ((self.x0 + self.dx / 2.), (self.y0 + self.dy / 2.)) - args = dict(nxny=(self.nx, self.ny), dxdy=(self.dx, self.dy), - proj=self.proj, pixel_ref='center', x0y0=x0y0) - return Grid(**args) + # shift the grid + x0y0 = ((self.x0 + self.dx / 2.0), (self.y0 + self.dy / 2.0)) + args = { + 'nxny': (self.nx, self.ny), + 'dxdy': (self.dx, self.dy), + 'proj': self.proj, + 'pixel_ref': 'center', + 'x0y0': x0y0, + } + return Grid(**args) @lazy_property - def corner_grid(self): - """``salem.Grid`` instance representing the grid in corner coordinates. - """ - + def corner_grid(self) -> Grid: + """``salem.Grid`` instance representing the grid in corner coordinates.""" if self.pixel_ref == 'corner': return self - else: - # shift the grid - x0y0 = ((self.x0 - self.dx / 2.), (self.y0 - self.dy / 2.)) - args = dict(nxny=(self.nx, self.ny), dxdy=(self.dx, self.dy), - proj=self.proj, pixel_ref='corner', x0y0=x0y0) - return Grid(**args) + # shift the grid + x0y0 = ((self.x0 - self.dx / 2.0), (self.y0 - self.dy / 2.0)) + args = { + 'nxny': (self.nx, self.ny), + 'dxdy': (self.dx, self.dy), + 'proj': self.proj, + 'pixel_ref': 'corner', + 'x0y0': x0y0, + } + return Grid(**args) @property - def ij_coordinates(self): + def ij_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of i, j coordinates of the grid points. (dependent of the grid's cornered or centered representation.) """ - x = np.arange(self.nx) y = np.arange(self.ny) return np.meshgrid(x, y) @property - def x_coord(self): - """x coordinates of the grid points (1D, no mesh)""" - + def x_coord(self) -> NDArray[Any]: + """X coordinates of the grid points (1D, no mesh).""" return self.x0 + np.arange(self.nx) * self.dx @property - def y_coord(self): - """y coordinates of the grid points (1D, no mesh)""" - + def y_coord(self) -> NDArray[Any]: + """Y coordinates of the grid points (1D, no mesh).""" return self.y0 + np.arange(self.ny) * self.dy @property - def xy_coordinates(self): + def xy_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of x, y coordinates of the grid points. (dependent of the grid's cornered or centered representation.) """ - return np.meshgrid(self.x_coord, self.y_coord) @lazy_property - def ll_coordinates(self): + def ll_coordinates(self) -> tuple[np.ndarray, np.ndarray]: """Tuple of longitudes, latitudes of the grid points. (dependent of the grid's cornered or centered representation.) """ - x, y = self.xy_coordinates proj_out = check_crs('EPSG:4326') return transform_proj(self.proj, proj_out, x, y) @property - def xstagg_xy_coordinates(self): + def xstagg_xy_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of x, y coordinates of the X staggered grid. (independent of the grid's cornered or centered representation.) """ - - x_s = self.corner_grid.x0 + np.arange(self.nx+1) * self.dx + x_s = self.corner_grid.x0 + np.arange(self.nx + 1) * self.dx y = self.center_grid.y0 + np.arange(self.ny) * self.dy return np.meshgrid(x_s, y) @property - def ystagg_xy_coordinates(self): + def ystagg_xy_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of x, y coordinates of the Y staggered grid. (independent of the grid's cornered or centered representation.) """ - x = self.center_grid.x0 + np.arange(self.nx) * self.dx - y_s = self.corner_grid.y0 + np.arange(self.ny+1) * self.dy + y_s = self.corner_grid.y0 + np.arange(self.ny + 1) * self.dy return np.meshgrid(x, y_s) @lazy_property - def xstagg_ll_coordinates(self): + def xstagg_ll_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of longitudes, latitudes of the X staggered grid. (independent of the grid's cornered or centered representation.) """ - x, y = self.xstagg_xy_coordinates proj_out = check_crs('EPSG:4326') return transform_proj(self.proj, proj_out, x, y) @lazy_property - def ystagg_ll_coordinates(self): + def ystagg_ll_coordinates(self) -> tuple[NDArray[Any], ...]: """Tuple of longitudes, latitudes of the Y staggered grid. (independent of the grid's cornered or centered representation.) """ - x, y = self.ystagg_xy_coordinates proj_out = check_crs('EPSG:4326') return transform_proj(self.proj, proj_out, x, y) @lazy_property - def pixcorner_ll_coordinates(self): - """Tuple of longitudes, latitudes (dims: ny+1, nx+1) at the corners of - the grid. + def pixcorner_ll_coordinates(self) -> tuple[NDArray[Any], ...]: + """Tuple of lons, lats (dims: ny+1, nx+1) at the corners of the grid. Useful for graphics.Map essentially (independant of the grid's cornered or centered representation.) """ - - x = self.corner_grid.x0 + np.arange(self.nx+1) * self.dx - y = self.corner_grid.y0 + np.arange(self.ny+1) * self.dy + x = self.corner_grid.x0 + np.arange(self.nx + 1) * self.dx + y = self.corner_grid.y0 + np.arange(self.ny + 1) * self.dy x, y = np.meshgrid(x, y) proj_out = check_crs('EPSG:4326') return transform_proj(self.proj, proj_out, x, y) @lazy_property - def extent(self): - """[left, right, bottom, top] boundaries of the grid in the grid's - projection. + def extent(self) -> list[float]: + """[left, right, bottom, top] boundaries of the grid in the grid's projection. The boundaries are the pixels leftmost, rightmost, lowermost and uppermost corners, meaning that they are independent from the grid's representation. """ - x = np.array([0, self.nx]) * self.dx + self.corner_grid.x0 ypoint = [0, self.ny] if self.origin == 'lower-left' else [self.ny, 0] y = np.array(ypoint) * self.dy + self.corner_grid.y0 return [x[0], x[1], y[0], y[1]] - def almost_equal(self, other, rtol=1e-05, atol=1e-08): - """A less strict comparison between grids. + def almost_equal( + self, other: Grid, rtol: float = 1e-05, atol: float = 1e-08 + ) -> bool: + """Compare with another grid (less strictly). Two grids are considered equal when their defining coordinates and projection are equal. @@ -523,7 +568,6 @@ def almost_equal(self, other, rtol=1e-05, atol=1e-08): (independent of the grid's cornered or centered representation.) """ - # float attributes defining the instance fkeys = ['x0', 'y0', 'dx', 'dy'] # unambiguous attributes @@ -531,9 +575,12 @@ def almost_equal(self, other, rtol=1e-05, atol=1e-08): ok = True for k in fkeys: - ok = ok and np.isclose(getattr(self.corner_grid, k), - getattr(other.corner_grid, k), - rtol=rtol, atol=atol) + ok = ok and np.isclose( + getattr(self.corner_grid, k), + getattr(other.corner_grid, k), + rtol=rtol, + atol=atol, + ) for k in ckeys: _ok = getattr(self.corner_grid, k) == getattr(other.corner_grid, k) ok = ok and _ok @@ -541,7 +588,7 @@ def almost_equal(self, other, rtol=1e-05, atol=1e-08): p2 = other.corner_grid.proj return ok and proj_is_same(p1, p2) - def extent_in_crs(self, crs=wgs84): + def extent_in_crs(self, crs: pyproj.Proj = wgs84) -> list[float]: """Get the extent of the grid in a desired crs. Parameters @@ -552,15 +599,17 @@ def extent_in_crs(self, crs=wgs84): Returns ------- [left, right, bottom, top] boundaries of the grid. - """ + """ # this is not so trivial # for optimisation we will transform the boundaries only poly = self.extent_as_polygon(crs=crs) _i, _j = poly.exterior.xy return [np.min(_i), np.max(_i), np.min(_j), np.max(_j)] - def extent_as_polygon(self, crs=wgs84): + def extent_as_polygon( + self, crs: pyproj.Proj | Grid | None = wgs84 + ) -> BaseGeometry: """Get the extent of the grid in a shapely.Polygon and desired crs. Parameters @@ -571,23 +620,34 @@ def extent_as_polygon(self, crs=wgs84): Returns ------- [left, right, bottom, top] boundaries of the grid. + """ from shapely.geometry import Polygon # this is not so trivial # for optimisation we will transform the boundaries only - _i = np.hstack([np.arange(self.nx+1), - np.ones(self.ny+1)*self.nx, - np.arange(self.nx+1)[::-1], - np.zeros(self.ny+1)]).flatten() - _j = np.hstack([np.zeros(self.nx+1), - np.arange(self.ny+1), - np.ones(self.nx+1)*self.ny, - np.arange(self.ny+1)[::-1]]).flatten() + _i = np.hstack( + [ + np.arange(self.nx + 1), + np.ones(self.ny + 1) * self.nx, + np.arange(self.nx + 1)[::-1], + np.zeros(self.ny + 1), + ] + ).flatten() + _j = np.hstack( + [ + np.zeros(self.nx + 1), + np.arange(self.ny + 1), + np.ones(self.nx + 1) * self.ny, + np.arange(self.ny + 1)[::-1], + ] + ).flatten() _i, _j = self.corner_grid.ij_to_crs(_i, _j, crs=crs) return Polygon(zip(_i, _j)) - def regrid(self, nx=None, ny=None, factor=1): + def regrid( + self, nx: int | None = None, ny: int | None = None, factor: float = 1 + ) -> Grid: """Make a copy of the grid with an updated spatial resolution. The keyword parameters are mutually exclusive, because the x/y ratio @@ -597,7 +657,7 @@ def regrid(self, nx=None, ny=None, factor=1): ---------- nx : int the new number of x pixels - nx : int + ny : int the new number of y pixels factor : int multiplication factor (factor=3 will generate a grid with @@ -606,29 +666,44 @@ def regrid(self, nx=None, ny=None, factor=1): Returns ------- a new Grid object. - """ + """ + if nx is not None and ny is not None and factor: + msg = 'You cannot specify both `nx/ny` and `factor`' + raise ValueError(msg) if nx is not None: factor = nx / self.nx if ny is not None: factor = ny / self.ny - nx = self.nx * factor - ny = self.ny * factor + nnx = self.nx * factor + nny = self.ny * factor dx = self.dx / factor dy = self.dy / factor x0 = self.corner_grid.x0 y0 = self.corner_grid.y0 - args = dict(nxny=(nx, ny), dxdy=(dx, dy), x0y0=(x0, y0), - proj=self.proj, pixel_ref='corner') + args = { + 'nxny': (nnx, nny), + 'dxdy': (dx, dy), + 'x0y0': (x0, y0), + 'proj': self.proj, + 'pixel_ref': 'corner', + } g = Grid(**args) if self.pixel_ref == 'center': g = g.center_grid return g - def ij_to_crs(self, i, j, crs=None, nearest=False): - """Converts local i, j to cartesian coordinates in a specified crs + def ij_to_crs( + self, + i: NDArray[Any], + j: NDArray[Any], + crs: pyproj.Proj | None = None, + *, + nearest: bool = False, + ) -> tuple[NDArray[Any], NDArray[Any]]: + """Convert local i, j to cartesian coordinates in a specified crs. Parameters ---------- @@ -644,8 +719,8 @@ def ij_to_crs(self, i, j, crs=None, nearest=False): Returns ------- (x, y) coordinates of the points in the specified crs. - """ + """ # Default if crs is None: crs = self.proj @@ -664,10 +739,24 @@ def ij_to_crs(self, i, j, crs=None, nearest=False): ret = transform_proj(self.proj, _crs, x, y) elif isinstance(_crs, Grid): ret = _crs.transform(x, y, crs=self.proj, nearest=nearest) + else: + msg = 'crs must be a pyproj.Proj or salem.Grid, not {}'.format( + type(crs) + ) + raise TypeError(msg) return ret - def transform(self, x, y, z=None, crs=wgs84, nearest=False, maskout=False): - """Converts any coordinates into the local grid. + def transform( + self, + x: NDArray[Any], + y: NDArray[Any], + z: NDArray[Any] | None = None, + crs: pyproj.Proj | Grid | str | None = wgs84, + *, + nearest: bool = False, + maskout: bool = False, + ) -> tuple[NDArray[Any], NDArray[Any]]: + """Convert any coordinates into the local grid. Parameters ---------- @@ -691,10 +780,13 @@ def transform(self, x, y, z=None, crs=wgs84, nearest=False, maskout=False): Returns ------- (i, j) coordinates of the points in the local grid. - """ + """ x, y = np.ma.array(x), np.ma.array(y) + if crs is None: + msg = 'crs must be a pyproj.Proj or salem.Grid, not None' + raise ValueError(msg) # First to local proj _crs = check_crs(crs, raise_on_error=True) if isinstance(_crs, pyproj.Proj): @@ -715,18 +807,22 @@ def transform(self, x, y, z=None, crs=wgs84, nearest=False, maskout=False): # Mask? if maskout: if self.pixel_ref == 'center': - mask = ~((x >= -0.5) & (x < self.nx-0.5) & - (y >= -0.5) & (y < self.ny-0.5)) + dist = -0.5 + mask = ~( + (x >= dist) + & (x < self.nx + dist) + & (y >= dist) + & (y < self.ny + dist) + ) else: - mask = ~((x >= 0) & (x < self.nx) & - (y >= 0) & (y < self.ny)) + mask = ~((x >= 0) & (x < self.nx) & (y >= 0) & (y < self.ny)) x = np.ma.array(x, mask=mask) y = np.ma.array(y, mask=mask) return x, y - def grid_lookup(self, other): - """Performs forward transformation of any other grid into self. + def grid_lookup(self, other: Grid) -> dict[tuple[int, int], NDArray[Any]]: + """Perform forward transformation of any other grid into self. The principle of forward transform is to obtain, for each grid point of ``self`` , all the indices of ``other`` that are located into the @@ -745,26 +841,28 @@ def grid_lookup(self, other): a dict: each key (j, i) contains an array of shape (n, 2) where n is the number of ``other`` 's grid points found within the grid point (j, i) - """ + """ # Input checks - other = check_crs(other) - if not isinstance(other, Grid): - raise ValueError('`other` should be a Grid instance') + _other = check_crs(other) + if not isinstance(_other, Grid): + msg = '`other` should be a Grid instance' + raise TypeError(msg) # Transform the other grid into the local grid (forward transform) # Work in center grid cause that's what we need - i, j = other.center_grid.ij_coordinates + i, j = _other.center_grid.ij_coordinates i, j = i.flatten(), j.flatten() - oi, oj = self.center_grid.transform(i, j, crs=other.center_grid, - nearest=True, maskout=True) + oi, oj = self.center_grid.transform( + i, j, crs=_other.center_grid, nearest=True, maskout=True + ) # keep only valid values oi, oj, i, j = oi[~oi.mask], oj[~oi.mask], i[~oi.mask], j[~oi.mask] out_inds = oi.flatten() + self.nx * oj.flatten() # find the links - ris = np.digitize(out_inds, bins=np.arange(self.nx*self.ny+1)) + ris = np.digitize(out_inds, bins=np.arange(self.nx * self.ny + 1)) # some optim based on the fact that ris has many duplicates sort_idx = np.argsort(ris) @@ -772,15 +870,22 @@ def grid_lookup(self, other): unq_idx = np.split(sort_idx, np.cumsum(unq_count)) # lets go - out = dict() + out = {} for idx, ri in zip(unq_idx, unq_items): - ij = divmod(ri-1, self.nx) + ij = divmod(ri - 1, self.nx) out[ij] = np.stack((j[idx], i[idx]), axis=1) return out - def lookup_transform(self, data, grid=None, method=np.mean, lut=None, - return_lut=False): - """Performs the forward transformation of gridded data into self. + def lookup_transform( + self, + data: NDArray[Any], + grid: Grid | None = None, + method: Callable = np.mean, + lut: NDArray[Any] | None = None, + *, + return_lut: bool = False, + ) -> NDArray[Any]: + """Perform the forward transformation of gridded data into self. This method is suitable when the data grid is of higher resolution than ``self``. ``lookup_transform`` performs aggregation of data @@ -811,23 +916,26 @@ def lookup_transform(self, data, grid=None, method=np.mean, lut=None, ------- An aggregated ndarray of the data, in ``self`` coordinates. If ``return_lut==True``, also return the lookup table - """ + """ # Input checks if grid is None: grid = check_crs(data) # xarray if not isinstance(grid, Grid): - raise ValueError('grid should be a Grid instance') - if hasattr(data, 'values'): - data = data.values # xarray + msg = 'grid should be a Grid instance' + raise TypeError(msg) + if isinstance(data, (xr.DataArray, xr.Dataset)): + data = data.to_numpy() # xarray # dimensional check in_shape = data.shape ndims = len(in_shape) if (ndims < 2) or (ndims > 4): - raise ValueError('data dimension not accepted') + msg = 'Expected 2D, 3D or 4D data but got {}D'.format(ndims) + raise ValueError(msg) if (in_shape[-1] != grid.nx) or (in_shape[-2] != grid.ny): - raise ValueError('data dimension not compatible') + msg = 'data dimension not compatible' + raise ValueError(msg) if lut is None: lut = self.grid_lookup(grid) @@ -842,7 +950,7 @@ def lookup_transform(self, data, grid=None, method=np.mean, lut=None, dtype=float if data.dtype.kind == 'i' else data.dtype, ) - def _2d_trafo(ind, outd): + def _2d_trafo(ind: NDArray[Any], outd: NDArray[Any]) -> NDArray[Any]: for ji, l in lut.items(): outd[ji] = method(ind[l[:, 0], l[:, 1]]) return outd @@ -867,11 +975,16 @@ def _2d_trafo(ind, outd): if return_lut: return out_data, lut - else: - return out_data + return out_data - def map_gridded_data(self, data, grid=None, interp='nearest', - ks=3, out=None): + def map_gridded_data( + self, + data: NDArray[Any], + grid: Grid | str | None = None, + interp: str = 'nearest', + ks: int = 3, + out: NDArray[Any] | None = None, + ) -> NDArray[Any]: """Reprojects any structured data onto the local grid. The z and time dimensions of the data (if provided) are conserved, but @@ -902,27 +1015,24 @@ def map_gridded_data(self, data, grid=None, interp='nearest', Returns ------- A projected ndarray of the data, in ``self`` coordinates. - """ - if grid is None: - try: + """ + if grid is None and isinstance(data, (xr.DataArray, xr.Dataset)): + with contextlib.suppress(AttributeError): grid = data.salem.grid # try xarray - except AttributeError: - pass # Input checks if not isinstance(grid, Grid): - raise ValueError('grid should be a Grid instance') + msg = 'grid should be a Grid instance' + raise TypeError(msg) - try: # in case someone gave an xarray dataarray - data = data.values - except AttributeError: - pass + if isinstance(data, (xr.DataArray, xr.Dataset)): + with contextlib.suppress(AttributeError): + data = data.to_numpy() - try: # in case someone gave a masked array (won't work with scipy) + # in case someone gave a masked array (won't work with scipy) + with contextlib.suppress(AttributeError): data = data.filled(np.nan) - except AttributeError: - pass if data.dtype == np.float32: # New in scipy - issue with float32 @@ -931,9 +1041,11 @@ def map_gridded_data(self, data, grid=None, interp='nearest', in_shape = data.shape ndims = len(in_shape) if (ndims < 2) or (ndims > 4): - raise ValueError('data dimension not accepted') + msg = 'Expected 2D, 3D or 4D data but got {}D'.format(ndims) + raise ValueError(msg) if (in_shape[-1] != grid.nx) or (in_shape[-2] != grid.ny): - raise ValueError('data dimension not compatible') + msg = 'data dimension not compatible' + raise ValueError(msg) interp = interp.lower() @@ -945,10 +1057,12 @@ def map_gridded_data(self, data, grid=None, interp='nearest', # Work in center grid cause that's what we need # TODO: this stage could be optimized when many variables need transfo i, j = self.center_grid.ij_coordinates - oi, oj = grid.center_grid.transform(i, j, crs=self.center_grid, - nearest=use_nn, maskout=False) - pv = np.nonzero((oi >= 0) & (oi < grid.nx) & - (oj >= 0) & (oj < grid.ny)) + oi, oj = grid.center_grid.transform( + i, j, crs=self.center_grid, nearest=use_nn, maskout=False + ) + pv = np.nonzero( + (oi >= 0) & (oi < grid.nx) & (oj >= 0) & (oj < grid.ny) + ) # Prepare the output if out is not None: @@ -974,7 +1088,8 @@ def map_gridded_data(self, data, grid=None, interp='nearest', if interp == 'nearest': if out is not None: if ndims > 2: - raise ValueError('Need 2D for now.') + msg = 'Need 2D for now but got {}D'.format(ndims) + raise ValueError(msg) vok = np.isfinite(data[oj, oi]) out_data[j[vok], i[vok]] = data[oj[vok], oi[vok]] else: @@ -991,8 +1106,9 @@ def map_gridded_data(self, data, grid=None, interp='nearest', out_data[j, i] = f((oj, oi)) if ndims == 3: for dimi, cdata in enumerate(data): - f = RegularGridInterpolator(points, cdata, - bounds_error=False) + f = RegularGridInterpolator( + points, cdata, bounds_error=False + ) if out is not None: tmp = f((oj, oi)) vok = np.isfinite(tmp) @@ -1002,8 +1118,9 @@ def map_gridded_data(self, data, grid=None, interp='nearest', if ndims == 4: for dimj, cdata in enumerate(data): for dimi, ccdata in enumerate(cdata): - f = RegularGridInterpolator(points, ccdata, - bounds_error=False) + f = RegularGridInterpolator( + points, ccdata, bounds_error=False + ) if out is not None: tmp = f((oj, oi)) vok = np.isfinite(tmp) @@ -1040,20 +1157,27 @@ def map_gridded_data(self, data, grid=None, interp='nearest', else: out_data[dimj, dimi, j, i] = f(oj, oi, grid=False) else: - msg = 'interpolation not understood: {}'.format(interp) + msg = f'interpolation not understood: {interp}' raise ValueError(msg) # we have to catch a warning for an unexplained reason with warnings.catch_warnings(): - mess = "invalid value encountered in isfinite" - warnings.filterwarnings("ignore", message=mess) - out_data = np.ma.masked_invalid(out_data) - return out_data - - def region_of_interest(self, shape=None, geometry=None, grid=None, - corners=None, crs=wgs84, roi=None, - all_touched=False): - """Computes a region of interest (ROI). + mess = 'invalid value encountered in isfinite' + warnings.filterwarnings('ignore', message=mess) + return np.ma.masked_invalid(out_data) + + def region_of_interest( + self, + shape: Path | None = None, + geometry: BaseGeometry | None = None, + grid: Grid | None = None, + corners: tuple[float, float] | None = None, + crs: pyproj.Proj | Grid | str | xr.DataArray | None = wgs84, + roi: np.ndarray | None = None, + *, + all_touched: bool = False, + ) -> np.ndarray: + """Compute a region of interest (ROI). A ROI is simply a mask of the same size as the grid. @@ -1077,8 +1201,8 @@ def region_of_interest(self, shape=None, geometry=None, grid=None, pass-through argument for rasterio.features.rasterize, indicating that all grid cells which are clipped by the shapefile defining the region of interest should be included (default=False) - """ + """ # Initial mask if roi is not None: mask = np.array(roi, dtype=np.int16) @@ -1087,29 +1211,35 @@ def region_of_interest(self, shape=None, geometry=None, grid=None, # Collect keyword arguments, overriding anything the user # inadvertently added - rasterize_kws = dict(out=mask, all_touched=all_touched) + rasterize_kws = {'out': mask, 'all_touched': all_touched} # Several cases if shape is not None: import pandas as pd + inplace = False if not isinstance(shape, pd.DataFrame): from salem.sio import read_shapefile + shape = read_shapefile(shape) inplace = True # corner grid is needed for rasterio - shape = transform_geopandas(shape, to_crs=self.corner_grid, - inplace=inplace) + gps_shape = transform_geopandas( + shape, to_crs=self.corner_grid, inplace=inplace + ) import rasterio from rasterio.features import rasterize + with rasterio.Env(): - mask = rasterize(shape.geometry, **rasterize_kws) + mask = rasterize(gps_shape.geometry, **rasterize_kws) if geometry is not None: import rasterio from rasterio.features import rasterize + # corner grid is needed for rasterio - geom = transform_geometry(geometry, crs=crs, - to_crs=self.corner_grid) + geom = transform_geometry( + geometry, crs=crs, to_crs=self.corner_grid + ) with rasterio.Env(): mask = rasterize(np.atleast_1d(geom), **rasterize_kws) if grid is not None: @@ -1120,12 +1250,14 @@ def region_of_interest(self, shape=None, geometry=None, grid=None, xy0, xy1 = corners x0, y0 = cgrid.transform(*xy0, crs=crs, nearest=True) x1, y1 = cgrid.transform(*xy1, crs=crs, nearest=True) - mask[np.min([y0, y1]):np.max([y0, y1]) + 1, - np.min([x0, x1]):np.max([x0, x1]) + 1] = 1 + mask[ + np.min([y0, y1]) : np.max([y0, y1]) + 1, + np.min([x0, x1]) : np.max([x0, x1]) + 1, + ] = 1 return mask - def to_dict(self): + def to_dict(self) -> dict: """Serialize this grid to a dictionary. Returns @@ -1135,14 +1267,19 @@ def to_dict(self): See Also -------- from_dict : create a Grid from a dict - """ - return dict(proj=self.proj.srs, x0y0=(self.x0, self.y0), - nxny=(self.nx, self.ny), dxdy=(self.dx, self.dy), - pixel_ref=self.pixel_ref) - @classmethod - def from_dict(self, d): - """Create a Grid from a dictionary + """ + return { + 'proj': self.proj.srs, + 'x0y0': (self.x0, self.y0), + 'nxny': (self.nx, self.ny), + 'dxdy': (self.dx, self.dy), + 'pixel_ref': self.pixel_ref, + } + + @staticmethod + def from_dict(d: dict) -> Grid: + """Create a Grid from a dictionary. Parameters ---------- @@ -1156,10 +1293,11 @@ def from_dict(self, d): See Also -------- to_dict : create a dict from a Grid + """ return Grid(**d) - def to_json(self, fpath): + def to_json(self, fpath: Path) -> None: """Serialize this grid to a json file. Parameters @@ -1170,14 +1308,16 @@ def to_json(self, fpath): See Also -------- from_json : read a json file + """ import json - with open(fpath, 'w') as fp: + + with fpath.open('w') as fp: json.dump(self.to_dict(), fp) - @classmethod - def from_json(self, fpath): - """Create a Grid from a json file + @staticmethod + def from_json(fpath: Path) -> Grid: + """Create a Grid from a json file. Parameters ---------- @@ -1191,28 +1331,37 @@ def from_json(self, fpath): See Also -------- to_json : create a json file + """ import json - with open(fpath, 'r') as fp: + + with fpath.open() as fp: d = json.load(fp) return Grid.from_dict(d) - def to_dataset(self): - """Creates an empty dataset based on the Grid's geolocalisation. + def to_dataset(self) -> xr.Dataset: + """Create an empty dataset based on the Grid's geolocalisation. Returns ------- An xarray.Dataset object ready to be filled with data + """ import xarray as xr - ds = xr.Dataset(coords={'x': (['x', ], self.center_grid.x_coord), - 'y': (['y', ], self.center_grid.y_coord)} - ) + + ds = xr.Dataset( + coords={ + 'x': (['x'], self.center_grid.x_coord), + 'y': (['y'], self.center_grid.y_coord), + } + ) ds.attrs['pyproj_srs'] = self.proj.srs return ds - def to_geometry(self, to_crs=None): - """Makes a geometrical representation of the grid (e.g. for drawing). + def to_geometry( + self, to_crs: pyproj.Proj | Grid | None = None + ) -> BaseGeometry: + """Make a geometrical representation of the grid (e.g. for drawing). This can come also handy when doing shape-to-raster operations. @@ -1222,15 +1371,17 @@ def to_geometry(self, to_crs=None): Returns ------- a geopandas.GeoDataFrame + """ from geopandas import GeoDataFrame from shapely.geometry import Polygon + out = GeoDataFrame() geoms = [] ii = [] jj = [] - xx = self.corner_grid.x0 + np.arange(self.nx+1) * self.dx - yy = self.corner_grid.y0 + np.arange(self.ny+1) * self.dy + xx = self.corner_grid.x0 + np.arange(self.nx + 1) * self.dx + yy = self.corner_grid.y0 + np.arange(self.ny + 1) * self.dy for j, (y0, y1) in enumerate(zip(yy[:-1], yy[1:])): for i, (x0, x1) in enumerate(zip(xx[:-1], xx[1:])): coords = [(x0, y0), (x1, y0), (x1, y1), (x0, y1), (x0, y0)] @@ -1246,8 +1397,8 @@ def to_geometry(self, to_crs=None): return out -def proj_is_same(p1, p2): - """Checks is two pyproj projections are equal. +def proj_is_same(p1: pyproj.Proj, p2: pyproj.Proj) -> bool: + """Check is two pyproj projections are equal. See https://github.com/jswhit/pyproj/issues/15#issuecomment-208862786 @@ -1257,38 +1408,61 @@ def proj_is_same(p1, p2): first projection p2 : pyproj.Proj second projection + """ if has_gdal: # this is more robust, but gdal is a pain + from osgeo import osr + s1 = osr.SpatialReference() s1.ImportFromProj4(p1.srs) s2 = osr.SpatialReference() s2.ImportFromProj4(p2.srs) return s1.IsSame(s2) == 1 # IsSame returns 1 or 0 - else: - # at least we can try to sort it - p1 = '+'.join(sorted(p1.srs.split('+'))) - p2 = '+'.join(sorted(p2.srs.split('+'))) - return p1 == p2 - - -def _transform_internal(p1, p2, x, y, **kwargs): + # at least we can try to sort it + p1_str = '+'.join(sorted(p1.srs.split('+'))) + p2_str = '+'.join(sorted(p2.srs.split('+'))) + return p1_str == p2_str + + +def _transform_internal( + p1: pyproj.Proj, + p2: pyproj.Proj, + x: np.ndarray, + y: np.ndarray, + *, + always_xy: bool = False, + area_of_interest: AreaOfInterest | None = None, + **kwargs: bool, +) -> tuple[Any, Any]: if hasattr(pyproj, 'Transformer'): - trf = pyproj.Transformer.from_proj(p1, p2, **kwargs) + trf = pyproj.Transformer.from_proj( + p1, p2, always_xy=always_xy, area_of_interest=area_of_interest + ) with warnings.catch_warnings(): # https://github.com/pyproj4/pyproj/issues/1415 - warnings.filterwarnings("ignore", category=DeprecationWarning, - message=".*ndim > 0 to a scalar.*") + warnings.filterwarnings( + 'ignore', + category=DeprecationWarning, + message='.*ndim > 0 to a scalar.*', + ) return trf.transform(x, y) else: return pyproj.transform(p1, p2, x, y, **kwargs) -def transform_proj(p1, p2, x, y, nocopy=False): - """Wrapper around the pyproj.transform function. +def transform_proj( + p1: pyproj.Proj, + p2: pyproj.Proj, + x: np.ndarray, + y: np.ndarray, + *, + nocopy: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """Transform points between two coordinate systems. - Transform points between two coordinate systems defined by the Proj - instances p1 and p2. + Wrapper around the pyproj.transform function. + The coordinate systems are defined by the Proj instances p1 and p2. When two projections are equal, this function avoids quite a bunch of useless calculations. See https://github.com/jswhit/pyproj/issues/15 @@ -1305,8 +1479,8 @@ def transform_proj(p1, p2, x, y, nocopy=False): northings nocopy : bool in case the two projections are equal, you can use nocopy if you wish - """ + """ try: # This always makes a copy, even if projections are equivalent return _transform_internal(p1, p2, x, y, always_xy=True) @@ -1314,13 +1488,16 @@ def transform_proj(p1, p2, x, y, nocopy=False): if proj_is_same(p1, p2): if nocopy: return x, y - else: - return copy.deepcopy(x), copy.deepcopy(y) + return copy.deepcopy(x), copy.deepcopy(y) return _transform_internal(p1, p2, x, y) -def transform_geometry(geom, crs=wgs84, to_crs=wgs84): +def transform_geometry( + geom: BaseGeometry, + crs: pyproj.Proj | Grid = wgs84, + to_crs: pyproj.Proj | Grid = wgs84, +) -> BaseGeometry: """Reprojects a shapely geometry. Parameters @@ -1335,8 +1512,8 @@ def transform_geometry(geom, crs=wgs84, to_crs=wgs84): Returns ------- A reprojected geometry - """ + """ from_crs = check_crs(crs) to_crs = check_crs(to_crs) @@ -1347,13 +1524,20 @@ def transform_geometry(geom, crs=wgs84, to_crs=wgs84): elif isinstance(from_crs, Grid): project = partial(from_crs.ij_to_crs, crs=to_crs) else: - raise NotImplementedError() + raise NotImplementedError from shapely.ops import transform + return transform(project, geom) -def transform_geopandas(gdf, from_crs=None, to_crs=wgs84, inplace=False): +def transform_geopandas( + gdf: gpd.GeoDataFrame, + from_crs: pyproj.Proj | Grid | None = None, + to_crs: pyproj.Proj | Grid = wgs84, + *, + inplace: bool = False, +) -> gpd.GeoDataFrame: """Reprojects a geopandas dataframe. Parameters @@ -1370,20 +1554,18 @@ def transform_geopandas(gdf, from_crs=None, to_crs=wgs84, inplace=False): Returns ------- A projected dataframe + """ - from shapely.ops import transform import geopandas as gpd + from shapely.ops import transform - if from_crs is None: - from_crs = check_crs(gdf.crs) - else: - from_crs = check_crs(from_crs) + if gdf.crs is None and from_crs is None: + msg = 'You need to set from_crs or gdf needs a crs' + raise ValueError(msg) + from_crs = check_crs(gdf.crs) if from_crs is None else check_crs(from_crs) to_crs = check_crs(to_crs) - if inplace: - out = gdf - else: - out = gdf.copy() + out = gdf if inplace else gdf.copy() if isinstance(to_crs, pyproj.Proj) and isinstance(from_crs, pyproj.Proj): project = partial(transform_proj, from_crs, to_crs) @@ -1392,22 +1574,22 @@ def transform_geopandas(gdf, from_crs=None, to_crs=wgs84, inplace=False): elif isinstance(from_crs, Grid): project = partial(from_crs.ij_to_crs, crs=to_crs) else: - raise NotImplementedError() + raise NotImplementedError # Do the job and set the new attributes result = out.geometry.apply(lambda geom: transform(project, geom)) result.__class__ = gpd.GeoSeries if isinstance(to_crs, pyproj.Proj): - to_crs = to_crs.srs + to_crs_str = to_crs.srs elif isinstance(to_crs, Grid): - to_crs = None + to_crs_str = None out['geometry'] = result try: - out.set_crs(to_crs, allow_override=True, inplace=True) + out.set_crs(to_crs_str, allow_override=True, inplace=True) except ValueError: # Older versions of geopandas - out.crs = to_crs - out.geometry.crs = to_crs + out.crs = to_crs_str + out.geometry.crs = to_crs_str out['min_x'] = [g.bounds[0] for g in out.geometry] out['max_x'] = [g.bounds[2] for g in out.geometry] out['min_y'] = [g.bounds[1] for g in out.geometry] @@ -1415,17 +1597,16 @@ def transform_geopandas(gdf, from_crs=None, to_crs=wgs84, inplace=False): return out -def proj_is_latlong(proj): +def proj_is_latlong(proj: pyproj.Proj) -> bool: """Shortcut function because of deprecation.""" - try: - return proj.is_latlong() + return 'longlat' in proj.definition_string() except AttributeError: return proj.crs.is_geographic -def proj_to_cartopy(proj): - """Converts a pyproj.Proj to a cartopy.crs.Projection +def proj_to_cartopy(proj: pyproj.Proj | Grid) -> ccrs.Projection: + """Convert a pyproj.Proj to a cartopy.crs.Projection. Parameters ---------- @@ -1437,65 +1618,71 @@ def proj_to_cartopy(proj): a cartopy.crs.Projection object """ - - import cartopy + if not has_cartopy: + msg = 'cartopy is not installed' + raise ImportError(msg) import cartopy.crs as ccrs - proj = check_crs(proj) + cproj = check_crs(proj) - if proj_is_latlong(proj): + if proj_is_latlong(cproj): return ccrs.PlateCarree() - srs = proj.srs + srs = cproj.srs if has_gdal: # this is more robust, as srs could be anything (espg, etc.) + from osgeo import osr + s1 = osr.SpatialReference() - s1.ImportFromProj4(proj.srs) + s1.ImportFromProj4(cproj.srs) if s1.ExportToProj4(): srs = s1.ExportToProj4() - km_proj = {'lon_0': 'central_longitude', - 'lat_0': 'central_latitude', - 'x_0': 'false_easting', - 'y_0': 'false_northing', - 'lat_ts': 'latitude_true_scale', - 'o_lon_p': 'central_rotated_longitude', - 'o_lat_p': 'pole_latitude', - 'k': 'scale_factor', - 'zone': 'zone', - } - km_globe = {'a': 'semimajor_axis', - 'b': 'semiminor_axis', - } - km_std = {'lat_1': 'lat_1', - 'lat_2': 'lat_2', - } - kw_proj = dict() - kw_globe = dict() - kw_std = dict() - for s in srs.split('+'): - s = s.split('=') + km_proj = { + 'lon_0': 'central_longitude', + 'lat_0': 'central_latitude', + 'x_0': 'false_easting', + 'y_0': 'false_northing', + 'lat_ts': 'latitude_true_scale', + 'o_lon_p': 'central_rotated_longitude', + 'o_lat_p': 'pole_latitude', + 'k': 'scale_factor', + 'zone': 'zone', + } + km_globe = { + 'a': 'semimajor_axis', + 'b': 'semiminor_axis', + } + km_std = { + 'lat_1': 'lat_1', + 'lat_2': 'lat_2', + } + kw_proj = {} + kw_globe = {} + kw_std = {} + cl = None + v = None + for i in srs.split('+'): + s = i.split('=') if len(s) != 2: continue k = s[0].strip() v = s[1].strip() - try: + with contextlib.suppress(Exception): v = float(v) - except: - pass if k == 'proj': if v == 'tmerc': cl = ccrs.TransverseMercator kw_proj['approx'] = True - if v == 'lcc': + elif v == 'lcc': cl = ccrs.LambertConformal - if v == 'merc': + elif v == 'merc': cl = ccrs.Mercator - if v == 'utm': + elif v == 'utm': cl = ccrs.UTM - if v == 'stere': + elif v == 'stere': cl = ccrs.Stereographic - if v == 'ob_tran': + elif v == 'ob_tran': cl = ccrs.RotatedPole if k in km_proj: if k == 'zone': @@ -1506,6 +1693,11 @@ def proj_to_cartopy(proj): if k in km_std: kw_std[km_std[k]] = v + if cl is None: + msg = 'Could not determine the projection type. {} not known.'.format( + v + ) + raise ValueError(msg) globe = None if kw_globe: globe = ccrs.Globe(ellipse='sphere', **kw_globe) @@ -1516,8 +1708,11 @@ def proj_to_cartopy(proj): if cl.__name__ == 'Mercator': kw_proj.pop('false_easting', None) kw_proj.pop('false_northing', None) - if Version(cartopy.__version__) < Version('0.15'): - kw_proj.pop('latitude_true_scale', None) + if has_cartopy: + import cartopy + + if Version(cartopy.__version__) < Version('0.15'): + kw_proj.pop('latitude_true_scale', None) elif cl.__name__ == 'Stereographic': kw_proj.pop('scale_factor', None) if 'latitude_true_scale' in kw_proj: @@ -1538,8 +1733,15 @@ def proj_to_cartopy(proj): return cl(globe=globe, **kw_proj) -def mercator_grid(center_ll=None, extent=None, ny=600, nx=None, - origin='lower-left', transverse=True): +def mercator_grid( + center_ll: tuple[float, float], + extent: tuple[float, float], + ny: int = 600, + nx: int | None = None, + origin: str = 'lower-left', + *, + transverse: bool = True, +) -> Grid: """Local (transverse) mercator map centered on a specified point. Parameters @@ -1558,45 +1760,64 @@ def mercator_grid(center_ll=None, extent=None, ny=600, nx=None, transverse : bool wether to use a transverse or regular mercator. Default should have been false, but for backwards compatibility reasons we keep it to True - """ + """ + # if nx is not None and ny is not None: + # msg = 'You cannot specify both nx and ny' + # raise ValueError(msg) # Make a local proj pname = 'tmerc' if transverse else 'merc' lon, lat = center_ll - proj_params = dict(proj=pname, lat_0=0., lon_0=lon, - k=0.9996, x_0=0, y_0=0, datum='WGS84') + proj_params = { + 'proj': pname, + 'lat_0': 0.0, + 'lon_0': lon, + 'k': 0.9996, + 'x_0': 0, + 'y_0': 0, + 'datum': 'WGS84', + } projloc = pyproj.Proj(proj_params) # Define a spatial resolution xx = extent[0] yy = extent[1] if nx is None: - nx = ny * xx / yy + nnx = np.rint(ny * xx / yy) + nny = np.rint(ny) else: - ny = nx * yy / xx - - nx = np.rint(nx) - ny = np.rint(ny) + nnx = np.rint(nx) + nny = np.rint(nx * yy / xx) - e, n = transform_proj(wgs84, projloc, lon, lat) + e, n = transform_proj(wgs84, projloc, np.array(lon), np.array(lat)) if origin == 'upper-left': - corner = (-xx / 2. + e, yy / 2. + n) - dxdy = (xx / nx, - yy / ny) + corner = (float(-xx / 2.0 + e), float(yy / 2.0 + n)) + dxdy = (xx / nnx, -yy / nny) else: - corner = (-xx / 2. + e, -yy / 2. + n) - dxdy = (xx / nx, yy / ny) - - return Grid(proj=projloc, x0y0=corner, nxny=(nx, ny), dxdy=dxdy, - pixel_ref='corner') - - -def googlestatic_mercator_grid(center_ll=None, nx=640, ny=640, zoom=12, scale=1): + corner = (float(-xx / 2.0 + e), float(-yy / 2.0 + n)) + dxdy = (xx / nnx, yy / nny) + + return Grid( + proj=projloc, + x0y0=corner, + nxny=(nnx, nny), + dxdy=dxdy, + pixel_ref='corner', + ) + + +def googlestatic_mercator_grid( + center_ll: tuple[float, float], + nx: int = 640, + ny: int = 640, + zoom: int = 12, + scale: int = 1, +) -> Grid: """Mercator map centered on a specified point (google API conventions). Mostly useful for google maps. """ - # Number of pixels in an image with a zoom level of 0. google_pix = 256 * scale # The equatorial radius of the Earth assuming WGS-84 ellipsoid. @@ -1615,10 +1836,10 @@ def googlestatic_mercator_grid(center_ll=None, nx=640, ny=640, zoom=12, scale=1) xx = nx * mpix yy = ny * mpix - e, n = transform_proj(wgs84, projloc, lon, lat) - corner = (-xx / 2. + e, yy / 2. + n) - dxdy = (xx / nx, - yy / ny) + e, n = transform_proj(wgs84, projloc, np.array(lon), np.array(lat)) + corner = (-xx / 2.0 + e, yy / 2.0 + n) + dxdy = (xx / nx, -yy / ny) - return Grid(proj=projloc, x0y0=corner, - nxny=(nx, ny), dxdy=dxdy, - pixel_ref='corner') + return Grid( + proj=projloc, x0y0=corner, nxny=(nx, ny), dxdy=dxdy, pixel_ref='corner' + ) diff --git a/salem/graphics.py b/salem/graphics.py index 9f61582..e4a518c 100644 --- a/salem/graphics.py +++ b/salem/graphics.py @@ -1,53 +1,74 @@ -""" -Color handling and maps. -""" -from __future__ import division +"""Color handling and maps.""" # Builtins -import warnings -import os -from os import path +from __future__ import annotations + +import contextlib import copy -# External libs +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + import numpy as np +# External libs +from matplotlib.colors import Normalize + try: from skimage.transform import resize as imresize + has_skimage = True except ImportError: has_skimage = False -import pandas as pd - import matplotlib as mpl import matplotlib.pyplot as plt +import pandas as pd +import xarray as xr +from matplotlib.collections import LineCollection, PatchCollection from matplotlib.colors import LinearSegmentedColormap -from mpl_toolkits.axes_grid1 import make_axes_locatable -from matplotlib.collections import PatchCollection, LineCollection -from shapely.geometry import MultiPoint, LineString, Polygon -from salem.descartes import PolygonPatch from matplotlib.transforms import Transform as MPLTranform +from mpl_toolkits.axes_grid1 import make_axes_locatable +from shapely.geometry import LineString, MultiPoint, Polygon -from salem import utils, gis, sio, Grid, wgs84, sample_data_dir, GeoTiff +from salem import GeoTiff, Grid, gis, sample_data_dir, sio, utils, wgs84 +from salem.descartes import PolygonPatch -shapefiles = dict() -shapefiles['world_borders'] = path.join(sample_data_dir, 'shapes', - 'world_borders', 'world_borders.shp') -shapefiles['oceans'] = path.join(sample_data_dir, 'shapes', 'oceans', - 'ne_50m_ocean.shp') -shapefiles['rivers'] = path.join(sample_data_dir, 'shapes', 'rivers', - 'ne_50m_rivers_lake_centerlines.shp') -shapefiles['lakes'] = path.join(sample_data_dir, 'shapes', 'lakes', - 'ne_50m_lakes.shp') +if TYPE_CHECKING: + import pyproj + from matplotlib.artist import Artist + from matplotlib.axes import Axes + from matplotlib.colorbar import ColorbarBase + from numpy.typing import NDArray + +shapefiles = {} +shapefiles['world_borders'] = ( + sample_data_dir / 'shapes' / 'world_borders' / 'world_borders.shp' +) +shapefiles['oceans'] = ( + sample_data_dir / 'shapes' / 'oceans' / 'ne_50m_ocean.shp' +) +shapefiles['rivers'] = ( + sample_data_dir + / 'shapes' + / 'rivers' + / 'ne_50m_rivers_lake_centerlines.shp' +) +shapefiles['lakes'] = sample_data_dir / 'shapes' / 'lakes' / 'ne_50m_lakes.shp' # Be sure we have the directory -if not os.path.exists(shapefiles['world_borders']): +if not (sample_data_dir / 'shapes' / 'world_borders').exists(): from salem.utils import get_demo_file + _ = get_demo_file('world_borders.shp') +ExtendChoices = Literal['neither', 'both', 'min', 'max'] | None +InterpChoices = Literal['nearest', 'linear', 'spline'] + + class ExtendedNorm(mpl.colors.BoundaryNorm): - """ A better BoundaryNorm with an ``extend'' keyword. + """A better BoundaryNorm with an ``extend'' keyword. TODO: remove this when PR is accepted @@ -55,22 +76,21 @@ class ExtendedNorm(mpl.colors.BoundaryNorm): https://github.com/matplotlib/matplotlib/pull/5034 """ - def __init__(self, boundaries, ncolors, extend='neither'): - + def __init__(self, boundaries, ncolors, extend='neither') -> None: _b = np.atleast_1d(boundaries).astype(float) mpl.colors.BoundaryNorm.__init__(self, _b, ncolors, clip=False) # 'neither' | 'both' | 'min' | 'max' if extend == 'both': - _b = np.append(_b, _b[-1]+1) - _b = np.insert(_b, 0, _b[0]-1) + _b = np.append(_b, _b[-1] + 1) + _b = np.insert(_b, 0, _b[0] - 1) elif extend == 'min': - _b = np.insert(_b, 0, _b[0]-1) + _b = np.insert(_b, 0, _b[0] - 1) elif extend == 'max': - _b = np.append(_b, _b[-1]+1) + _b = np.append(_b, _b[-1] + 1) self._b = _b self._N = len(self._b) - if self._N - 1 == self.Ncmap: + if self.Ncmap == self._N - 1: self._interp = False else: self._interp = True @@ -93,14 +113,13 @@ def __call__(self, value): return ret -def get_cmap(cmap='viridis'): +def get_cmap(cmap: str = 'viridis') -> mpl.colors.Colormap: """Get a colormap from mpl, and also those defined by Salem. Currently we have: topo, dem, nrwc see https://github.com/fmaussion/salem-sample-data/tree/master/colormaps """ - try: return plt.get_cmap(cmap) except ValueError: @@ -108,7 +127,7 @@ def get_cmap(cmap='viridis'): return LinearSegmentedColormap.from_list(cmap, cl, N=256) -class DataLevels(object): +class DataLevels: """Assigns the right color to your data. Simple tool that ensures the full compatibility of the plot @@ -118,20 +137,21 @@ class DataLevels(object): def __init__( self, - data=None, - levels=None, - nlevels=None, - vmin=None, - vmax=None, - extend=None, - cmap=None, - norm=None, - ): + data: NDArray[Any] | None = None, + levels: NDArray[Any] | None = None, + nlevels: int | None = None, + vmin: float | None = None, + vmax: float | None = None, + extend: ExtendChoices = None, + cmap: str | None = None, + norm: Normalize | None = None, + ) -> None: """Instanciate. Parameters ---------- see the set_* functions + """ self.set_data(data) self.set_levels(levels) @@ -142,64 +162,63 @@ def __init__( self.set_cmap(cmap) self.set_norm(norm) - def update(self, d): - """ - Update the properties of :class:`DataLevels` from the dictionary *d*. - """ - + def update(self, d: dict) -> None: + """Update the properties of :class:`DataLevels` from the dictionary *d*.""" for k, v in d.items(): func = getattr(self, 'set_' + k, None) if func is None or not callable(func): - raise AttributeError('Unknown property %s' % k) + msg = 'Unknown property {}'.format(k) + raise AttributeError(msg) func(v) - def set_data(self, data=None): + def set_data(self, data: NDArray[Any] | None = None) -> None: """Any kind of data array (also masked).""" if data is not None: self.data = np.ma.masked_invalid(np.atleast_1d(data), copy=False) else: - self.data = np.ma.asarray([0., 1.]) + self.data = np.ma.asarray([0.0, 1.0]) - def set_levels(self, levels=None): + def set_levels(self, levels: NDArray[Any] | None = None) -> None: """Levels you define. Must be monotically increasing.""" self._levels = levels - def set_nlevels(self, nlevels=None): + def set_nlevels(self, nlevels: int | None = None) -> None: """Automatic N levels. Ignored if set_levels has been set.""" self._nlevels = nlevels - def set_vmin(self, val=None): + def set_vmin(self, val: float | None = None) -> None: """Mininum level value. Ignored if set_levels has been set.""" self._vmin = val - def set_vmax(self, val=None): + def set_vmax(self, val: float | None = None) -> None: """Maximum level value. Ignored if set_levels has been set.""" self._vmax = val - def set_cmap(self, cm=None): + def set_cmap(self, cm: str | None = None) -> None: """Set a colormap.""" self.cmap = get_cmap(cm or 'viridis') - def set_norm(self, norm=None): + def set_norm(self, norm: Normalize | None = None) -> None: """Set a normalization function. Related parameters will be ignored if set. - (e.g., vmin and vmax will be ignored if using LogNorm)""" + (e.g., vmin and vmax will be ignored if using LogNorm) + """ self._norm = norm - def set_extend(self, extend=None): - """Colorbar extensions: 'neither' | 'both' | 'min' | 'max'""" + def set_extend(self, extend: ExtendChoices = None) -> None: + """Set colorbar extensions: 'neither' | 'both' | 'min' | 'max'.""" self._extend = extend def set_plot_params( self, - levels=None, - nlevels=None, - vmin=None, - vmax=None, - extend=None, - cmap=None, - norm=None, - ): + levels: NDArray[Any] | None = None, + nlevels: int | None = None, + vmin: float | None = None, + vmax: float | None = None, + extend: ExtendChoices = None, + cmap: str | None = None, + norm: Normalize | None = None, + ) -> None: """Shortcut to all parameters related to the plot. As a side effect, running set_plot_params() without arguments will @@ -215,7 +234,7 @@ def set_plot_params( self.set_norm(norm) @property - def levels(self): + def levels(self) -> NDArray[Any]: """Clever getter.""" levels = self._levels nlevels = self._nlevels @@ -223,41 +242,38 @@ def levels(self): self.set_vmin(levels[0]) self.set_vmax(levels[-1]) return levels - else: - if nlevels is None: - if self.extend in ['max', 'min']: - nlevels = self.cmap.N - 1 - elif self.extend in ['both']: - nlevels = self.cmap.N - 2 - else: - nlevels = self.cmap.N - if self.vmax == self.vmin: - return np.linspace(self.vmin, self.vmax+1, nlevels) - return np.linspace(self.vmin, self.vmax, nlevels) + if nlevels is None: + if self.extend in ['max', 'min']: + nlevels = self.cmap.N - 1 + elif self.extend in ['both']: + nlevels = self.cmap.N - 2 + else: + nlevels = self.cmap.N + if self.vmax == self.vmin: + return np.linspace(self.vmin, self.vmax + 1, nlevels) + return np.linspace(self.vmin, self.vmax, nlevels) @property - def nlevels(self): + def nlevels(self) -> int: """Clever getter.""" return len(self.levels) @property - def vmin(self): + def vmin(self) -> float: """Clever getter.""" if self._vmin is None: return np.min(self.data) - else: - return self._vmin + return self._vmin @property - def vmax(self): + def vmax(self) -> float: """Clever getter.""" if self._vmax is None: return np.max(self.data) - else: - return self._vmax + return self._vmax @property - def extend(self): + def extend(self) -> str: """Clever getter.""" if self._extend is None: # If the user didnt set it, we decide @@ -271,65 +287,77 @@ def extend(self): else: out = 'neither' return out - else: - return self._extend + return self._extend @property - def norm(self): + def norm(self) -> Normalize: """Clever getter.""" - l = self.levels - e = self.extend + lev = self.levels + ext = self.extend if self._norm is None: # Warnings - if e not in ["both", "min"] and (np.min(l) > np.min(self.data)): - warnings.warn("Minimum data out of bounds.", RuntimeWarning) - if e not in ["both", "max"] and (np.max(l) < np.max(self.data)): - warnings.warn("Maximum data out of bounds.", RuntimeWarning) + if ext not in ['both', 'min'] and ( + np.min(lev) > np.min(self.data) + ): + warnings.warn( + 'Minimum data out of bounds.', RuntimeWarning, stacklevel=1 + ) + if ext not in ['both', 'max'] and ( + np.max(lev) < np.max(self.data) + ): + warnings.warn( + 'Maximum data out of bounds.', RuntimeWarning, stacklevel=1 + ) try: # Added in mpl 3.3.0 - return mpl.colors.BoundaryNorm(l, self.cmap.N, extend=e) + return mpl.colors.BoundaryNorm(lev, self.cmap.N, extend=ext) except TypeError: - return ExtendedNorm(l, self.cmap.N, extend=e) + return ExtendedNorm(lev, self.cmap.N, extend=ext) else: return self._norm - def to_rgb(self): + def to_rgb(self) -> NDArray[Any]: """Transform the data to RGB triples.""" - if np.all(self.data.mask): # unfortunately the functions below can't handle this one - return np.zeros(self.data.shape + (4, )) + return np.zeros((*self.data.shape, 4)) return self.cmap(self.norm(self.data)) - def get_colorbarbase_kwargs(self): + def get_colorbarbase_kwargs(self) -> dict: """If you need to make a colorbar based on a given DataLevel state.""" - # This is a discutable choice: with more than 60 colors (could be # less), we assume a continuous colorbar. if self.nlevels < 60 or self._norm is not None: norm = self.norm else: norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax) - return dict(extend=self.extend, cmap=self.cmap, norm=norm) + return {'extend': self.extend, 'cmap': self.cmap, 'norm': norm} - def colorbarbase(self, cax, **kwargs): - """Returns a ColorbarBase to add to the cax axis. All keywords are - passed to matplotlib.colorbar.ColorbarBase - """ + def colorbarbase(self, cax: Axes, **kwargs) -> ColorbarBase: + """Return a ColorbarBase to add to the cax axis. + All keywords are passed to matplotlib.colorbar.ColorbarBase + """ # This is a discutable choice: with more than 60 colors (could be # less), we assume a continuous colorbar. if self.nlevels < 60 or self._norm is not None: norm = self.norm else: norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax) - return mpl.colorbar.ColorbarBase(cax, extend=self.extend, - cmap=self.cmap, norm=norm, **kwargs) + return mpl.colorbar.ColorbarBase( + cax, extend=self.extend, cmap=self.cmap, norm=norm, **kwargs + ) - def append_colorbar(self, ax, position='right', size='5%', pad=0.5, - **kwargs): - """Appends a colorbar to existing axes + def append_colorbar( + self, + ax: Axes, + position: str = 'right', + size: float = 0.05, + pad: float = 0.5, + **kwargs, + ) -> ColorbarBase: + """Append a colorbar to existing axes. It uses matplotlib's make_axes_locatable toolkit. @@ -340,15 +368,15 @@ def append_colorbar(self, ax, position='right', size='5%', pad=0.5, size: the size of the colorbar (e.g. in % of the ax) pad: pad between axes given in inches or tuple-like of floats, (horizontal padding, vertical padding) - """ + """ orientation = 'horizontal' if position in ['left', 'right']: orientation = 'vertical' cax = make_axes_locatable(ax).append_axes(position, size=size, pad=pad) return self.colorbarbase(cax, orientation=orientation, **kwargs) - def plot(self, ax): + def plot(self, ax: Axes) -> Artist: """Add a kind of plot of the data to an axis. More useful for child classes. @@ -357,11 +385,18 @@ def plot(self, ax): """ data = np.atleast_2d(self.data) toplot = self.cmap(self.norm(data)) - primitive = ax.imshow(toplot, interpolation='none', origin='lower') - return primitive + return ax.imshow(toplot, interpolation='none', origin='lower') - def visualize(self, ax=None, title=None, orientation='vertical', - add_values=False, addcbar=True, cbar_title=''): + def visualize( + self, + ax: Axes | None = None, + title: str | None = None, + orientation: str = 'vertical', + cbar_title: str = '', + *, + add_values: bool = False, + addcbar: bool = True, + ) -> Artist: """Quick plot, useful for debugging. Parameters @@ -372,8 +407,8 @@ def visualize(self, ax=None, title=None, orientation='vertical', add_values: add the data values as text in the pixels (for testing) Returns a dict containing the primitives of the various plot calls - """ + """ # Do we make our own fig? if ax is None: ax = plt.gca() @@ -385,20 +420,28 @@ def visualize(self, ax=None, title=None, orientation='vertical', addcbar = (self.vmin != self.vmax) and addcbar if addcbar: if orientation == 'horizontal': - self.append_colorbar(ax, "top", size=0.2, pad=0.5, - label=cbar_title) + self.append_colorbar( + ax, 'top', size=0.2, pad=0.5, label=cbar_title + ) else: - self.append_colorbar(ax, "right", size="5%", pad=0.2, - label=cbar_title) + self.append_colorbar( + ax, 'right', size=0.05, pad=0.2, label=cbar_title + ) # Mini add-on if add_values: data = np.atleast_2d(self.data) - x, y = np.meshgrid(np.arange(data.shape[1]), - np.arange(data.shape[0])) + x, y = np.meshgrid( + np.arange(data.shape[1]), np.arange(data.shape[0]) + ) for v, i, j in zip(data.flatten(), x.flatten(), y.flatten()): - ax.text(i, j, v, horizontalalignment='center', - verticalalignment='center') + ax.text( + i, + j, + v, + horizontalalignment='center', + verticalalignment='center', + ) # Details if title is not None: @@ -425,8 +468,16 @@ class Map(DataLevels): regional maps. """ - def __init__(self, grid, nx=500, ny=None, factor=None, - countries=True, **kwargs): + def __init__( + self, + grid: Grid, + nx: int | None = 500, + ny: int | None = None, + factor: float | None = None, + *, + countries: bool = True, + **kwargs, + ) -> None: """Make a new map. Parameters @@ -445,8 +496,8 @@ def __init__(self, grid, nx=500, ny=None, factor=None, it later with a call to set_shapefile) kwargs: ** all keywords accepted by DataLevels - """ + """ if factor is not None: nx = None ny = None @@ -466,67 +517,80 @@ def __init__(self, grid, nx=500, ny=None, factor=None, self._contourf_data = None self._contour_data = None - def _check_data(self, data=None, crs=None, interp='nearest', - overplot=False): + def _check_data( + self, + data: xr.DataArray | NDArray[Any], + crs: pyproj.Proj | Grid | str | xr.DataArray | None = None, + interp: InterpChoices = 'nearest', + *, + overplot: bool = False, + ) -> xr.DataArray | NDArray[Any]: """Interpolates the data to the map grid.""" - - if crs is None: + if crs is None and isinstance(data, xr.DataArray): # try xarray # TODO: note that this might slow down the plotting a bit # if the data already matches the grid... - try: + with contextlib.suppress(Exception): crs = data.salem.grid - except: - pass data = np.ma.fix_invalid(np.squeeze(data)) shp = data.shape if len(shp) != 2: - raise ValueError('Data should be 2D.') + msg = 'Data should be 2D.' + raise ValueError(msg) if crs is None: # Reform case, but with a sanity check - if not np.isclose(shp[0] / shp[1], self.grid.ny / self.grid.nx, - atol=1e-2): - raise ValueError('Dimensions of data do not match the map.') + if not np.isclose( + shp[0] / shp[1], self.grid.ny / self.grid.nx, atol=1e-2 + ): + msg = 'Dimensions of data do not match the map.' + raise ValueError(msg) # need to resize if not same if not ((shp[0] == self.grid.ny) and (shp[1] == self.grid.nx)): - # We convert to float for img resizing if data.dtype not in [np.float32, np.float64]: data = data.astype(np.float64) - if interp.lower() == 'nearest': - interp = 0 - elif interp.lower() == 'linear': - interp = 1 - elif interp.lower() == 'spline': - interp = 3 + interp_dict = {'nearest': 0, 'linear': 1, 'spline': 3} + interpolation = interp_dict[interp] + if not has_skimage: - raise ImportError('Needs scikit-image to be installed.') + msg = 'Needs scikit-image to be installed.' + raise ImportError(msg) + from skimage.transform import resize as imresize + with warnings.catch_warnings(): - mess = "invalid value encountered in reduce" - warnings.filterwarnings("ignore", message=mess) - mess = "All-NaN slice encountered" - warnings.filterwarnings("ignore", message=mess) - mess = ("Possible precision loss when converting from " - "int64 to float64") - warnings.filterwarnings("ignore", message=mess) - mess = "Passing `np.nan` to mean no clipping in np.clip" - warnings.filterwarnings("ignore", message=mess) + msg = 'invalid value encountered in reduce' + warnings.filterwarnings('ignore', message=msg) + msg = 'All-NaN slice encountered' + warnings.filterwarnings('ignore', message=msg) + msg = ( + 'Possible precision loss when converting from ' + 'int64 to float64' + ) + warnings.filterwarnings('ignore', message=msg) + msg = 'Passing `np.nan` to mean no clipping in np.clip' + warnings.filterwarnings('ignore', message=msg) nans = data.filled(np.nan) try: - data = imresize(nans, - (self.grid.ny, self.grid.nx), - order=interp, mode='edge', - anti_aliasing=True) + data = imresize( + nans, + (self.grid.ny, self.grid.nx), + order=interpolation, + mode='edge', + anti_aliasing=True, + ) except RuntimeError: # For some order anti_aliasing doesnt work with 'edge' - data = imresize(nans, - (self.grid.ny, self.grid.nx), - order=interp, mode='edge', - anti_aliasing=False) + data = imresize( + nans, + (self.grid.ny, self.grid.nx), + order=interpolation, + mode='edge', + anti_aliasing=False, + ) return data @@ -534,18 +598,28 @@ def _check_data(self, data=None, crs=None, interp='nearest', if isinstance(crs, Grid): # Remap if overplot: - data = self.grid.map_gridded_data(data, crs, interp=interp, - out=self.data) + data = self.grid.map_gridded_data( + data, crs, interp=interp, out=self.data + ) else: data = self.grid.map_gridded_data(data, crs, interp=interp) else: - raise ValueError('crs should be a grid, not a proj') + msg = 'crs should be a grid, not a proj' + raise TypeError(msg) return data - def set_data(self, data=None, crs=None, interp='nearest', - overplot=False): - """Adds data to the plot. The data has to be georeferenced, i.e. by + def set_data( + self, + data: NDArray[Any] | xr.DataArray | None = None, + crs: str | pyproj.Proj | Grid | None = None, + interp: InterpChoices = 'nearest', + *, + overplot: bool = False, + ) -> None: + """Add data to the plot. + + The data has to be georeferenced, i.e. by setting crs (if omitted the data is assumed to be defined on the map's grid) @@ -556,40 +630,53 @@ def set_data(self, data=None, crs=None, interp='nearest', interp: 'nearest' (default) or 'linear', the interpolation algorithm overplot: add the data to an existing plot (useful for mosaics for example) - """ + """ # Check input if data is None: self.data = np.ma.zeros((self.grid.ny, self.grid.nx)) self.data.mask = self.data + 1 return - data = self._check_data(data=data, crs=crs, interp=interp, - overplot=overplot) + data = self._check_data( + data=data, crs=crs, interp=interp, overplot=overplot + ) DataLevels.set_data(self, data) - def set_contourf(self, data=None, crs=None, interp='nearest', **kwargs): - """Adds data to contourfill on the map. + def set_contourf( + self, + data: NDArray[Any] | xr.DataArray | None = None, + crs: pyproj.Proj | Grid | str | None = None, + interp: InterpChoices = 'nearest', + **kwargs, + ) -> None: + """Add data to contourfill on the map. Parameters ---------- - mask: bool array (2d) crs: the data coordinate reference system interp: 'nearest' (default) or 'linear', the interpolation algorithm kwargs: anything accepted by contourf - """ + """ # Check input if data is None: self._contourf_data = None return - self._contourf_data = self._check_data(data=data, crs=crs, - interp=interp) + self._contourf_data = self._check_data( + data=data, crs=crs, interp=interp + ) kwargs.setdefault('zorder', 1.4) self._contourf_kw = kwargs - def set_contour(self, data=None, crs=None, interp='nearest', **kwargs): - """Adds data to contour on the map. + def set_contour( + self, + data: NDArray[Any] | xr.DataArray | None = None, + crs: pyproj.Proj | Grid | str | None = None, + interp: InterpChoices = 'nearest', + **kwargs, + ) -> None: + """Add data to contour on the map. Parameters ---------- @@ -597,20 +684,28 @@ def set_contour(self, data=None, crs=None, interp='nearest', **kwargs): crs: the data coordinate reference system interp: 'nearest' (default) or 'linear', the interpolation algorithm kwargs: anything accepted by contour - """ + """ # Check input if data is None: self._contour_data = None return - self._contour_data = self._check_data(data=data, crs=crs, - interp=interp) + self._contour_data = self._check_data( + data=data, crs=crs, interp=interp + ) kwargs.setdefault('zorder', 1.4) self._contour_kw = kwargs - def set_geometry(self, geometry=None, crs=wgs84, text=None, - text_delta=(0.01, 0.01), text_kwargs=dict(), **kwargs): + def set_geometry( + self, + geometry=None, + crs=wgs84, + text=None, + text_delta=(0.01, 0.01), + text_kwargs=dict(), + **kwargs, + ) -> None: """Adds any Shapely geometry to the map. If called without arguments, it removes all previous geometries. @@ -633,15 +728,15 @@ def set_geometry(self, geometry=None, crs=wgs84, text=None, facecolor, linestyle, linewidth, alpha... """ - # Reset? if geometry is None: self._geometries = [] return # Transform - geom = gis.transform_geometry(geometry, crs=crs, - to_crs=self.grid.center_grid) + geom = gis.transform_geometry( + geometry, crs=crs, to_crs=self.grid.center_grid + ) # Text if text is not None: @@ -649,8 +744,7 @@ def set_geometry(self, geometry=None, crs=wgs84, text=None, x = x[0] + text_delta[0] * self.grid.nx sign = self.grid.dy / np.abs(self.grid.dy) y = y[0] + text_delta[1] * self.grid.ny * sign - self.set_text(x, y, text, crs=self.grid.center_grid, - **text_kwargs) + self.set_text(x, y, text, crs=self.grid.center_grid, **text_kwargs) # Save if 'Multi' in geom.geom_type: @@ -673,7 +767,6 @@ def set_text(self, x=None, y=None, text='', crs=wgs84, **kwargs): Keyword arguments will be passed to mpl's text() function. """ - # Reset? if x is None: self._text = [] @@ -685,8 +778,16 @@ def set_text(self, x=None, y=None, text='', crs=wgs84, **kwargs): x, y = self.grid.center_grid.transform(x, y, crs=crs) self._text.append((x, y, text, kwargs)) - def set_shapefile(self, shape=None, countries=False, oceans=False, - rivers=False, lakes=False, **kwargs): + def set_shapefile( + self, + shape: Path | str | None = None, + *, + countries: bool = False, + oceans: bool = False, + rivers: bool = False, + lakes: bool = False, + **kwargs, + ) -> None: """Add a shapefile to the plot. Salem is shipped with a few default settings for country borders, @@ -706,8 +807,8 @@ def set_shapefile(self, shape=None, countries=False, oceans=False, linewidths, colors, linestyles, ... For Polygons:: alpha, edgecolor, facecolor, fill, linestyle, linewidth, color, ... - """ + """ # See if the user wanted defaults settings if oceans: kwargs.setdefault('facecolor', (0.36862745, 0.64313725, 0.8)) @@ -731,7 +832,7 @@ def set_shapefile(self, shape=None, countries=False, oceans=False, # Reset? if shape is None: self._collections = [] - return + return None # Transform if isinstance(shape, pd.DataFrame): @@ -739,7 +840,7 @@ def set_shapefile(self, shape=None, countries=False, oceans=False, else: shape = sio.read_shapefile_to_grid(shape, grid=self.grid) if len(shape) == 0: - return + return None # Different collection for each type geomtype = shape.iloc[0].geometry.geom_type @@ -747,49 +848,68 @@ def set_shapefile(self, shape=None, countries=False, oceans=False, patches = [] for g in shape.geometry: if 'Multi' in g.geom_type: - for gg in g.geoms: - patches.append(PolygonPatch(gg)) + patches = [PolygonPatch(gg) for gg in g.geoms] else: - patches.append(PolygonPatch(g)) + patches = [PolygonPatch(g)] kwargs.setdefault('facecolor', 'none') if 'color' in kwargs: kwargs.setdefault('edgecolor', kwargs['color']) del kwargs['color'] self._collections.append(PatchCollection(patches, **kwargs)) - elif 'LineString' in geomtype: + return None + if 'LineString' in geomtype: lines = [] for g in shape.geometry: if 'Multi' in g.geom_type: - for gg in g.geoms: - lines.append(np.array(gg.coords)) + lines = [np.array(gg.coords) for gg in g.geoms] else: - lines.append(np.array(g.coords)) + lines = [np.array(g.coords)] self._collections.append(LineCollection(lines, **kwargs)) - else: - raise NotImplementedError(geomtype) + return None + raise NotImplementedError(geomtype) def _find_interval(self, max_nticks): """Quick n dirty function to find a suitable lonlat interval.""" - candidates = [0.001, 0.002, 0.005, - 0.01, 0.02, 0.05, - 0.1, 0.2, 0.5, - 1, 2, 5, 10, 20] + candidates = [ + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, + 2, + 5, + 10, + 20, + ] xx, yy = self.grid.pixcorner_ll_coordinates for inter in candidates: _xx = xx / inter _yy = yy / inter mm_x = [np.ceil(np.min(_xx)), np.floor(np.max(_xx))] mm_y = [np.ceil(np.min(_yy)), np.floor(np.max(_yy))] - nx = mm_x[1]-mm_x[0]+1 - ny = mm_y[1]-mm_y[0]+1 + nx = mm_x[1] - mm_x[0] + 1 + ny = mm_y[1] - mm_y[0] + 1 if np.max([nx, ny]) <= max_nticks: break return inter - def set_lonlat_contours(self, interval=None, xinterval=None, - yinterval=None, add_tick_labels=True, - add_xtick_labels=True, add_ytick_labels=True, - max_nticks=8, **kwargs): + def set_lonlat_contours( + self, + interval: float | None = None, + xinterval: float | None = None, + yinterval: float | None = None, + *, + add_tick_labels: bool = True, + add_xtick_labels: bool = True, + add_ytick_labels: bool = True, + max_nticks: int = 8, + **kwargs, + ) -> None: """Add longitude and latitude contours to the map. Calling it with interval=0 will remove all contours. @@ -813,8 +933,8 @@ def set_lonlat_contours(self, interval=None, xinterval=None, Ignore if ``interval`` is set to a value kwargs : {} any keyword accepted by contour() - """ + """ # Defaults if interval is None: interval = self._find_interval(max_nticks) @@ -835,10 +955,12 @@ def set_lonlat_contours(self, interval=None, xinterval=None, _yy = yy / yinterval mm_x = [np.ceil(np.min(_xx)), np.floor(np.max(_xx))] mm_y = [np.ceil(np.min(_yy)), np.floor(np.max(_yy))] - self.xtick_levs = (mm_x[0] + np.arange(mm_x[1]-mm_x[0]+1)) * \ - xinterval - self.ytick_levs = (mm_y[0] + np.arange(mm_y[1]-mm_y[0]+1)) * \ - yinterval + self.xtick_levs = ( + mm_x[0] + np.arange(mm_x[1] - mm_x[0] + 1) + ) * xinterval + self.ytick_levs = ( + mm_y[0] + np.arange(mm_y[1] - mm_y[0] + 1) + ) * yinterval # Decide on float format d = np.array(['4', '3', '2', '1', '0']) @@ -852,7 +974,7 @@ def set_lonlat_contours(self, interval=None, xinterval=None, if add_tick_labels: if add_xtick_labels: _xx = xx[0 if self.origin == 'lower' else -1, :] - _xi = np.arange(self.grid.nx+1) + _xi = np.arange(self.grid.nx + 1) for xl in self.xtick_levs: if (xl > _xx[-1]) or (xl < _xx[0]): continue @@ -864,7 +986,7 @@ def set_lonlat_contours(self, interval=None, xinterval=None, self.xtick_val.append(label) if add_ytick_labels: _yy = np.sort(yy[:, 0]) - _yi = np.arange(self.grid.ny+1) + _yi = np.arange(self.grid.ny + 1) if self.origin == 'upper': _yi = _yi[::-1] for yl in self.ytick_levs: @@ -884,9 +1006,10 @@ def set_lonlat_contours(self, interval=None, xinterval=None, kwargs.setdefault('zorder', 1.5) self.ll_contour_kw = kwargs - def _shading_base(self, slope=None, relief_factor=0.7): + def _shading_base( + self, slope: NDArray[Any] | None = None, relief_factor: float = 0.7 + ) -> None: """Compute the shading factor out of the slope.""" - # reset? if slope is None: self.slope = None @@ -896,7 +1019,7 @@ def _shading_base(self, slope=None, relief_factor=0.7): p = np.nonzero(slope > 0) if len(p[0]) > 0: temp = np.clip(slope[p] / (2 * np.std(slope)), -1, 1) - slope[p] = 0.4 * np.sin(0.5*np.pi*temp) + slope[p] = 0.4 * np.sin(0.5 * np.pi * temp) self.relief_factor = relief_factor self.slope = slope @@ -913,28 +1036,33 @@ def set_topography(self, topo=None, crs=None, relief_factor=0.7, **kwargs): Returns ------- the topography if needed (bonus) - """ + """ if topo is None: self._shading_base() - return + return None kwargs.setdefault('interp', 'spline') if isinstance(topo, str): - _, ext = os.path.splitext(topo) - if ext.lower() == '.tif': + topo = Path(topo) + if isinstance(topo, Path): + ext = topo.suffix.lower() + if ext == '.tif': g = GeoTiff(topo) # Spare memory ex = self.grid.extent_in_crs(crs=wgs84) # l, r, b, t - g.set_subset(corners=((ex[0], ex[2]), (ex[1], ex[3])), - crs=wgs84, margin=10) + g.set_subset( + corners=((ex[0], ex[2]), (ex[1], ex[3])), + crs=wgs84, + margin=10, + ) z = g.get_vardata() z[z < -999] = 0 z = self.grid.map_gridded_data(z, g.grid, **kwargs) else: - raise ValueError('File extension not recognised: {}' - .format(ext)) + msg = f'File extension not recognised: {ext}' + raise ValueError(msg) else: z = self._check_data(topo, crs=crs, **kwargs) @@ -951,9 +1079,14 @@ def set_topography(self, topo=None, crs=None, relief_factor=0.7, **kwargs): self._shading_base(dx - dy, relief_factor=relief_factor) return z - def set_rgb(self, img=None, crs=None, interp='nearest', - natural_earth=None): - """Manually force to a rgb img + def set_rgb( + self, + img: NDArray[Any] | None = None, + crs: pyproj.Proj | Grid | str | None = None, + interp: InterpChoices = 'nearest', + natural_earth: str | None = None, + ) -> None: + """Manually force to a rgb img. Parameters ---------- @@ -966,18 +1099,23 @@ def set_rgb(self, img=None, crs=None, interp='nearest', natural_earth : str 'lr', 'mr' or 'hr' (low res, medium or high res) natural earth background img - """ + """ if natural_earth is not None: from matplotlib.image import imread + with warnings.catch_warnings(): # DecompressionBombWarning - warnings.simplefilter("ignore") + warnings.simplefilter('ignore') img = imread(utils.get_natural_earth_file(natural_earth)) ny, nx = img.shape[0], img.shape[1] - dx, dy = 360. / nx, 180. / ny - grid = Grid(nxny=(nx, ny), dxdy=(dx, -dy), x0y0=(-180., 90.), - pixel_ref='corner').center_grid + dx, dy = 360.0 / nx, 180.0 / ny + grid = Grid( + nxny=(nx, ny), + dxdy=(dx, -dy), + x0y0=(-180.0, 90.0), + pixel_ref='corner', + ).center_grid return self.set_rgb(img, grid, interp='linear') if (len(img.shape) != 3) or (img.shape[-1] not in [3, 4]): @@ -989,9 +1127,17 @@ def set_rgb(self, img=None, crs=None, interp='nearest', out.append(self._check_data(img[..., i], crs=crs, interp=interp)) self._rgb = np.dstack(out) - def set_scale_bar(self, location=None, length=None, maxlen=0.25, - add_bbox=False, bbox_dx=1.2, bbox_dy=1.2, - bbox_kwargs=None, **kwargs): + def set_scale_bar( + self, + location=None, + length=None, + maxlen=0.25, + add_bbox=False, + bbox_dx=1.2, + bbox_dy=1.2, + bbox_kwargs=None, + **kwargs, + ): """Add a legend bar showing the scale to the plot. Parameters @@ -1019,8 +1165,8 @@ def set_scale_bar(self, location=None, length=None, maxlen=0.25, any kwarg accepted by ``set_geometry``. Defaults are put on ``color``, ``linewidth``, ``text``, ``text_kwargs``... But you can do whatever you want - """ + """ x0, x1, y0, y1 = self.grid.extent # Find a sensible length for the scale @@ -1028,24 +1174,27 @@ def set_scale_bar(self, location=None, length=None, maxlen=0.25, length = utils.nice_scale(x1 - x0, maxlen=maxlen) if location is None: - location = (0.96 - length/2/(x1 - x0), 0.04) + location = (0.96 - length / 2 / (x1 - x0), 0.04) # scalebar center location in proj coordinates sbcx, sbcy = x0 + (x1 - x0) * location[0], y0 + (y1 - y0) * location[1] # coordinates for the scalebar - line = LineString(([sbcx - length/2, sbcy], [sbcx + length/2, sbcy])) + line = LineString( + ([sbcx - length / 2, sbcy], [sbcx + length / 2, sbcy]) + ) # Of the bounding box - bbox = [[sbcx - length / 2 * bbox_dx, sbcy - length / 4 * bbox_dy], - [sbcx - length / 2 * bbox_dx, sbcy + length / 4 * bbox_dy], - [sbcx + length / 2 * bbox_dx, sbcy + length / 4 * bbox_dy], - [sbcx + length / 2 * bbox_dx, sbcy - length / 4 * bbox_dy], - ] + bbox = [ + [sbcx - length / 2 * bbox_dx, sbcy - length / 4 * bbox_dy], + [sbcx - length / 2 * bbox_dx, sbcy + length / 4 * bbox_dy], + [sbcx + length / 2 * bbox_dx, sbcy + length / 4 * bbox_dy], + [sbcx + length / 2 * bbox_dx, sbcy - length / 4 * bbox_dy], + ] # Units if gis.proj_is_latlong(self.grid.proj): units = 'deg' - elif length >= 1000.: + elif length >= 1000.0: length /= 1000 units = 'km' else: @@ -1055,7 +1204,7 @@ def set_scale_bar(self, location=None, length=None, maxlen=0.25, length = int(length) # Defaults kwargs.setdefault('color', 'k') - kwargs.setdefault('text', '{} '.format(length) + units) + kwargs.setdefault('text', f'{length} ' + units) kwargs.setdefault('text_delta', (0.0, 0.015)) kwargs.setdefault('linewidth', 3) kwargs.setdefault('zorder', 99) @@ -1074,43 +1223,39 @@ def set_scale_bar(self, location=None, length=None, maxlen=0.25, self.set_geometry(poly, crs=self.grid.proj, **bbox_kwargs) self.set_geometry(line, crs=self.grid.proj, **kwargs) - def transform(self, crs=wgs84, ax=None): - """Get a matplotlib transform object for a given reference system + def transform( + self, crs: pyproj.Proj | Grid | str = wgs84, ax: Axes | None = None + ) -> MPLTranform: + """Get a matplotlib transform object for a given reference system. Parameters ---------- crs : coordinate reference system a Grid or a Proj, basically. If a grid is given, the grid's proj will be used. + ax : matplotlib.axes.Axes + the axis to use for the transformation Returns ------- a matplotlib.transforms.Transform instance + """ - try: + if isinstance(crs, (xr.DataArray, xr.Dataset)): crs = crs.salem.grid - except: - pass - try: + if isinstance(crs, Grid): crs = crs.proj - except: - pass - return _SalemTransform(target_grid=self.grid, - source_crs=crs, ax=ax) + return _SalemTransform(target_grid=self.grid, source_crs=crs, ax=ax) - def to_rgb(self): + def to_rgb(self) -> NDArray[Any]: """Transform the data to a RGB image and add topographical shading.""" - - if self._rgb is None: - toplot = DataLevels.to_rgb(self) - else: - toplot = self._rgb + toplot = DataLevels.to_rgb(self) if self._rgb is None else self._rgb # Shading if self.slope is not None: # remove alphas? try: - pno = np.where(toplot[:, :, 3] == 0.) + pno = np.where(toplot[:, :, 3] == 0.0) for i in [0, 1, 2]: toplot[pno[0], pno[1], i] = 1 toplot[:, :, 3] = 1 @@ -1134,15 +1279,16 @@ def plot(self, ax): Returns a dict containing the primitives of the various plot calls """ - - out = {'imshow': None, - 'contour': [], - 'contourf': [], - } + out = { + 'imshow': None, + 'contour': [], + 'contourf': [], + } # Image is the easiest - out['imshow'] = ax.imshow(self.to_rgb(), interpolation='none', - origin=self.origin) + out['imshow'] = ax.imshow( + self.to_rgb(), interpolation='none', origin=self.origin + ) ax.autoscale(enable=False) # Contour @@ -1162,13 +1308,19 @@ def plot(self, ax): # Lon lat contours lon, lat = self.grid.pixcorner_ll_coordinates if len(self.xtick_levs) > 0: - ax.contour(lon, levels=self.xtick_levs, - extent=(-0.5, self.grid.nx-0.5, -0.5, self.grid.ny), - **self.ll_contour_kw) + ax.contour( + lon, + levels=self.xtick_levs, + extent=(-0.5, self.grid.nx - 0.5, -0.5, self.grid.ny), + **self.ll_contour_kw, + ) if len(self.ytick_levs) > 0: - ax.contour(lat, levels=self.ytick_levs, - extent=(-0.5, self.grid.nx, -0.5, self.grid.ny-0.5), - **self.ll_contour_kw) + ax.contour( + lat, + levels=self.ytick_levs, + extent=(-0.5, self.grid.nx, -0.5, self.grid.ny - 0.5), + **self.ll_contour_kw, + ) # Geometries for g, kwargs in self._geometries: @@ -1207,8 +1359,8 @@ def plot(self, ax): # Ticks if (len(self.xtick_pos) > 0) or (len(self.ytick_pos) > 0): - ax.xaxis.set_ticks(np.array(self.xtick_pos)-0.5) - ax.yaxis.set_ticks(np.array(self.ytick_pos)-0.5) + ax.xaxis.set_ticks(np.array(self.xtick_pos) - 0.5) + ax.yaxis.set_ticks(np.array(self.ytick_pos) - 0.5) ax.set_xticklabels(self.xtick_val) ax.set_yticklabels(self.ytick_val) ax.xaxis.set_zorder(5) @@ -1221,8 +1373,7 @@ def plot(self, ax): def plot_polygon(ax, poly, edgecolor='black', **kwargs): - """ Plot a single Polygon geometry """ - + """Plot a single Polygon geometry""" a = np.asarray(poly.exterior.coords) # without Descartes, we could make a Patch of exterior ax.add_patch(PolygonPatch(poly, **kwargs)) @@ -1233,9 +1384,7 @@ def plot_polygon(ax, poly, edgecolor='black', **kwargs): class _SalemTransform(MPLTranform): - """ - A transform class for mpl axes using Grids. - """ + """A transform class for mpl axes using Grids.""" input_dims = 2 output_dims = 2 @@ -1243,13 +1392,14 @@ class _SalemTransform(MPLTranform): has_inverse = False def __init__(self, target_grid=None, source_crs=None, ax=None): - """ Instanciate. + """Instanciate. Parameters ---------- target_grid : salem.Grid typically, the map grid source_grid + """ self.source_crs = source_crs self.target_grid = target_grid diff --git a/salem/sio.py b/salem/sio.py index c6a739f..2a9ee72 100644 --- a/salem/sio.py +++ b/salem/sio.py @@ -1,31 +1,46 @@ -""" -Input output functions (but mostly input) -""" -from __future__ import division +"""Input output functions (but mostly input).""" -import os -from glob import glob +from __future__ import annotations + +import contextlib import pickle from datetime import datetime from functools import partial -import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable -import numpy as np -import netCDF4 import cftime - -from salem.utils import memory, cached_shapefile_path -from salem import gis, utils, wgs84, wrftools, proj_to_cartopy - +import netCDF4 +import numpy as np import xarray as xr from xarray.backends.netCDF4_ import NetCDF4DataStore from xarray.core import dtypes + +from salem import gis, utils, wgs84, wrftools +from salem.gis import check_crs, proj_to_cartopy +from salem.utils import ( + cached_shapefile_path, + deprecated_arg, + import_if_exists, + memory, +) + +if TYPE_CHECKING: + import threading + + import pandas as pd + from matplotlib import axes + from numpy._typing import NDArray + + from salem.graphics import Map + try: - from xarray.backends.locks import (NETCDFC_LOCK, HDF5_LOCK, combine_locks) - NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + from xarray.backends.locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks + + netcdf4_python_lock = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) except ImportError: # xarray < v0.11 - from xarray.backends.api import _default_lock as NETCDF4_PYTHON_LOCK + from xarray.backends.api import _default_lock as netcdf4_python_lock try: from xarray.core.pycompat import basestring except ImportError: @@ -33,11 +48,25 @@ basestring = str # Locals -from salem import transform_proj +from salem.gis import transform_proj +has_cartopy = import_if_exists('cartopy') +if has_cartopy: + from cartopy import crs +has_geopandas = import_if_exists('geopandas') +if has_geopandas: + from geopandas import GeoDataFrame, read_file -def read_shapefile(fpath, cached=False): - """Reads a shapefile using geopandas. +tolerance_by_proj_id = { + 1: 1e-3, + 2: 5e-3, + 3: 1e-3, + 6: 1e-3, +} + + +def read_shapefile(fpath: Path, *, cached: bool = False) -> GeoDataFrame: + """Read a shapefile using geopandas. For convenience, it adds four columns to the dataframe: [min_x, max_x, min_y, max_y] @@ -46,60 +75,58 @@ def read_shapefile(fpath, cached=False): caching utility (cached=True). This will save a pickle of the shapefile in the cache directory. """ - - import geopandas as gpd - - _, ext = os.path.splitext(fpath) - - if ext.lower() in ['.shp', '.p']: + if not has_geopandas: + msg = 'read_shapefile requires geopandas to be installed' + raise RuntimeError(msg) + if fpath.suffix.lower() in ['.shp', '.p']: if cached: cpath = cached_shapefile_path(fpath) # unpickle if cached, read and pickle if not - if os.path.exists(cpath): - with open(cpath, 'rb') as f: - out = pickle.load(f) - else: - out = read_shapefile(fpath, cached=False) - with open(cpath, 'wb') as f: - pickle.dump(out, f) - else: - out = gpd.read_file(fpath) - out['min_x'] = [g.bounds[0] for g in out.geometry] - out['max_x'] = [g.bounds[2] for g in out.geometry] - out['min_y'] = [g.bounds[1] for g in out.geometry] - out['max_y'] = [g.bounds[3] for g in out.geometry] - else: - raise ValueError('File extension not recognised: {}'.format(ext)) - - return out + if cpath.exists(): + with cpath.open('rb') as f: + return pickle.load(f) + out = read_shapefile(fpath, cached=False) + with cpath.open('wb') as f: + pickle.dump(out, f) + return out + out = read_file(fpath) + out['min_x'] = [g.bounds[0] for g in out.geometry] + out['max_x'] = [g.bounds[2] for g in out.geometry] + out['min_y'] = [g.bounds[1] for g in out.geometry] + out['max_y'] = [g.bounds[3] for g in out.geometry] + return out + msg = f'File extension not recognised: {fpath.suffix}' + raise ValueError(msg) @memory.cache(ignore=['grid']) -def _memory_shapefile_to_grid(shape_cpath, grid=None, - nxny=None, pixel_ref=None, x0y0=None, dxdy=None, - proj=None): - """Quick solution using joblib in order to not transform many times the - same shape (useful for maps). +def _memory_shapefile_to_grid( + shape_cpath: Path, grid: gis.Grid, **kwargs +) -> gis.Grid: + """Quick solution to not transform many times the same shape (useful for maps). - Since grid is a complex object, joblib seems to have trouble with it. + Using joblib. Since grid is a complex object, joblib seems to have trouble with it. So joblib is checking its cache according to the grid params while the job is done with grid. """ - shape = read_shapefile(shape_cpath, cached=True) e = grid.extent_in_crs(crs=shape.crs) - p = np.nonzero(~((shape['min_x'].to_numpy() > e[1]) | - (shape['max_x'].to_numpy() < e[0]) | - (shape['min_y'].to_numpy() > e[3]) | - (shape['max_y'].to_numpy() < e[2]))) + p = np.nonzero( + ~( + (shape['min_x'].to_numpy() > e[1]) + | (shape['max_x'].to_numpy() < e[0]) + | (shape['min_y'].to_numpy() > e[3]) + | (shape['max_y'].to_numpy() < e[2]) + ) + ) shape = shape.iloc[p] - shape = gis.transform_geopandas(shape, to_crs=grid, inplace=True) - return shape + return gis.transform_geopandas(shape, to_crs=grid, inplace=True) -def read_shapefile_to_grid(fpath, grid): - """Same as read_shapefile but directly transformed to a grid. +def read_shapefile_to_grid(fpath: Path, grid: gis.Grid) -> gis.Grid: + """Read a shapefile and transform to salem.Grid object. + Same as read_shapefile but directly transformed to a grid. The whole thing is cached so that the second call will will be much faster. @@ -107,32 +134,37 @@ def read_shapefile_to_grid(fpath, grid): ---------- fpath: path to the file grid: the arrival grid - """ + Returns + ------- + a salem.Grid object + + """ # ensure it is a cached pickle (copy code smell) shape_cpath = cached_shapefile_path(fpath) - if not os.path.exists(shape_cpath): + if not shape_cpath.exists(): out = read_shapefile(fpath, cached=False) - with open(shape_cpath, 'wb') as f: + with shape_cpath.open('wb') as f: pickle.dump(out, f) - return _memory_shapefile_to_grid(shape_cpath, grid=grid, - **grid.to_dict()) + return _memory_shapefile_to_grid(shape_cpath, grid=grid, **grid.to_dict()) # TODO: remove this once we sure that we have all WRF files right tmp_check_wrf = True -def _wrf_grid_from_dataset(ds): +def _wrf_grid_from_dataset(ds: xr.Dataset) -> gis.Grid: """Get the WRF projection out of the file.""" pargs = {} if hasattr(ds, 'PROJ_ENVI_STRING'): # HAR and other TU Berlin files dx = ds.GRID_DX if hasattr(ds, 'GRID_DX') else ds.DX dy = ds.GRID_DY if hasattr(ds, 'GRID_DY') else ds.DY - if ds.PROJ_NAME in ['Lambert Conformal Conic', - 'WRF Lambert Conformal']: + if ds.PROJ_NAME in [ + 'Lambert Conformal Conic', + 'WRF Lambert Conformal', + ]: proj_id = 1 pargs['lat_1'] = ds.PROJ_STANDARD_PAR1 pargs['lat_2'] = ds.PROJ_STANDARD_PAR2 @@ -141,7 +173,7 @@ def _wrf_grid_from_dataset(ds): pargs['center_lon'] = ds.PROJ_CENTRAL_LON elif ds.PROJ_NAME in ['lat-lon']: proj_id = 6 - elif "mercator" in ds.PROJ_NAME.lower(): + elif 'mercator' in ds.PROJ_NAME.lower(): proj_id = 3 pargs['lat_ts'] = ds.TRUELAT1 pargs['center_lon'] = ds.CEN_LON @@ -160,36 +192,40 @@ def _wrf_grid_from_dataset(ds): pargs['center_lon'] = ds.CEN_LON proj_id = ds.MAP_PROJ - if proj_id == 1: + wrf_projection_key = { + 'Lambert Conformal': 1, + 'Polar Stereographic': 2, + 'Mercator': 3, + 'Lat-long': 6, + } + if proj_id == wrf_projection_key['Lambert Conformal']: # Lambert - p4 = '+proj=lcc +lat_1={lat_1} +lat_2={lat_2} ' \ - '+lat_0={lat_0} +lon_0={lon_0} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' - p4 = p4.format(**pargs) - elif proj_id == 2: + p4 = ( + f"+proj=lcc +lat_1={pargs['lat_1']} +lat_2={pargs['lat_2']} " + f"+lat_0={pargs['lat_0']} +lon_0={pargs['lon_0']} " + "+x_0=0 +y_0=0 +a=6370000 +b=6370000" + ) + elif proj_id == wrf_projection_key['Polar Stereographic']: # Polar stereo - p4 = '+proj=stere +lat_ts={lat_1} +lon_0={lon_0} +lat_0=90.0' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' - p4 = p4.format(**pargs) - elif proj_id == 3: + p4 = ( + f"+proj=stere +lat_ts={pargs['lat_1']} +lon_0={pargs['lon_0']} +lat_0=90.0 " + f"+x_0=0 +y_0=0 +a=6370000 +b=6370000" + ) + elif proj_id == wrf_projection_key['Mercator']: # Mercator - p4 = '+proj=merc +lat_ts={lat_1} ' \ - '+lon_0={center_lon} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' - p4 = p4.format(**pargs) - elif proj_id == 6: + p4 = ( + f"+proj=merc +lat_ts={pargs['lat_1']} " + f"+lon_0={pargs['center_lon']} " + "+x_0=0 +y_0=0 +a=6370000 +b=6370000" + ) + elif proj_id == wrf_projection_key['Lat-long']: # Lat-long - p4 = '+proj=eqc ' \ - '+lon_0={lon_0} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' - p4 = p4.format(**pargs) + p4 = f'+proj=eqc +lon_0={pargs['lon_0']} +x_0=0 +y_0=0 +a=6370000 +b=6370000' else: - raise NotImplementedError('WRF proj not implemented yet: ' - '{}'.format(proj_id)) + msg = f'WRF proj not understood: {proj_id}' + raise NotImplementedError(msg) proj = gis.check_crs(p4) - if proj is None: - raise RuntimeError('WRF proj not understood: {}'.format(p4)) # Here we have to accept xarray and netCDF4 datasets try: @@ -201,13 +237,13 @@ def _wrf_grid_from_dataset(ds): ny = ds.sizes['south_north'] if hasattr(ds, 'PROJ_ENVI_STRING'): # HAR - x0 = ds['west_east'][0] - y0 = ds['south_north'][0] + x0 = float(ds['west_east'][0]) + y0 = float(ds['south_north'][0]) else: # Normal WRF file e, n = gis.transform_proj(wgs84, proj, cen_lon, cen_lat) - x0 = -(nx-1) / 2. * dx + e # DL corner - y0 = -(ny-1) / 2. * dy + n # DL corner + x0 = -(nx - 1) / 2.0 * dx + e # DL corner + y0 = -(ny - 1) / 2.0 * dy + n # DL corner grid = gis.Grid(nxny=(nx, ny), x0y0=(x0, y0), dxdy=(dx, dy), proj=proj) if tmp_check_wrf: @@ -225,38 +261,43 @@ def _wrf_grid_from_dataset(ds): reflon = ds.variables['lon'] reflat = ds.variables['lat'] else: - raise RuntimeError("couldn't test for correct WRF lon-lat") + msg = "couldn't test for correct WRF lon-lat" + raise RuntimeError(msg) if len(reflon.shape) == 3: reflon = reflon[0, :, :] reflat = reflat[0, :, :] mylon, mylat = grid.ll_coordinates - atol = 5e-3 if proj_id == 2 else 1e-3 + atol = tolerance_by_proj_id[proj_id] check = np.isclose(reflon, mylon, atol=atol) if not np.all(check): n_pix = np.sum(~check) maxe = np.max(np.abs(reflon - mylon)) if maxe < (360 - atol): - warnings.warn('For {} grid points, the expected accuracy ({}) ' - 'of our lons did not match those of the WRF ' - 'file. Max error: {}'.format(n_pix, atol, maxe)) + msg = """ + For {n_pix} grid points, the expected accuracy ({atol}) of our lons + did not match those of the WRF file. Max error: {maxe} + )""".format(n_pix=n_pix, atol=atol, maxe=maxe) + deprecated_arg(msg) check = np.isclose(reflat, mylat, atol=atol) if not np.all(check): n_pix = np.sum(~check) maxe = np.max(np.abs(reflat - mylat)) - warnings.warn('For {} grid points, the expected accuracy ({}) ' - 'of our lats did not match those of the WRF file. ' - 'Max error: {}'.format(n_pix, atol, maxe)) + msg = """ + For {n_pix} grid points, the expected accuracy ({atol}) of our lats + did not match those of the WRF file. Max error: {maxe} + )""".format(n_pix=n_pix, atol=atol, maxe=maxe) + deprecated_arg(msg) return grid -def _lonlat_grid_from_dataset(ds): +def _lonlat_grid_from_dataset(ds: xr.Dataset) -> gis.Grid | None: """Seek for longitude and latitude coordinates.""" - # Do we have some standard names as variable? - vns = ds.variables.keys() + keys_view = ds.variables.keys() + vns = [str(k) for k in keys_view] xc = utils.str_in_list(vns, utils.valid_names['x_dim']) yc = utils.str_in_list(vns, utils.valid_names['y_dim']) @@ -278,37 +319,42 @@ def _lonlat_grid_from_dataset(ds): lat = ds.variables[y][:] # double check for dubious variables - if not utils.str_in_list([x], utils.valid_names['lon_var']) or \ - not utils.str_in_list([y], utils.valid_names['lat_var']): - # name not usual. see if at least the range follows some conv - if (np.max(np.abs(lon)) > 360.1) or (np.max(np.abs(lat)) > 90.1): - return None + # name not usual. see if at least the range follows some conv + if ( + not utils.str_in_list([x], utils.valid_names['lon_var']) + or not utils.str_in_list([y], utils.valid_names['lat_var']) + ) and ((np.max(np.abs(lon)) > 360.1) or (np.max(np.abs(lat)) > 90.1)): + return None # Make the grid - dx = lon[1]-lon[0] - dy = lat[1]-lat[0] - args = dict(nxny=(lon.shape[0], lat.shape[0]), proj=wgs84, dxdy=(dx, dy), - x0y0=(lon[0], lat[0])) + dx = lon[1] - lon[0] + dy = lat[1] - lat[0] + args = { + 'nxny': (lon.shape[0], lat.shape[0]), + 'proj': wgs84, + 'dxdy': (dx, dy), + 'x0y0': (lon[0], lat[0]), + } return gis.Grid(**args) -def _salem_grid_from_dataset(ds): +def _salem_grid_from_dataset(ds: xr.Dataset) -> gis.Grid | None: """Seek for coordinates that Salem might have created. Current convention: x_coord, y_coord, pyproj_srs as attribute """ - # Projection try: - proj = ds.pyproj_srs + proj = str(ds.pyproj_srs) except AttributeError: proj = None - proj = gis.check_crs(proj) if proj is None: return None + proj = gis.check_crs(proj) # Do we have some standard names as variable? - vns = ds.variables.keys() + keys_view = ds.variables.keys() + vns = [str(k) for k in keys_view] xc = utils.str_in_list(vns, utils.valid_names['x_dim']) yc = utils.str_in_list(vns, utils.valid_names['y_dim']) @@ -330,14 +376,18 @@ def _salem_grid_from_dataset(ds): y = ds.variables[y][:] # Make the grid - dx = x[1]-x[0] - dy = y[1]-y[0] - args = dict(nxny=(x.shape[0], y.shape[0]), proj=proj, dxdy=(dx, dy), - x0y0=(x[0], y[0])) + dx = x[1] - x[0] + dy = y[1] - y[0] + args = { + 'nxny': (x.shape[0], y.shape[0]), + 'proj': proj, + 'dxdy': (dx, dy), + 'x0y0': (x[0], y[0]), + } return gis.Grid(**args) -def grid_from_dataset(ds): +def grid_from_dataset(ds: xr.Dataset) -> gis.Grid | None: """Find out if the dataset contains enough info for Salem to understand. ``ds`` can be an xarray dataset or a NetCDF dataset, or anything @@ -345,7 +395,6 @@ def grid_from_dataset(ds): Returns a :py:class:`~salem.Grid` if successful, ``None`` otherwise """ - # try if it is a salem file out = _salem_grid_from_dataset(ds) if out is not None: @@ -360,15 +409,17 @@ def grid_from_dataset(ds): return _lonlat_grid_from_dataset(ds) -def netcdf_time(ncobj, monthbegin=False): +def netcdf_time( + ncobj: netCDF4.Dataset, *, monthbegin: bool = False +) -> pd.DatetimeIndex | None: """Check if the netcdf file contains a time that Salem understands.""" - import pandas as pd time = None try: - vt = utils.str_in_list(ncobj.variables.keys(), - utils.valid_names['time_var'])[0] + keys_view = ncobj.variables.keys() + vns = [str(k) for k in keys_view] + vt = utils.str_in_list(vns, utils.valid_names['time_var'])[0] except IndexError: # no time variable return None @@ -384,51 +435,62 @@ def netcdf_time(ncobj, monthbegin=False): except AttributeError: stimes = ncobj.variables['Times'][:] for t in stimes: - time.append(pd.to_datetime(t.tobytes().decode(), - errors='raise', - format='%Y-%m-%d_%H:%M:%S')) + time.append( + pd.to_datetime( + t.tobytes().decode(), + errors='raise', + format='%Y-%m-%d_%H:%M:%S', + ) + ) elif vt is not None: # CF time var = ncobj.variables[vt] try: # We want python times because pandas doesn't understand # CFtime - time = cftime.num2date(var[:], var.units, - only_use_cftime_datetimes=False, - only_use_python_datetimes=True) + time = cftime.num2date( + var[:], + var.units, + only_use_cftime_datetimes=False, + only_use_python_datetimes=True, + ) except TypeError: # Old versions of cftime did return python times when possible time = cftime.num2date(var[:], var.units) if monthbegin: # sometimes monthly data is centered in the month (stupid) - time = [datetime(t.year, t.month, 1) for t in time] + time = [ + datetime(t.year, t.month, 1, tzinfo=t.tzinfo) for t in time + ] - return time + if time is None: + return None + return pd.DatetimeIndex(time) -class _XarrayAccessorBase(object): +class _XarrayAccessorBase: """Common logic for for both data structures (DataArray and Dataset). http://xarray.pydata.org/en/stable/internals.html#extending-xarray """ - def __init__(self, xarray_obj): - + def __init__(self, xarray_obj: xr.Dataset | xr.DataArray) -> None: self._obj = xarray_obj if isinstance(xarray_obj, xr.DataArray): xarray_obj = xarray_obj.to_dataset(name='var') - try: # maybe there was already some georef + # maybe there was already some georef + with contextlib.suppress(Exception): xarray_obj.attrs['pyproj_srs'] = xarray_obj['var'].pyproj_srs - except: - pass self.grid = grid_from_dataset(xarray_obj) if self.grid is None: - raise RuntimeError('dataset Grid not understood.') + msg = 'dataset Grid not understood.' + raise RuntimeError(msg) - dn = xarray_obj.sizes.keys() + keys_view = xarray_obj.sizes.keys() + dn = [str(k) for k in keys_view] self.x_dim = utils.str_in_list(dn, utils.valid_names['x_dim'])[0] self.y_dim = utils.str_in_list(dn, utils.valid_names['y_dim'])[0] dim = utils.str_in_list(dn, utils.valid_names['t_dim']) @@ -436,16 +498,25 @@ def __init__(self, xarray_obj): dim = utils.str_in_list(dn, utils.valid_names['z_dim']) self.z_dim = dim[0] if dim else None - def subset(self, margin=0, ds=None, **kwargs): - """subset(self, margin=0, shape=None, geometry=None, grid=None, - corners=None, crs=wgs84, roi=None) + def subset( + self, + margin: int = 0, + ds: xr.Dataset | xr.DataArray | None = None, + **kwargs, + ) -> xr.Dataset | xr.DataArray: + """Get a subset of the dataset. - Get a subset of the dataset. + subset(self, margin=0, shape=None, geometry=None, grid=None, + corners=None, crs=wgs84, roi=None) Accepts all keywords of :py:func:`~Grid.roi` Parameters ---------- + margin : int + add a margin to the region to subset (can be negative!). Can + be used a single keyword, too: set_subset(margin=-5) will remove + five pixels from each boundary of the dataset. ds : Dataset or DataArray form the ROI from the extent of the Dataset or DataArray shape : str @@ -461,12 +532,8 @@ def subset(self, margin=0, ds=None, **kwargs): coordinate reference system of the geometry and corners roi : ndarray a mask for the region of interest to subset the dataset onto - margin : int - add a margin to the region to subset (can be negative!). Can - be used a single keyword, too: set_subset(margin=-5) will remove - five pixels from each boundary of the dataset. - """ + """ if ds is not None: grid = ds.salem.grid kwargs.setdefault('grid', grid) @@ -480,17 +547,25 @@ def subset(self, margin=0, ds=None, **kwargs): sub_x = [np.min(ids[1]) - margin, np.max(ids[1]) + margin] sub_y = [np.min(ids[0]) - margin, np.max(ids[0]) + margin] - out_ds = self._obj[{self.x_dim: slice(sub_x[0], sub_x[1]+1), - self.y_dim: slice(sub_y[0], sub_y[1]+1)} - ] - return out_ds - - def roi(self, ds=None, **kwargs): - """roi(self, shape=None, geometry=None, grid=None, corners=None, + return self._obj[ + { + self.x_dim: slice(sub_x[0], sub_x[1] + 1), + self.y_dim: slice(sub_y[0], sub_y[1] + 1), + } + ] + + def roi( + self, + ds: xr.Dataset | xr.DataArray | None = None, + grid: gis.Grid | None = None, + other: float | np.ndarray | xr.Dataset | xr.DataArray | None = None, + **kwargs, + ) -> xr.Dataset | xr.DataArray: + """Make a region of interest (ROI) for the dataset. + + roi(self, shape=None, geometry=None, grid=None, corners=None, crs=wgs84, roi=None, all_touched=False, other=None) - Make a region of interest (ROI) for the dataset. - All grid points outside the ROI will be masked out. Parameters @@ -518,6 +593,7 @@ def roi(self, ds=None, **kwargs): Value to use for locations in this object where cond is False. By default, these locations filled with NA. As in http://xarray.pydata.org/en/stable/generated/xarray.DataArray.where.html + """ other = kwargs.pop('other', dtypes.NA) if ds is not None: @@ -544,47 +620,60 @@ def roi(self, ds=None, **kwargs): out.variables[v].attrs['pyproj_srs'] = self.grid.proj.srs return out - def get_map(self, **kwargs): + def get_map(self, **kwargs) -> Map: """Make a salem.Map out of the dataset. All keywords are passed to :py:class:salem.Map """ - from salem.graphics import Map + return Map(self.grid, **kwargs) - def _quick_map(self, obj, ax=None, interp='nearest', **kwargs): + def _quick_map( + self, + obj: xr.DataArray, + ax: axes.Axes | None = None, + interp: str = 'nearest', + **kwargs, + ) -> Map: """Make a plot of a data array.""" - # some metadata? - title = obj.name or '' - if obj._title_for_slice(): - title += ' (' + obj._title_for_slice() + ')' - cb = obj.attrs['units'] if 'units' in obj.attrs else '' + title = str(obj.name) or '' + title_add = obj._title_for_slice() + if title_add: + title += ' ({})'.format(str(title_add)) + cb = obj.attrs.get('units', '') smap = self.get_map(**kwargs) smap.set_data(obj.values, interp=interp) smap.visualize(ax=ax, title=title, cbar_title=cb) return smap - def cartopy(self): + def cartopy(self) -> crs.Projection: """Get a cartopy.crs.Projection for this dataset.""" return proj_to_cartopy(self.grid.proj) - def _apply_transform(self, transform, grid, other, return_lut=False): - """Common transform mixin""" - + def _apply_transform( + self, + transform: Callable, + grid: gis.Grid, + other: xr.Dataset | xr.DataArray, + *, + return_lut: bool = False, + ) -> xr.Dataset | xr.DataArray: + """Apply common transform mixin.""" was_dataarray = False if not isinstance(other, xr.Dataset): try: other = other.to_dataset(name=other.name) was_dataarray = True - except AttributeError: + except AttributeError as att_err: # must be a ndarray if return_lut: rdata, lut = transform(other, grid=grid, return_lut=True) else: rdata = transform(other, grid=grid) + lut = None # let's guess sh = rdata.shape nd = len(sh) @@ -598,7 +687,8 @@ def _apply_transform(self, transform, grid, other, return_lut=False): newdim = self.z_dim dims = (newdim, self.y_dim, self.x_dim) else: - raise NotImplementedError('more than 3 dims not ok yet.') + msg = 'more than 3 dims not ok yet.' + raise NotImplementedError(msg) from att_err coords = {} for d in dims: @@ -609,8 +699,7 @@ def _apply_transform(self, transform, grid, other, return_lut=False): out.attrs['pyproj_srs'] = self.grid.proj.srs if return_lut: return out, lut - else: - return out + return out # go out = xr.Dataset() @@ -638,8 +727,9 @@ def _apply_transform(self, transform, grid, other, return_lut=False): coords[self.x_dim] = self._obj[self.x_dim] coords[self.y_dim] = self._obj[self.y_dim] - rdata = xr.DataArray(rdata, coords=coords, attrs=var.attrs, - dims=dims) + rdata = xr.DataArray( + rdata, coords=coords, attrs=var.attrs, dims=dims + ) rdata.attrs['pyproj_srs'] = self.grid.proj.srs out[v] = rdata @@ -650,10 +740,15 @@ def _apply_transform(self, transform, grid, other, return_lut=False): if return_lut: return out, lut - else: - return out + return out - def transform(self, other, grid=None, interp='nearest', ks=3): + def transform( + self, + other: xr.Dataset | xr.DataArray | np.ndarray, + grid: gis.Grid | None = None, + interp: str = 'nearest', + ks: int = 3, + ) -> xr.Dataset | xr.DataArray: """Reprojects an other Dataset or DataArray onto self. The returned object has the same data structure as ``other`` (i.e. @@ -674,15 +769,21 @@ def transform(self, other, grid=None, interp='nearest', ks=3): Returns ------- a dataset or a dataarray - """ + """ transform = partial(self.grid.map_gridded_data, interp=interp, ks=ks) return self._apply_transform(transform, grid, other) - def lookup_transform(self, other, grid=None, method=np.mean, lut=None, - return_lut=False): - """Reprojects an other Dataset or DataArray onto self using the - forward tranform lookup. + def lookup_transform( + self, + other: xr.Dataset | xr.DataArray | np.ndarray, + grid: gis.Grid | None = None, + method: Callable = np.mean, + lut: NDArray[Any] | None = None, + *, + return_lut: bool = False, + ) -> xr.Dataset | xr.DataArray: + """Project another Dataset or DataArray onto self via forward transform lookup. See : :py:meth:`Grid.lookup_transform` @@ -707,21 +808,27 @@ def lookup_transform(self, other, grid=None, method=np.mean, lut=None, ------- a dataset or a dataarray If ``return_lut==True``, also return the lookup table - """ + """ + if grid is None: + grid = check_crs(self.grid) transform = partial(self.grid.lookup_transform, method=method, lut=lut) - return self._apply_transform(transform, grid, other, - return_lut=return_lut) + return self._apply_transform( + transform, grid, other, return_lut=return_lut + ) @xr.register_dataarray_accessor('salem') class DataArrayAccessor(_XarrayAccessorBase): + """Salems xarray accessor for DataArrays.""" - def quick_map(self, ax=None, interp='nearest', **kwargs): + def quick_map( + self, ax: axes.Axes | None = None, interp: str = 'nearest', **kwargs + ) -> Map: """Make a plot of the DataArray.""" return self._quick_map(self._obj, ax=ax, interp=interp, **kwargs) - def deacc(self, as_rate=True): + def deacc(self, *, as_rate: bool = True) -> xr.DataArray: """De-accumulates the variable (i.e. compute the variable's rate). The returned variable has one element less over the time dimension. @@ -733,26 +840,37 @@ def deacc(self, as_rate=True): as_rate: bool set to false if you don't want units per hour, but units per given data timestep - """ + """ out = self._obj[{self.t_dim: slice(1, len(self._obj[self.t_dim]))}] - diff = self._obj[{self.t_dim: slice(0, len(self._obj[self.t_dim])-1)}] - out.values = out.values - diff.values - out.attrs['description'] = out.attrs['description'].replace('ACCUMULATED ', '') + diff = self._obj[ + {self.t_dim: slice(0, len(self._obj[self.t_dim]) - 1)} + ] + out.values = out.to_numpy() - diff.to_numpy() + out.attrs['description'] = out.attrs['description'].replace( + 'ACCUMULATED ', '' + ) if as_rate: - dth = self._obj.time[1].values - self._obj.time[0].values + dth = self._obj.time[1].to_numpy() - self._obj.time[0].to_numpy() dth = dth.astype('timedelta64[h]').astype(float) - out.values = out.values / dth + out.values = out.to_numpy() / dth out.attrs['units'] += ' h-1' else: out.attrs['units'] += ' step-1' return out - def interpz(self, zcoord, levels, dim_name='', fill_value=np.nan, - use_multiprocessing=True): - """Interpolates the array along the vertical dimension + def interpz( + self, + zcoord: xr.DataArray, + levels: list[float], + dim_name: str = '', + fill_value: float = np.nan, + *, + use_multiprocessing: bool = True, + ) -> xr.DataArray: + """Interpolate the array along the vertical dimension. Parameters ---------- @@ -771,14 +889,19 @@ def interpz(self, zcoord, levels, dim_name='', fill_value=np.nan, Returns ------- a new DataArray with the interpolated data - """ + """ if self.z_dim is None: - raise RuntimeError('zdimension not recognized') + msg = 'zdimension not recognized' + raise RuntimeError(msg) - data = wrftools.interp3d(self._obj.values, zcoord.values, - np.atleast_1d(levels), fill_value=fill_value, - use_multiprocessing=use_multiprocessing) + data = wrftools.interp3d( + self._obj.values, + zcoord.values, + np.atleast_1d(levels), + fill_value=fill_value, + use_multiprocessing=use_multiprocessing, + ) dims = list(self._obj.dims) zd = np.nonzero([self.z_dim == d for d in dims])[0][0] @@ -795,14 +918,28 @@ def interpz(self, zcoord, levels, dim_name='', fill_value=np.nan, @xr.register_dataset_accessor('salem') class DatasetAccessor(_XarrayAccessorBase): - - def quick_map(self, varname, ax=None, interp='nearest', **kwargs): + """Salems xarray accessor for Datasets.""" + + def quick_map( + self, + varname: str, + ax: axes.Axes | None = None, + interp: str = 'nearest', + **kwargs, + ) -> Map: """Make a plot of a variable of the DataSet.""" - return self._quick_map(self._obj[varname], ax=ax, interp=interp, - **kwargs) - - def transform_and_add(self, other, grid=None, interp='nearest', ks=3, - name=None): + return self._quick_map( + self._obj[varname], ax=ax, interp=interp, **kwargs + ) + + def transform_and_add( + self, + other: xr.Dataset | xr.DataArray | np.ndarray, + name: str | dict[str, str] | None = None, + grid: gis.Grid | None = None, + interp: str = 'nearest', + ks: int = 3, + ) -> xr.Dataset | xr.DataArray: """Reprojects an other Dataset and adds it's content to the current one. Parameters @@ -821,25 +958,33 @@ def transform_and_add(self, other, grid=None, interp='nearest', ks=3, conflict). Set to a str to to rename the variable (if unique) or set to a dict for mapping the old names to the new names for datasets. - """ + """ out = self.transform(other, grid=grid, interp=interp, ks=ks) if isinstance(out, xr.DataArray): new_name = name or out.name if new_name is None: - raise ValueError('You need to set name') + msg = 'You need to set name' + raise ValueError(msg) self._obj[new_name] = out - else: + elif isinstance(out, xr.Dataset): for v in out.data_vars: try: new_name = name[v] except (KeyError, TypeError): new_name = v self._obj[new_name] = out[v] - - def wrf_zlevel(self, varname, levels=None, fill_value=np.nan, - use_multiprocessing=True): + return self._obj + + def wrf_zlevel( + self, + varname: str, + levels: list[float] | NDArray[Any] | None = None, + fill_value: float = np.nan, + *, + use_multiprocessing: bool = True, + ) -> xr.Dataset: """Interpolates to a specified height above sea level. Parameters @@ -857,28 +1002,57 @@ def wrf_zlevel(self, varname, levels=None, fill_value=np.nan, Returns ------- an interpolated DataArray + """ if levels is None: - levels = np.array([10, 20, 30, 50, 75, 100, 200, 300, 500, 750, - 1000, 2000, 3000, 5000, 7500, 10000]) + levels = np.array( + [ + 10, + 20, + 30, + 50, + 75, + 100, + 200, + 300, + 500, + 750, + 1000, + 2000, + 3000, + 5000, + 7500, + 10000, + ] + ) zcoord = self._obj['Z'] - out = self._obj[varname].salem.interpz(zcoord, levels, dim_name='z', - fill_value=fill_value, - use_multiprocessing=use_multiprocessing) + out = self._obj[varname].salem.interpz( + zcoord, + levels, + dim_name='z', + fill_value=fill_value, + use_multiprocessing=use_multiprocessing, + ) out['z'].attrs['description'] = 'height above sea level' out['z'].attrs['units'] = 'm' return out - def wrf_plevel(self, varname, levels=None, fill_value=np.nan, - use_multiprocessing=True): + def wrf_plevel( + self, + varname: str, + levels: list[float] | NDArray[Any] | None = None, + fill_value: float = np.nan, + *, + use_multiprocessing: bool = True, + ) -> xr.Dataset: """Interpolates to a specified pressure level (hPa). Parameters ---------- varname: str the name of the variable to interpolate - levels: 1d array + levels: 1d array, optional levels at which to interpolate (default: some levels I thought of) fill_value : np.nan or 'extrapolate', optional how to handle levels below the topography. Default is to mark them @@ -889,21 +1063,46 @@ def wrf_plevel(self, varname, levels=None, fill_value=np.nan, Returns ------- an interpolated DataArray + """ if levels is None: - levels = np.array([1000, 975, 950, 925, 900, 850, 800, 750, 700, - 650, 600, 550, 500, 450, 400, 300, 200, 100]) + levels = np.array( + [ + 1000, + 975, + 950, + 925, + 900, + 850, + 800, + 750, + 700, + 650, + 600, + 550, + 500, + 450, + 400, + 300, + 200, + 100, + ] + ) zcoord = self._obj['PRESSURE'] - out = self._obj[varname].salem.interpz(zcoord, levels, dim_name='p', - fill_value=fill_value, - use_multiprocessing=use_multiprocessing) + out = self._obj[varname].salem.interpz( + zcoord, + levels, + dim_name='p', + fill_value=fill_value, + use_multiprocessing=use_multiprocessing, + ) out['p'].attrs['description'] = 'pressure' out['p'].attrs['units'] = 'hPa' return out -def open_xr_dataset(file): +def open_xr_dataset(file: Path | str) -> xr.Dataset: """Thin wrapper around xarray's open_dataset. This is needed because variables often have not enough georef attrs @@ -913,18 +1112,25 @@ def open_xr_dataset(file): Returns ------- an xarray Dataset - """ + """ # if geotiff, use Salem - p, ext = os.path.splitext(file) - if (ext.lower() == '.tif') or (ext.lower() == '.tiff'): + if isinstance(file, str): + file = Path(file) + ext = file.suffix.lower() + if ext in ('.tif', '.tiff'): from salem import GeoTiff + geo = GeoTiff(file) # TODO: currently everything is loaded in memory (baaad) - da = xr.DataArray(geo.get_vardata(), - coords={'x': geo.grid.center_grid.x_coord, - 'y': geo.grid.center_grid.y_coord}, - dims=['y', 'x']) + da = xr.DataArray( + geo.get_vardata(), + coords={ + 'x': geo.grid.center_grid.x_coord, + 'y': geo.grid.center_grid.y_coord, + }, + dims=['y', 'x'], + ) ds = xr.Dataset() ds.attrs['pyproj_srs'] = geo.grid.proj.srs ds['data'] = da @@ -950,8 +1156,10 @@ def open_xr_dataset(file): return ds -def open_wrf_dataset(file, **kwargs): - """Wrapper around xarray's open_dataset to make WRF files a bit better. +def open_wrf_dataset(file: Path | str, **kwargs) -> xr.Dataset: + """Use Salem to open a wrf dataset. + + Thin wrapper around xarray's open_dataset to make WRF files a bit better. This is needed because variables often have not enough georef attrs to be understood alone, and datasets tend to loose their attrs with @@ -967,8 +1175,8 @@ def open_wrf_dataset(file, **kwargs): Returns ------- an xarray Dataset - """ + """ nc = netCDF4.Dataset(file) nc.set_auto_mask(False) @@ -987,12 +1195,12 @@ def open_wrf_dataset(file, **kwargs): ds = xr.open_dataset(NetCDF4DataStore(nc), **kwargs) # remove time dimension to lon lat - for vn in ['XLONG', 'XLAT']: - try: + try: + for vn in ['XLONG', 'XLAT']: v = ds[vn].isel(Time=0) ds[vn] = xr.DataArray(v.values, dims=['south_north', 'west_east']) - except (ValueError, KeyError): - pass + except (ValueError, KeyError): + pass # Convert time (if necessary) if 'Time' in ds.dims: @@ -1001,7 +1209,7 @@ def open_wrf_dataset(file, **kwargs): ds['Time'] = time ds = ds.rename({'Time': 'time'}) tr = {'Time': 'time', 'XLAT': 'lat', 'XLONG': 'lon', 'XTIME': 'xtime'} - tr = {k: tr[k] for k in tr.keys() if k in ds.variables} + tr = {k: tr[k] for k in tr if k in ds.variables} ds = ds.rename(tr) # drop ugly vars @@ -1024,24 +1232,40 @@ def open_wrf_dataset(file, **kwargs): return ds -def is_rotated_proj_working(): +def is_rotated_proj_working() -> np.bool: + """Check if the pyproj version is working with rotated projections. + Returns + ------- + The check result. + + """ import pyproj - srs = ('+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' - '+to_meter=0.0174532925199433 +o_lon_p=0.0 +o_lat_p=80.5 ' - '+lon_0=357.5 +no_defs') + + srs = ( + '+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' + '+to_meter=0.0174532925199433 +o_lon_p=0.0 +o_lat_p=80.5 ' + '+lon_0=357.5 +no_defs' + ) p1 = pyproj.Proj(srs) p2 = wgs84 - return np.isclose(transform_proj(p1, p2, -20, -9), - [-22.243473889042903, -0.06328365194179102], - atol=1e-5).all() + return np.isclose( + transform_proj(p1, p2, np.array(-20), np.array(-9)), + [-22.243473889042903, -0.06328365194179102], + atol=1e-5, + ).all() -def open_metum_dataset(file, pole_longitude=None, pole_latitude=None, - central_rotated_longitude=0., **kwargs): - """Wrapper to Met Office Unified Model files (experimental) +def open_metum_dataset( + file: Path | str, + pole_longitude: float | None = None, + pole_latitude: float | None = None, + central_rotated_longitude: float = 0.0, + **kwargs, +) -> xr.Dataset: + """Wrapper to Met Office Unified Model files (experimental). This is needed because these files are a little messy. @@ -1064,12 +1288,15 @@ def open_metum_dataset(file, pole_longitude=None, pole_latitude=None, Returns ------- an xarray Dataset - """ + """ if not is_rotated_proj_working(): - raise RuntimeError('open_metum_dataset currently does not ' - 'work with certain PROJ versions: ' - 'https://github.com/pyproj4/pyproj/issues/424') + msg = ( + 'open_metum_dataset currently does not ' + 'work with certain PROJ versions: ' + 'https://github.com/pyproj4/pyproj/issues/424' + ) + raise RuntimeError(msg) # open with xarray ds = xr.open_dataset(file, **kwargs) @@ -1091,17 +1318,22 @@ def open_metum_dataset(file, pole_longitude=None, pole_latitude=None, pole_latitude = ds.attrs.get(n_lat, None) # then as variable attribute if pole_longitude is None or pole_latitude is None: - for k, v in ds.variables.items(): + for v in ds.variables.values(): if n_lon in v.attrs: pole_longitude = v.attrs[n_lon] if n_lat in v.attrs: pole_latitude = v.attrs[n_lat] if pole_longitude is not None and pole_latitude is not None: break - - srs = ('+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' - '+to_meter=0.0174532925199433 ' - '+o_lon_p={o_lon_p} +o_lat_p={o_lat_p} +lon_0={lon_0} +no_defs') + if pole_longitude is None or pole_latitude is None: + msg = 'Could not determine pole longitude and/or latitude' + raise RuntimeError(msg) + + srs = ( + '+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' + '+to_meter=0.0174532925199433 ' + '+o_lon_p={o_lon_p} +o_lat_p={o_lat_p} +lon_0={lon_0} +no_defs' + ) params = { 'o_lon_p': central_rotated_longitude, 'o_lat_p': pole_latitude, @@ -1117,8 +1349,14 @@ def open_metum_dataset(file, pole_longitude=None, pole_latitude=None, return ds -def open_mf_wrf_dataset(paths, chunks=None, compat='no_conflicts', lock=None, - preprocess=None): +def open_mf_wrf_dataset( + paths: list[Path] | Path | str, + chunks: int | dict[str, float] | None = None, + compat: str = 'no_conflicts', + preprocess: Callable | None = None, + *, + lock: bool | threading.Lock = False, +) -> xr.Dataset: """Open multiple WRF files as a single WRF dataset. Requires dask to be installed. Note that if your files are sliced by time, @@ -1130,9 +1368,8 @@ def open_mf_wrf_dataset(paths, chunks=None, compat='no_conflicts', lock=None, Parameters ---------- - paths : str or sequence - Either a string glob in the form `path/to/my/files/*.nc` or an - explicit list of files to open. + paths : + Either a list of Path object or a Path object. chunks : int or dict, optional Dictionary with keys given by dimension names and values given by chunk sizes. In general, these should divide the dimensions of each dataset. @@ -1163,18 +1400,27 @@ def open_mf_wrf_dataset(paths, chunks=None, compat='no_conflicts', lock=None, Returns ------- xarray.Dataset - """ - if isinstance(paths, basestring): - paths = sorted(glob(paths)) + """ if not paths: - raise IOError('no files to open') - - if lock is None: - lock = NETCDF4_PYTHON_LOCK + msg = 'no files to open' + raise OSError(msg) + + if isinstance(paths, Path): + paths = str(paths) + if isinstance(paths, str): + # NOTE: this only works on posix systems + split = paths.split('/') + glob = split.pop(-1) + path = '/'.join(split) + paths = sorted(Path(path).glob(glob)) + + used_lock = lock if lock else netcdf4_python_lock try: - datasets = [open_wrf_dataset(p, chunks=chunks or {}, lock=lock) - for p in paths] + datasets = [ + open_wrf_dataset(p, chunks=chunks or {}, lock=used_lock) + for p in paths + ] except TypeError as err: if 'lock' not in str(err): raise @@ -1183,7 +1429,7 @@ def open_mf_wrf_dataset(paths, chunks=None, compat='no_conflicts', lock=None, orig_datasets = datasets - def ds_closer(): + def ds_closer() -> None: for ods in orig_datasets: ods.close() @@ -1191,12 +1437,17 @@ def ds_closer(): datasets = [preprocess(ds) for ds in datasets] try: - combined = xr.combine_nested(datasets, combine_attrs='drop_conflicts', - concat_dim='time', compat=compat) + combined = xr.combine_nested( + datasets, + combine_attrs='drop_conflicts', + concat_dim='time', + compat=compat, + ) except ValueError: # Older xarray - combined = xr.combine_nested(datasets, concat_dim='time', - compat=compat) + combined = xr.combine_nested( + datasets, concat_dim='time', compat=compat + ) except AttributeError: # Even older combined = xr.auto_combine(datasets, concat_dim='time', compat=compat) @@ -1206,6 +1457,7 @@ def ds_closer(): combined.set_close(ds_closer) except AttributeError: from xarray.backends.api import _MultiFileCloser + mfc = _MultiFileCloser([ods._file_obj for ods in orig_datasets]) combined._file_obj = mfc diff --git a/salem/tests/__init__.py b/salem/tests/__init__.py index 3d46811..c4cd526 100644 --- a/salem/tests/__init__.py +++ b/salem/tests/__init__.py @@ -1,42 +1,48 @@ -from __future__ import division import unittest +from typing import Callable +from urllib.error import URLError +from urllib.request import urlopen + from packaging.version import Version -import os + from salem import python_version -from urllib.request import urlopen -from urllib.error import URLError -def has_internet(): +def has_internet() -> bool: """Not so recommended it seems""" try: _ = urlopen('http://www.google.com', timeout=1) - return True except URLError: pass + else: + return True return False try: import shapely + has_shapely = True except ImportError: has_shapely = False try: import geopandas + has_geopandas = True except ImportError: has_geopandas = False try: import motionless + has_motionless = True except ImportError: has_motionless = False try: import matplotlib + mpl_version = Version(matplotlib.__version__) has_matplotlib = mpl_version >= Version('2') except ImportError: @@ -45,65 +51,70 @@ def has_internet(): try: import rasterio + has_rasterio = True except ImportError: has_rasterio = False try: import cartopy + has_cartopy = True except ImportError: has_cartopy = False try: import dask + has_dask = True except ImportError: has_dask = False -def requires_internet(test): - msg = "requires internet" +def requires_internet(test: Callable) -> Callable: + msg = 'requires internet' return test if has_internet() else unittest.skip(msg)(test) -def requires_matplotlib_and_py3(test): - msg = "requires matplotlib and py3" - return test if has_matplotlib and (python_version == 'py3') \ +def requires_matplotlib_and_py3(test: Callable) -> Callable: + msg = 'requires matplotlib and py3' + return ( + test + if has_matplotlib and (python_version == 'py3') else unittest.skip(msg)(test) + ) -def requires_matplotlib(test): - msg = "requires matplotlib" +def requires_matplotlib(test: Callable) -> Callable: + msg = 'requires matplotlib' return test if has_matplotlib else unittest.skip(msg)(test) -def requires_motionless(test): - msg = "requires motionless" +def requires_motionless(test: Callable) -> Callable: + msg = 'requires motionless' return test if has_motionless else unittest.skip(msg)(test) -def requires_rasterio(test): - msg = "requires rasterio" +def requires_rasterio(test: Callable) -> Callable: + msg = 'requires rasterio' return test if has_rasterio else unittest.skip(msg)(test) -def requires_cartopy(test): - msg = "requires cartopy" +def requires_cartopy(test: Callable) -> Callable: + msg = 'requires cartopy' return test if has_cartopy else unittest.skip(msg)(test) -def requires_shapely(test): - msg = "requires shapely" +def requires_shapely(test: Callable) -> Callable: + msg = 'requires shapely' return test if has_shapely else unittest.skip(msg)(test) -def requires_geopandas(test): - msg = "requires geopandas" +def requires_geopandas(test: Callable) -> Callable: + msg = 'requires geopandas' return test if has_geopandas else unittest.skip(msg)(test) -def requires_dask(test): - msg = "requires dask" +def requires_dask(test: Callable) -> Callable: + msg = 'requires dask' return test if has_dask else unittest.skip(msg)(test) - diff --git a/salem/tests/test_datasets.py b/salem/tests/test_datasets.py index 04bfcc7..d575345 100644 --- a/salem/tests/test_datasets.py +++ b/salem/tests/test_datasets.py @@ -1,40 +1,45 @@ -from __future__ import division - import unittest import warnings from datetime import datetime - -import numpy as np import netCDF4 - - +import numpy as np import pandas as pd +import pytest import xarray as xr - -from numpy.testing import assert_array_equal, assert_allclose -from salem import Grid +from numpy.testing import assert_allclose, assert_array_equal + +from salem import Grid, mercator_grid, wgs84, wrftools +from salem.datasets import ( + WRF, + EsriITMIX, + GeoDataset, + GeoNetcdf, + GeoTiff, + GoogleCenterMap, + GoogleVisibleMap, +) +from salem.tests import ( + requires_geopandas, + requires_internet, + requires_matplotlib, + requires_motionless, + requires_rasterio, + requires_shapely, +) from salem.utils import get_demo_file -from salem import wgs84 -from salem import wrftools, mercator_grid -from salem.datasets import (GeoDataset, GeoNetcdf, GeoTiff, WRF, - GoogleCenterMap, GoogleVisibleMap, EsriITMIX) -from salem.tests import (requires_rasterio, requires_motionless, - requires_geopandas, requires_internet, - requires_matplotlib, requires_shapely) class TestDataset(unittest.TestCase): - - def test_period(self): + def test_period(self) -> None: """See if simple operations work well""" g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) d = GeoDataset(g) - self.assertTrue(d.time is None) - self.assertTrue(d.sub_t is None) - self.assertTrue(d.t0 is None) - self.assertTrue(d.t1 is None) + assert d.time is None + assert d.sub_t is None + assert d.t0 is None + assert d.t1 is None t = pd.date_range('1/1/2011', periods=72, freq='D') d = GeoDataset(g, time=t) @@ -66,64 +71,65 @@ def test_period(self): assert_array_equal(d.t0, t[0]) assert_array_equal(d.t1, t[-1]) - self.assertRaises(NotImplementedError, d.get_vardata) + with pytest.raises(NotImplementedError): + d.get_vardata() @requires_rasterio @requires_shapely - def test_subset(self): + def test_subset(self) -> None: """See if simple operations work well""" import shapely.geometry as shpg g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) d = GeoDataset(g) - self.assertTrue(isinstance(d, GeoDataset)) - self.assertEqual(g, d.grid) + assert isinstance(d, GeoDataset) + assert g == d.grid - d.set_subset(corners=([0, 0], [2, 2]), crs=wgs84) - self.assertEqual(g, d.grid) + d.set_subset(corners=((0, 0), (2, 2)), crs=wgs84) + assert g == d.grid d.set_subset() - self.assertEqual(g, d.grid) + assert g == d.grid d.set_subset(margin=-1) lon, lat = d.grid.ll_coordinates - self.assertEqual(lon, 1) - self.assertEqual(lat, 1) + assert lon == 1 + assert lat == 1 - d.set_subset(corners=([0.1, 0.1], [1.9, 1.9]), crs=wgs84) - self.assertEqual(g, d.grid) + d.set_subset(corners=((0.1, 0.1), (1.9, 1.9)), crs=wgs84) + assert g == d.grid - d.set_subset(corners=([0.51, 0.51], [1.9, 1.9]), crs=wgs84) - self.assertNotEqual(g, d.grid) + d.set_subset(corners=((0.51, 0.51), (1.9, 1.9)), crs=wgs84) + assert g != d.grid gm = Grid(nxny=(1, 1), dxdy=(1, 1), x0y0=(1, 1), proj=wgs84) - d.set_subset(corners=([1, 1], [1, 1]), crs=wgs84) - self.assertEqual(gm, d.grid) + d.set_subset(corners=((1, 1), (1, 1)), crs=wgs84) + assert gm == d.grid d.set_subset() d.set_roi() - d.set_roi(corners=([1, 1], [1, 1]), crs=wgs84) + d.set_roi(corners=((1, 1), (1, 1)), crs=wgs84) d.set_subset(toroi=True) - self.assertEqual(gm, d.grid) + assert gm == d.grid gm = Grid(nxny=(1, 1), dxdy=(1, 1), x0y0=(2, 2), proj=wgs84) - d.set_subset(corners=([2, 2], [2, 2]), crs=wgs84) - self.assertEqual(gm, d.grid) + d.set_subset(corners=((2, 2), (2, 2)), crs=wgs84) + assert gm == d.grid with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. - warnings.simplefilter("always") + warnings.simplefilter('always') # Trigger a warning. - d.set_subset(corners=([-4, -4], [5, 5]), crs=wgs84) - self.assertEqual(g, d.grid) + d.set_subset(corners=((-4, -4), (5, 5)), crs=wgs84) + assert g == d.grid # Verify some things assert len(w) >= 2 - self.assertRaises(RuntimeError, d.set_subset, corners=([-1, -1], - [-1, -1])) - self.assertRaises(RuntimeError, d.set_subset, corners=([5, 5], - [5, 5])) + with pytest.raises(RuntimeError): + d.set_subset(corners=((-1, -1), (-1, -1))) + with pytest.raises(RuntimeError): + d.set_subset(corners=((5, 5), (5, 5))) shpf = get_demo_file('Hintereisferner.shp') reff = get_demo_file('hef_roi.tif') @@ -131,16 +137,21 @@ def test_subset(self): d.set_roi(shape=shpf) ref = d.get_vardata() # same errors as IDL: ENVI is just wrong - self.assertTrue(np.sum(ref != d.roi) < 9) - - g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - p = shpg.Polygon([(1.5, 1.), (2., 1.5), (1.5, 2.), (1., 1.5)]) + assert np.sum(ref != d.roi) < 9 + + g = Grid( + nxny=(3, 3), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + p = shpg.Polygon([(1.5, 1.0), (2.0, 1.5), (1.5, 2.0), (1.0, 1.5)]) roi = g.region_of_interest(geometry=p) np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) d = GeoDataset(g) - d.set_roi(corners=([1.1, 1.1], [1.9, 1.9])) + d.set_roi(corners=((1.1, 1.1), (1.9, 1.9))) d.set_subset(toroi=True) np.testing.assert_array_equal([[1]], d.roi) d.set_subset() @@ -149,13 +160,13 @@ def test_subset(self): np.testing.assert_array_equal([[0, 0, 0], [0, 0, 0], [0, 0, 0]], d.roi) # Raises - self.assertRaises(RuntimeError, d.set_subset, toroi=True) + with pytest.raises(RuntimeError): + d.set_subset(toroi=True) class TestGeotiff(unittest.TestCase): - @requires_rasterio - def test_subset(self): + def test_subset(self) -> None: """Open geotiff, do subsets and stuff""" go = get_demo_file('hef_srtm.tif') gs = get_demo_file('hef_srtm_subset.tif') @@ -175,9 +186,11 @@ def test_subset(self): eps = 1e-5 ex = gs.grid.extent_in_crs(crs=wgs84) # [left, right, bot, top - go.set_subset(corners=((ex[0], ex[2]+eps), (ex[1], ex[3]-eps)), - crs=wgs84, - margin=-2) + go.set_subset( + corners=((ex[0], ex[2] + eps), (ex[1], ex[3] - eps)), + crs=wgs84, + margin=-2, + ) ref = gs.get_vardata()[2:-2, 2:-2] totest = go.get_vardata() np.testing.assert_array_equal(ref.shape, totest.shape) @@ -186,15 +199,13 @@ def test_subset(self): go.set_subset() @requires_rasterio - def test_itmix(self): - + def test_itmix(self) -> None: gf = get_demo_file('02_surface_Academy_1997_UTM47.asc') ds = EsriITMIX(gf) ds.get_vardata() @requires_rasterio - def test_xarray(self): - + def test_xarray(self) -> None: from salem import open_xr_dataset go = get_demo_file('hef_srtm.tif') @@ -208,7 +219,9 @@ def test_xarray(self): ref = gs['data'] totest = gos['data'] - np.testing.assert_array_equal(ref.shape, (gos.salem.grid.ny, gos.salem.grid.nx)) + np.testing.assert_array_equal( + ref.shape, (gos.salem.grid.ny, gos.salem.grid.nx) + ) np.testing.assert_array_equal(ref.shape, totest.shape) np.testing.assert_array_equal(ref, totest) rlon, rlat = geo.grid.center_grid.ll_coordinates @@ -218,9 +231,7 @@ def test_xarray(self): class TestGeoNetcdf(unittest.TestCase): - - def test_eraint(self): - + def test_eraint(self) -> None: f = get_demo_file('era_interim_tibet.nc') d = GeoNetcdf(f) assert d.grid.origin == 'upper-left' @@ -242,15 +253,18 @@ def test_eraint(self): assert_array_equal(flon[alon], d.get_vardata('longitude')) assert_allclose(flat[alat], d.get_vardata('latitude')) - assert_allclose(nc.variables['t2m'][:, alat, alon], - np.squeeze(d.get_vardata('t2m'))) + assert_allclose( + nc.variables['t2m'][:, alat, alon], + np.squeeze(d.get_vardata('t2m')), + ) d.set_period(t0='2012-06-01 06:00:00', t1='2012-06-01 12:00:00') - assert_allclose(nc.variables['t2m'][1:3, alat, alon], - np.squeeze(d.get_vardata('t2m'))) - - def test_as_xarray(self): + assert_allclose( + nc.variables['t2m'][1:3, alat, alon], + np.squeeze(d.get_vardata('t2m')), + ) + def test_as_xarray(self) -> None: f = get_demo_file('era_interim_tibet.nc') d = GeoNetcdf(f) t2 = d.get_vardata('t2m', as_xarray=True) @@ -271,15 +285,15 @@ def test_as_xarray(self): # TODO: the z dim is not ok @requires_geopandas - def test_wrf(self): + def test_wrf(self) -> None: """Open WRF, do subsets and stuff""" fs = get_demo_file('chinabang.shp') - for d in ['1', '2']: - fw = get_demo_file('wrf_tip_d{}.nc'.format(d)) + for i in ['1', '2']: + fw = get_demo_file('wrf_tip_d{}.nc'.format(i)) d = GeoNetcdf(fw) - self.assertTrue(isinstance(d, GeoDataset)) + assert isinstance(d, GeoDataset) mylon, mylat = d.grid.ll_coordinates reflon = d.get_vardata('XLONG') reflat = d.get_vardata('XLAT') @@ -293,17 +307,31 @@ def test_wrf(self): d2 = GeoNetcdf(get_demo_file('wrf_tip_d2.nc')) # Auto dimensions - self.assertTrue(d1.t_dim == 'Time') - self.assertTrue(d1.x_dim == 'west_east') - self.assertTrue(d1.y_dim == 'south_north') - self.assertTrue(d1.z_dim is None) + assert d1.t_dim == 'Time' + assert d1.x_dim == 'west_east' + assert d1.y_dim == 'south_north' + assert d1.z_dim is None # Time - assert_array_equal(d1.time, pd.to_datetime([datetime(2005, 9, 21), - datetime(2005, 9, 21, 3)])) - - assert_array_equal(d2.time, pd.to_datetime([datetime(2005, 9, 21), - datetime(2005, 9, 21, 1)])) + assert_array_equal( + d1.time, + pd.to_datetime( + [ + datetime(2005, 9, 21), + datetime(2005, 9, 21, 3), + ] + ), + ) + + assert_array_equal( + d2.time, + pd.to_datetime( + [ + datetime(2005, 9, 21), + datetime(2005, 9, 21, 1), + ] + ), + ) bef = d2.get_vardata('T2') d2.set_period(t0=datetime(2005, 9, 21, 1)) assert_array_equal(bef[[1], ...], d2.get_vardata('T2')) @@ -315,9 +343,9 @@ def test_wrf(self): # ROIS d1.set_roi(grid=d2.grid) d1.set_subset(toroi=True) - self.assertEqual(d1.grid.nx * 3, d2.grid.nx) - self.assertEqual(d1.grid.ny * 3, d2.grid.ny) - self.assertTrue(np.min(d1.roi) == 1) + assert d1.grid.nx * 3 == d2.grid.nx + assert d1.grid.ny * 3 == d2.grid.ny + assert np.min(d1.roi) == 1 mylon, mylat = d1.grid.ll_coordinates reflon = d1.get_vardata('XLONG') @@ -352,8 +380,7 @@ def test_wrf(self): np.testing.assert_allclose(reflat, mylat, atol=1e-4) @requires_geopandas - def test_wrf_polar(self): - + def test_wrf_polar(self) -> None: d = GeoNetcdf(get_demo_file('geo_em_d01_polarstereo.nc')) mylon, mylat = d.grid.ll_coordinates reflon = np.squeeze(d.get_vardata('XLONG_M')) @@ -371,8 +398,7 @@ def test_wrf_polar(self): np.testing.assert_allclose(reflat, mylat, atol=1e-4) @requires_geopandas - def test_wrf_latlon(self): - + def test_wrf_latlon(self) -> None: d = GeoNetcdf(get_demo_file('geo_em.d01_lon-lat.nc')) mylon, mylat = d.grid.ll_coordinates reflon = np.squeeze(d.get_vardata('XLONG_M')) @@ -381,6 +407,7 @@ def test_wrf_latlon(self): np.testing.assert_allclose(reflon, mylon, atol=1e-4) np.testing.assert_allclose(reflat, mylat, atol=1e-4) + del d d = GeoNetcdf(get_demo_file('geo_em.d04_lon-lat.nc')) mylon, mylat = d.grid.ll_coordinates reflon = np.squeeze(d.get_vardata('XLONG_M')) @@ -389,21 +416,21 @@ def test_wrf_latlon(self): np.testing.assert_allclose(reflon, mylon, atol=1e-4) np.testing.assert_allclose(reflat, mylat, atol=1e-4) - def test_longtime(self): + def test_longtime(self) -> None: """There was a bug with time""" fs = get_demo_file('test_longtime.nc') c = GeoNetcdf(fs) - self.assertEqual(len(c.time), 2424) - assert_array_equal(c.time[0:2], pd.to_datetime([datetime(1801, 10, 1), - datetime(1801, 11, - 1)])) - - def test_diagnostic_vars(self): + assert len(c.time) == 2424 + assert_array_equal( + c.time[0:2], + pd.to_datetime([datetime(1801, 10, 1), datetime(1801, 11, 1)]), + ) + def test_diagnostic_vars(self) -> None: d = WRF(get_demo_file('wrf_tip_d1.nc')) d2 = GeoNetcdf(get_demo_file('wrf_tip_d2.nc')) - self.assertTrue('T2C' in d.variables) + assert 'T2C' in d.variables ref = d.get_vardata('T2') tot = d.get_vardata('T2C') + 273.15 @@ -414,91 +441,105 @@ def test_diagnostic_vars(self): ref = d.get_vardata('T2') tot = d.get_vardata('T2C') + 273.15 - self.assertEqual(tot.shape[-1] * 3, d2.grid.nx) - self.assertEqual(tot.shape[-2] * 3, d2.grid.ny) + assert tot.shape[-1] * 3 == d2.grid.nx + assert tot.shape[-2] * 3 == d2.grid.ny np.testing.assert_allclose(ref, tot) d = WRF(get_demo_file('wrf_tip_d1.nc')) ref = d.variables['T2'][:] d.set_subset(margin=-5) tot = d.get_vardata('T2') - assert_array_equal(ref.shape[1]-10, tot.shape[1]) - assert_array_equal(ref.shape[2]-10, tot.shape[2]) + assert_array_equal(ref.shape[1] - 10, tot.shape[1]) + assert_array_equal(ref.shape[2] - 10, tot.shape[2]) assert_array_equal(ref[:, 5:-5, 5:-5], tot) class TestGoogleStaticMap(unittest.TestCase): - @requires_internet @requires_motionless @requires_matplotlib - def test_center(self): + def test_center(self) -> None: import matplotlib as mpl - gm = GoogleCenterMap(center_ll=(10.762660, 46.794221), zoom=13, - size_x=500, size_y=500, use_cache=False) + + gm = GoogleCenterMap( + center_ll=(10.762660, 46.794221), + zoom=13, + size_x=500, + size_y=500, + use_cache=False, + ) gm.set_roi(shape=get_demo_file('Hintereisferner.shp')) gm.set_subset(toroi=True, margin=10) img = gm.get_vardata()[..., :3] - img[np.nonzero(gm.roi == 0)] /= 2. + img[np.nonzero(gm.roi == 0)] /= 2.0 # from PIL import Image # Image.fromarray((img * 255).astype(np.uint8)).save( # get_demo_file('hef_google_roi.png')) ref = mpl.image.imread(get_demo_file('hef_google_roi.png')) - rmsd = np.sqrt(np.mean((ref - img)**2)) - self.assertTrue(rmsd < 0.2) + rmsd = np.sqrt(np.mean((ref - img) ** 2)) + assert rmsd < 0.2 # assert_allclose(ref, img, atol=2e-2) - gm = GoogleCenterMap(center_ll=(10.762660, 46.794221), zoom=13, - size_x=500, size_y=500) + gm = GoogleCenterMap( + center_ll=(10.762660, 46.794221), zoom=13, size_x=500, size_y=500 + ) gm.set_roi(shape=get_demo_file('Hintereisferner.shp')) gm.set_subset(toroi=True, margin=10) img = gm.get_vardata()[..., :3] - img[np.nonzero(gm.roi == 0)] /= 2. - rmsd = np.sqrt(np.mean((ref - img)**2)) - self.assertTrue(rmsd < 0.2) - - gm = GoogleCenterMap(center_ll=(10.762660, 46.794221), zoom=13, - size_x=500, size_y=500) - gm2 = GoogleCenterMap(center_ll=(10.762660, 46.794221), zoom=13, - size_x=500, size_y=500, scale=2) + img[np.nonzero(gm.roi == 0)] /= 2.0 + rmsd = np.sqrt(np.mean((ref - img) ** 2)) + assert rmsd < 0.2 + + gm = GoogleCenterMap( + center_ll=(10.762660, 46.794221), zoom=13, size_x=500, size_y=500 + ) + gm2 = GoogleCenterMap( + center_ll=(10.762660, 46.794221), + zoom=13, + size_x=500, + size_y=500, + scale=2, + ) assert (gm.grid.nx * 2) == gm2.grid.nx assert gm.grid.extent == gm2.grid.extent @requires_internet @requires_motionless @requires_matplotlib - def test_visible(self): + def test_visible(self) -> None: import matplotlib as mpl - x = [91.176036, 92.05, 88.880927] - y = [29.649702, 31.483333, 29.264956] + x = np.array([91.176036, 92.05, 88.880927]) + y = np.array([29.649702, 31.483333, 29.264956]) - g = GoogleVisibleMap(x=x, y=y, size_x=400, size_y=400, - maptype='terrain') + g = GoogleVisibleMap( + x=x, y=y, size_x=400, size_y=400, maptype='terrain' + ) img = g.get_vardata()[..., :3] i, j = g.grid.transform(x, y, nearest=True) for _i, _j in zip(i, j): - img[_j-3:_j+4, _i-3:_i+4, 0] = 1 - img[_j-3:_j+4, _i-3:_i+4, 1:] = 0 + img[_j - 3 : _j + 4, _i - 3 : _i + 4, 0] = 1 + img[_j - 3 : _j + 4, _i - 3 : _i + 4, 1:] = 0 # from PIL import Image # Image.fromarray((img * 255).astype(np.uint8)).save( # get_demo_file('hef_google_visible.png')) ref = mpl.image.imread(get_demo_file('hef_google_visible.png')) - rmsd = np.sqrt(np.mean((ref-img)**2)) - self.assertTrue(rmsd < 1e-1) + rmsd = np.sqrt(np.mean((ref - img) ** 2)) + assert rmsd < 0.1 - self.assertRaises(ValueError, GoogleVisibleMap, x=x, y=y, zoom=12) + with pytest.raises(ValueError): + GoogleVisibleMap(x=x, y=y, zoom=12) fw = get_demo_file('wrf_tip_d1.nc') d = GeoNetcdf(fw) i, j = d.grid.ij_coordinates g = GoogleVisibleMap(x=i, y=j, crs=d.grid, size_x=500, size_y=500) img = g.get_vardata()[..., :3] - mask = g.grid.map_gridded_data(i*0+1, d.grid) + mask = g.grid.map_gridded_data(i * 0 + 1, d.grid) img[np.nonzero(mask)] = np.clip(img[np.nonzero(mask)] + 0.3, 0, 1) @@ -506,26 +547,27 @@ def test_visible(self): # Image.fromarray((img * 255).astype(np.uint8)).save( # get_demo_file('hef_google_visible_grid.png')) ref = mpl.image.imread(get_demo_file('hef_google_visible_grid.png')) - rmsd = np.sqrt(np.mean((ref-img)**2)) - self.assertTrue(rmsd < 5e-1) + rmsd = np.sqrt(np.mean((ref - img) ** 2)) + assert rmsd < 0.5 - gm = GoogleVisibleMap(x=i, y=j, crs=d.grid, - size_x=500, size_y=500) - gm2 = GoogleVisibleMap(x=i, y=j, crs=d.grid, scale=2, - size_x=500, size_y=500) + gm = GoogleVisibleMap(x=i, y=j, crs=d.grid, size_x=500, size_y=500) + gm2 = GoogleVisibleMap( + x=i, y=j, crs=d.grid, scale=2, size_x=500, size_y=500 + ) assert (gm.grid.nx * 2) == gm2.grid.nx assert gm.grid.extent == gm2.grid.extent # Test regression for non array inputs - grid = mercator_grid(center_ll=(72.5, 30.), - extent=(2.0e6, 2.0e6)) - GoogleVisibleMap(x=[0, grid.nx - 1], y=[0, grid.ny - 1], crs=grid) + grid = mercator_grid(center_ll=(72.5, 30.0), extent=(2.0e6, 2.0e6)) + GoogleVisibleMap( + x=np.array([0, grid.nx - 1]), + y=np.array([0, grid.ny - 1]), + crs=grid, + ) class TestWRF(unittest.TestCase): - - def test_unstagger(self): - + def test_unstagger(self) -> None: wf = get_demo_file('wrf_cropped.nc') with netCDF4.Dataset(wf) as nc: nc.set_auto_mask(False) @@ -536,45 +578,45 @@ def test_unstagger(self): # Own constructor v = wrftools.Unstaggerer(nc['PH']) assert_allclose(v[:], ref) - assert_allclose(v[0:2, 2:12, ...], - ref[0:2, 2:12, ...]) - assert_allclose(v[:, 2:12, ...], - ref[:, 2:12, ...]) - assert_allclose(v[0:2, 2:12, 5:10, 15:17], - ref[0:2, 2:12, 5:10, 15:17]) - assert_allclose(v[1:2, 2:, 5:10, 15:17], - ref[1:2, 2:, 5:10, 15:17]) - assert_allclose(v[1:2, :-2, 5:10, 15:17], - ref[1:2, :-2, 5:10, 15:17]) - assert_allclose(v[1:2, 2:-4, 5:10, 15:17], - ref[1:2, 2:-4, 5:10, 15:17]) - assert_allclose(v[[0, 2], ...], - ref[[0, 2], ...]) - assert_allclose(v[..., [0, 2]], - ref[..., [0, 2]]) + assert_allclose(v[0:2, 2:12, ...], ref[0:2, 2:12, ...]) + assert_allclose(v[:, 2:12, ...], ref[:, 2:12, ...]) + assert_allclose( + v[0:2, 2:12, 5:10, 15:17], ref[0:2, 2:12, 5:10, 15:17] + ) + assert_allclose(v[1:2, 2:, 5:10, 15:17], ref[1:2, 2:, 5:10, 15:17]) + assert_allclose( + v[1:2, :-2, 5:10, 15:17], ref[1:2, :-2, 5:10, 15:17] + ) + assert_allclose( + v[1:2, 2:-4, 5:10, 15:17], ref[1:2, 2:-4, 5:10, 15:17] + ) + assert_allclose(v[[0, 2], ...], ref[[0, 2], ...]) + assert_allclose(v[..., [0, 2]], ref[..., [0, 2]]) assert_allclose(v[0, ...], ref[0, ...]) # Under WRF - nc = WRF(wf) - assert_allclose(nc.get_vardata('PH'), ref) - nc.set_period(1, 2) - assert_allclose(nc.get_vardata('PH'), ref[1:3, ...]) - - def test_unstagger_compressed(self): + nc_wrf = WRF(wf) + assert_allclose(nc_wrf.get_vardata('PH'), ref) + nc_wrf.set_period(1, 2) + assert_allclose(nc_wrf.get_vardata('PH'), ref[1:3, ...]) + def test_unstagger_compressed(self) -> None: wf = get_demo_file('wrf_cropped.nc') wfc = get_demo_file('wrf_cropped_compressed.nc') # Under WRF nc = WRF(wf) ncc = WRF(wfc) - assert_allclose(nc.get_vardata('PH'), ncc.get_vardata('PH'), rtol=.003) + assert_allclose( + nc.get_vardata('PH'), ncc.get_vardata('PH'), rtol=0.003 + ) nc.set_period(1, 2) ncc.set_period(1, 2) - assert_allclose(nc.get_vardata('PH'), ncc.get_vardata('PH'), rtol=.003) - - def test_ncl_diagvars(self): + assert_allclose( + nc.get_vardata('PH'), ncc.get_vardata('PH'), rtol=0.003 + ) + def test_ncl_diagvars(self) -> None: wf = get_demo_file('wrf_cropped.nc') ncl_out = get_demo_file('wrf_cropped_ncl.nc') @@ -591,8 +633,7 @@ def test_ncl_diagvars(self): tot = w.get_vardata('SLP') assert_allclose(ref, tot, rtol=1e-6) - def test_ncl_diagvars_compressed(self): - + def test_ncl_diagvars_compressed(self) -> None: wf = get_demo_file('wrf_cropped_compressed.nc') ncl_out = get_demo_file('wrf_cropped_ncl.nc') @@ -609,38 +650,43 @@ def test_ncl_diagvars_compressed(self): tot = w.get_vardata('SLP') assert_allclose(ref, tot, rtol=1e-4) - def test_staggeredcoords(self): - + def test_staggeredcoords(self) -> None: wf = get_demo_file('wrf_cropped.nc') nc = GeoNetcdf(wf) lon, lat = nc.grid.xstagg_ll_coordinates - assert_allclose(np.squeeze(nc.variables['XLONG_U'][0, ...]), lon, - atol=1e-4) - assert_allclose(np.squeeze(nc.variables['XLAT_U'][0, ...]), lat, - atol=1e-4) + assert_allclose( + np.squeeze(nc.variables['XLONG_U'][0, ...]), lon, atol=1e-4 + ) + assert_allclose( + np.squeeze(nc.variables['XLAT_U'][0, ...]), lat, atol=1e-4 + ) lon, lat = nc.grid.ystagg_ll_coordinates - assert_allclose(np.squeeze(nc.variables['XLONG_V'][0, ...]), lon, - atol=1e-4) - assert_allclose(np.squeeze(nc.variables['XLAT_V'][0, ...]), lat, - atol=1e-4) - - def test_staggeredcoords_compressed(self): - + assert_allclose( + np.squeeze(nc.variables['XLONG_V'][0, ...]), lon, atol=1e-4 + ) + assert_allclose( + np.squeeze(nc.variables['XLAT_V'][0, ...]), lat, atol=1e-4 + ) + + def test_staggeredcoords_compressed(self) -> None: wf = get_demo_file('wrf_cropped_compressed.nc') nc = GeoNetcdf(wf) lon, lat = nc.grid.xstagg_ll_coordinates - assert_allclose(np.squeeze(nc.variables['XLONG_U'][0, ...]), lon, - atol=1e-4) - assert_allclose(np.squeeze(nc.variables['XLAT_U'][0, ...]), lat, - atol=1e-4) + assert_allclose( + np.squeeze(nc.variables['XLONG_U'][0, ...]), lon, atol=1e-4 + ) + assert_allclose( + np.squeeze(nc.variables['XLAT_U'][0, ...]), lat, atol=1e-4 + ) lon, lat = nc.grid.ystagg_ll_coordinates - assert_allclose(np.squeeze(nc.variables['XLONG_V'][0, ...]), lon, - atol=1e-4) - assert_allclose(np.squeeze(nc.variables['XLAT_V'][0, ...]), lat, - atol=1e-4) - - def test_har(self): - + assert_allclose( + np.squeeze(nc.variables['XLONG_V'][0, ...]), lon, atol=1e-4 + ) + assert_allclose( + np.squeeze(nc.variables['XLAT_V'][0, ...]), lat, atol=1e-4 + ) + + def test_har(self) -> None: # HAR hf = get_demo_file('har_d30km_y_2d_t2_2000.nc') d = GeoNetcdf(hf) diff --git a/salem/tests/test_gis.py b/salem/tests/test_gis.py index d2ae42f..fce5c77 100644 --- a/salem/tests/test_gis.py +++ b/salem/tests/test_gis.py @@ -1,33 +1,42 @@ -from __future__ import division +from __future__ import annotations +import contextlib import unittest import warnings -import os +from pathlib import Path +from typing import TYPE_CHECKING +import netCDF4 +import numpy as np import pyproj import pytest -import numpy as np -import netCDF4 -from numpy.testing import assert_array_equal, assert_allclose - -from salem import Grid -from salem import wgs84 -import salem.gis as gis +from numpy.testing import assert_allclose, assert_array_equal +from typing_extensions import Self + +from salem import Grid, gis, wgs84 +from salem.tests import ( + python_version, + requires_cartopy, + requires_geopandas, + requires_rasterio, + requires_shapely, +) from salem.utils import get_demo_file -from salem.tests import (requires_shapely, requires_geopandas, - requires_cartopy, requires_rasterio, python_version) + +if TYPE_CHECKING: + import types -class SimpleNcDataSet(): +class SimpleNcDataSet: """Exploratory object to play around. For testing only.""" - def __init__(self, file): + def __init__(self, file: Path) -> None: self.nc = netCDF4.Dataset(file) self.nc.set_auto_mask(False) proj = gis.check_crs(str(self.nc.proj4_str)) x = self.nc.variables['x'] y = self.nc.variables['y'] - dxdy = (x[1]-x[0], y[1]-y[0]) + dxdy = (x[1] - x[0], y[1] - y[0]) nxny = (len(x), len(y)) x0y0 = None if dxdy[1] > 0: @@ -36,41 +45,55 @@ def __init__(self, file): x0y0 = (x[0], y[0]) self.grid = Grid(nxny=nxny, dxdy=dxdy, proj=proj, x0y0=x0y0) - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exception_type, exception_value, traceback): + def __exit__( + self, + exception_type: type[BaseException] | None, + exception_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: self.close() - def close(self): + def close(self) -> None: self.nc.close() class TestGrid(unittest.TestCase): - - def test_constructor(self): - + def test_constructor(self) -> None: # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] + test_file = Path('test.json') for proj in projs: - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) - self.assertTrue(isinstance(g, Grid)) - self.assertEqual(g.center_grid, g.corner_grid) + assert isinstance(g, Grid) + assert g.center_grid == g.corner_grid # serialization d = g.to_dict() rg = Grid.from_dict(d) - self.assertEqual(g, rg) - g.to_json('test.json') - rg = Grid.from_json('test.json') - os.remove('test.json') - self.assertEqual(g, rg) - - oargs = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + assert g == rg + g.to_json(test_file) + rg = Grid.from_json(test_file) + test_file.unlink() + assert g == rg + + oargs = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } og = Grid(**oargs) - self.assertEqual(g, og) + assert g == og # very simple test exp_i, exp_j = np.meshgrid(np.arange(3), np.arange(3)) @@ -88,11 +111,13 @@ def test_constructor(self): assert_allclose(j, exp_j) args['proj'] = 'dummy' - self.assertRaises(ValueError, Grid, **args) + with pytest.raises(TypeError, match='proj should not be None'): + Grid(**args) args['proj'] = proj args['nxny'] = (1, -1) - self.assertRaises(ValueError, Grid, **args) + with pytest.raises(ValueError, match='nxny not valid'): + Grid(**args) args['nxny'] = (3, 3) args['dxdy'] = (1, -1) @@ -100,37 +125,44 @@ def test_constructor(self): del args['x0y0'] with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - self.assertRaises(ValueError, Grid, **args) - self.assertEqual(len(w), 1) + with pytest.raises(ValueError): + Grid(**args) + assert len(w) == 1 args['x0y0'] = args['ll_corner'] del args['ll_corner'] # Center VS corner - multiple times because it was a bug - assert_allclose(g.center_grid.xy_coordinates, - g.xy_coordinates) - assert_allclose(g.center_grid.center_grid.xy_coordinates, - g.xy_coordinates) - assert_allclose(g.corner_grid.corner_grid.xy_coordinates, - g.corner_grid.xy_coordinates) + assert_allclose(g.center_grid.xy_coordinates, g.xy_coordinates) + assert_allclose( + g.center_grid.center_grid.xy_coordinates, g.xy_coordinates + ) + assert_allclose( + g.corner_grid.corner_grid.xy_coordinates, + g.corner_grid.xy_coordinates, + ) ex = g.corner_grid.extent assert_allclose([-0.5, 2.5, -0.5, 2.5], ex) - assert_allclose(g.center_grid.extent, - g.corner_grid.extent) + assert_allclose(g.center_grid.extent, g.corner_grid.extent) args['x0y0'] = (0, 0) g = Grid(**args) - self.assertTrue(isinstance(g, Grid)) - - oargs = dict(nxny=(3, 3), dxdy=(1, -1), x0y0=(0, 0), proj=proj) + assert isinstance(g, Grid) + + oargs = { + 'nxny': (3, 3), + 'dxdy': (1, -1), + 'x0y0': (0, 0), + 'proj': proj, + } og = Grid(**oargs) - self.assertEqual(g, og) + assert g == og # serialization d = og.to_dict() rg = Grid.from_dict(d) - self.assertEqual(og, rg) + assert og == rg # The simple test should work here too i, j = g.ij_coordinates @@ -149,17 +181,18 @@ def test_constructor(self): assert_allclose(j, exp_y) # Center VS corner - multiple times because it was a bug - assert_allclose(g.center_grid.xy_coordinates, - g.xy_coordinates) - assert_allclose(g.center_grid.center_grid.xy_coordinates, - g.xy_coordinates) - assert_allclose(g.corner_grid.corner_grid.xy_coordinates, - g.corner_grid.xy_coordinates) + assert_allclose(g.center_grid.xy_coordinates, g.xy_coordinates) + assert_allclose( + g.center_grid.center_grid.xy_coordinates, g.xy_coordinates + ) + assert_allclose( + g.corner_grid.corner_grid.xy_coordinates, + g.corner_grid.xy_coordinates, + ) ex = g.corner_grid.extent assert_allclose([-0.5, 2.5, -2.5, 0.5], ex) - assert_allclose(g.center_grid.extent, - g.corner_grid.extent) + assert_allclose(g.center_grid.extent, g.corner_grid.extent) # The equivalents g = g.corner_grid @@ -167,83 +200,82 @@ def test_constructor(self): assert_allclose(i, exp_i) assert_allclose(j, exp_j) - exp_x, exp_y = np.meshgrid(np.arange(3)-0.5, -np.arange(3)+0.5) + exp_x, exp_y = np.meshgrid(np.arange(3) - 0.5, -np.arange(3) + 0.5) x, y = g.xy_coordinates assert_allclose(x, exp_x) assert_allclose(y, exp_y) - args = dict(nxny=(3, 2), dxdy=(1, 1), x0y0=(0, 0)) + args = {'nxny': (3, 2), 'dxdy': (1, 1), 'x0y0': (0, 0)} g = Grid(**args) - self.assertTrue(isinstance(g, Grid)) - self.assertTrue(g.xy_coordinates[0].shape == (2, 3)) - self.assertTrue(g.xy_coordinates[1].shape == (2, 3)) + assert isinstance(g, Grid) + assert g.xy_coordinates[0].shape == (2, 3) + assert g.xy_coordinates[1].shape == (2, 3) - def test_comparisons(self): + def test_comparisons(self) -> None: """See if the grids can compare themselves""" - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (3, 3), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g1 = Grid(**args) - self.assertEqual(g1.center_grid, g1.corner_grid) - self.assertTrue(g1.center_grid.almost_equal(g1.center_grid)) + assert g1.center_grid == g1.corner_grid + assert g1.center_grid.almost_equal(g1.center_grid) g2 = Grid(**args) - self.assertEqual(g1, g2) - self.assertTrue(g1.almost_equal(g2)) + assert g1 == g2 + assert g1.almost_equal(g2) - args['dxdy'] = (1. + 1e-6, 1. + 1e-6) + args['dxdy'] = (1.0 + 1e-6, 1.0 + 1e-6) g2 = Grid(**args) - self.assertNotEqual(g1, g2) - self.assertTrue(g1.almost_equal(g2)) + assert g1 != g2 + assert g1.almost_equal(g2) # serialization d = g1.to_dict() rg = Grid.from_dict(d) - self.assertEqual(g1, rg) + assert g1 == rg d = g2.to_dict() rg = Grid.from_dict(d) - self.assertEqual(g2, rg) - self.assertNotEqual(g1, rg) - self.assertTrue(g1.almost_equal(rg)) - g1.to_json('test.json') - rg = Grid.from_json('test.json') - os.remove('test.json') - self.assertEqual(g1, rg) - g2.to_json('test.json') - rg = Grid.from_json('test.json') - os.remove('test.json') - self.assertEqual(g2, rg) - self.assertNotEqual(g1, rg) - self.assertTrue(g1.almost_equal(rg)) + assert g2 == rg + assert g1 != rg + assert g1.almost_equal(rg) + test_file = Path('test.json') + g1.to_json(test_file) + rg = Grid.from_json(test_file) + test_file.unlink() + assert g1 == rg + g2.to_json(test_file) + rg = Grid.from_json(test_file) + test_file.unlink() + assert g2 == rg + assert g1 != rg + assert g1.almost_equal(rg) args['proj'] = gis.check_crs('epsg:26915') g2 = Grid(**args) - self.assertNotEqual(g1, g2) - self.assertFalse(g1.almost_equal(g2)) + assert g1 != g2 + assert not g1.almost_equal(g2) # New instance, same proj args['proj'] = gis.check_crs('epsg:26915') g1 = Grid(**args) - self.assertEqual(g1, g2) - self.assertTrue(g1.almost_equal(g2)) + assert g1 == g2 + assert g1.almost_equal(g2) # serialization d = g1.to_dict() rg = Grid.from_dict(d) - self.assertEqual(g1, rg) - self.assertTrue(g1.almost_equal(rg)) - g1.to_json('test.json') - rg = Grid.from_json('test.json') - os.remove('test.json') - self.assertEqual(g1, rg) - self.assertTrue(g1.almost_equal(rg)) - - def test_reprs(self): - from textwrap import dedent - - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + assert g1 == rg + assert g1.almost_equal(rg) + g1.to_json(test_file) + rg = Grid.from_json(test_file) + test_file.unlink() + assert g1 == rg + assert g1.almost_equal(rg) + + def test_reprs(self) -> None: + args = {'nxny': (3, 3), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g1 = Grid(**args) - self.assertEqual(g1.__repr__(), g1.__str__()) + assert g1.__repr__() == g1.__str__() - def test_errors(self): + def test_errors(self) -> None: """Check that errors are occurring""" # It should work exact same for any projection @@ -251,57 +283,125 @@ def test_errors(self): for proj in projs: with warnings.catch_warnings(): - warnings.simplefilter("ignore") - args = dict(nxny=(3, 3), dxdy=(1, -1), ll_corner=(0, 0), proj=proj) - self.assertRaises(ValueError, Grid, **args) - args = dict(nxny=(3, 3), dxdy=(-1, 0), ul_corner=(0, 0), proj=proj) - self.assertRaises(ValueError, Grid, **args) - args = dict(nxny=(3, 3), dxdy=(1, 1), proj=proj) - self.assertRaises(ValueError, Grid, **args) - args = dict(nxny=(3, -3), dxdy=(1, 1), ll_corner=(0, 0), proj=proj) - self.assertRaises(ValueError, Grid, **args) - args = dict(nxny=(3, 3), dxdy=(1, 1), ll_corner=(0, 0), - proj=proj, pixel_ref='areyoudumb') - self.assertRaises(ValueError, Grid, **args) - - args = dict(nxny=(3, 3), dxdy=(1, 1), ll_corner=(0, 0), proj=proj) + warnings.simplefilter('ignore') + args = { + 'nxny': (3, 3), + 'dxdy': (1, -1), + 'll_corner': (0, 0), + 'proj': proj, + } + with pytest.raises( + ValueError, match='dxdy and input params not compatible' + ): + Grid(**args) + args = { + 'nxny': (3, 3), + 'dxdy': (-1, 0), + 'ul_corner': (0, 0), + 'proj': proj, + } + with pytest.raises( + ValueError, match='dxdy and input params not compatible' + ): + Grid(**args) + args = {'nxny': (3, 3), 'dxdy': (1, 1), 'proj': proj} + with pytest.raises( + ValueError, match='Input params not compatible' + ): + Grid(**args) + args = { + 'nxny': (3, -3), + 'dxdy': (1, 1), + 'll_corner': (0, 0), + 'proj': proj, + } + with pytest.raises(ValueError, match='nxny not valid'): + Grid(**args) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'll_corner': (0, 0), + 'proj': proj, + 'pixel_ref': 'areyoudumb', + } + with pytest.raises( + ValueError, match='pixel_ref not recognized' + ): + Grid(**args) + + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'll_corner': (0, 0), + 'proj': proj, + } g = Grid(**args) - self.assertRaises(ValueError, g.transform, 0, 0, crs=None) - self.assertRaises(ValueError, g.transform, 0, 0, crs='areyou?') - self.assertRaises(ValueError, g.map_gridded_data, - np.zeros((3, 3)), 'areyou?') - self.assertRaises(ValueError, g.map_gridded_data, - np.zeros(3), g) - self.assertRaises(ValueError, g.map_gridded_data, - np.zeros((3, 4)), g) - self.assertRaises(ValueError, g.map_gridded_data, - np.zeros((3, 3)), g, interp='youare') + with pytest.raises( + ValueError, + match='crs must be a pyproj.Proj or salem.Grid, not None', + ): + g.transform(np.array(0), np.array(0), crs=None) + with pytest.raises( + ValueError, + match='salem could not properly parse the provided coordinate', + ): + g.transform(np.array(0), np.array(0), crs='areyou?') + with pytest.raises(TypeError): + g.map_gridded_data(np.zeros((3, 3)), 'areyou?') + with pytest.raises( + ValueError, match='Expected 2D, 3D or 4D data' + ): + g.map_gridded_data(np.zeros(3), g) + with pytest.raises( + ValueError, match='dimension not compatible' + ): + g.map_gridded_data(np.zeros((3, 4)), g) + with pytest.raises( + ValueError, match='interpolation not understood' + ): + g.map_gridded_data(np.zeros((3, 3)), g, interp='youare') # deprecation warnings for proj in projs: with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - args = dict(nxny=(3, 3), dxdy=(1, -1), corner=(0, 0), - proj=proj) + warnings.simplefilter('always') + args = { + 'nxny': (3, 3), + 'dxdy': (1, -1), + 'corner': (0, 0), + 'proj': proj, + } Grid(**args) - args = dict(nxny=(3, 3), dxdy=(1, -1), ul_corner=(0, 0), - proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, -1), + 'ul_corner': (0, 0), + 'proj': proj, + } Grid(**args) - args = dict(nxny=(3, 3), dxdy=(1, 1), ll_corner=(0, 0), - proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'll_corner': (0, 0), + 'proj': proj, + } Grid(**args) if python_version == 'py3': - self.assertEqual(len(w), 3) + assert len(w) == 3 - def test_ij_to_crs(self): + def test_ij_to_crs(self) -> None: """Converting to projection""" # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: - - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) exp_i, exp_j = np.meshgrid(np.arange(3), np.arange(3)) @@ -315,7 +415,7 @@ def test_ij_to_crs(self): # The equivalents gc = g.corner_grid - r_i, r_j = gc.ij_to_crs(exp_i+0.5, exp_j+0.5) + r_i, r_j = gc.ij_to_crs(exp_i + 0.5, exp_j + 0.5) assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) gc = g.center_grid @@ -323,7 +423,12 @@ def test_ij_to_crs(self): assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) - args = dict(nxny=(3, 3), dxdy=(1, -1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, -1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) exp_i, exp_j = np.meshgrid(np.arange(3), -np.arange(3)) in_i, in_j = np.meshgrid(np.arange(3), np.arange(3)) @@ -338,51 +443,66 @@ def test_ij_to_crs(self): # The equivalents gc = g.corner_grid r_i, r_j = gc.ij_to_crs(in_i, in_j) - assert_allclose(exp_i-0.5, r_i, atol=1e-03) - assert_allclose(exp_j+0.5, r_j, atol=1e-03) + assert_allclose(exp_i - 0.5, r_i, atol=1e-03) + assert_allclose(exp_j + 0.5, r_j, atol=1e-03) gc = g.center_grid r_i, r_j = gc.ij_to_crs(in_i, in_j) assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) # if we take some random projection it wont work - proj_out = pyproj.Proj(proj="utm", zone=10, datum='NAD27') + proj_out = pyproj.Proj(proj='utm', zone=10, datum='NAD27') r_i, r_j = g.ij_to_crs(exp_i, exp_j, crs=proj_out) - self.assertFalse(np.allclose(exp_i, r_i)) - self.assertFalse(np.allclose(exp_j, r_j)) + assert not np.allclose(exp_i, r_i) + assert not np.allclose(exp_j, r_j) # Raise - self.assertRaises(ValueError, g.ij_to_crs, exp_i, exp_j, crs='ups') + with pytest.raises(ValueError): + g.ij_to_crs(exp_i, exp_j, crs='ups') - def test_regrid(self): + def test_regrid(self) -> None: """New grids""" # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: - - kargs = [dict(nxny=(3, 2), dxdy=(1, 1), x0y0=(0, 0), - proj=proj), - dict(nxny=(3, 2), dxdy=(1, -1), x0y0=(0, 0), - proj=proj), - dict(nxny=(3, 2), dxdy=(1, 1), x0y0=(0, 0), - proj=proj, pixel_ref='corner'), - dict(nxny=(3, 2), dxdy=(1, -1), x0y0=(0, 0), - proj=proj, pixel_ref='corner')] + kargs = [ + {'nxny': (3, 2), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': proj}, + { + 'nxny': (3, 2), + 'dxdy': (1, -1), + 'x0y0': (0, 0), + 'proj': proj, + }, + { + 'nxny': (3, 2), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + 'pixel_ref': 'corner', + }, + { + 'nxny': (3, 2), + 'dxdy': (1, -1), + 'x0y0': (0, 0), + 'proj': proj, + 'pixel_ref': 'corner', + }, + ] for ka in kargs: g = Grid(**ka) rg = g.regrid() - self.assertTrue(g == rg) + assert g == rg rg = g.regrid(factor=3) assert_array_equal(g.extent, rg.extent) assert_array_equal(g.extent, rg.extent) - bg = rg.regrid(factor=1/3) - self.assertEqual(g, bg) + bg = rg.regrid(factor=1 / 3) + assert g == bg gx, gy = g.center_grid.xy_coordinates rgx, rgy = rg.center_grid.xy_coordinates @@ -395,20 +515,24 @@ def test_regrid(self): assert_allclose(gy, rgy[1::3, 1::3], atol=1e-7) nrg = g.regrid(nx=9) - self.assertTrue(nrg == rg) + assert nrg == rg nrg = g.regrid(ny=6) - self.assertTrue(nrg == rg) + assert nrg == rg - def test_transform(self): + def test_transform(self) -> None: """Converting to the grid""" # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: - - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) exp_i, exp_j = np.meshgrid(np.arange(3), np.arange(3)) @@ -419,11 +543,11 @@ def test_transform(self): assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) r_i, r_j = g.corner_grid.transform(exp_i, exp_j, crs=proj) - assert_allclose(exp_i+0.5, r_i, atol=1e-03) - assert_allclose(exp_j+0.5, r_j, atol=1e-03) + assert_allclose(exp_i + 0.5, r_i, atol=1e-03) + assert_allclose(exp_j + 0.5, r_j, atol=1e-03) r_i, r_j = g.corner_grid.transform(exp_i, exp_j, crs=g) - assert_allclose(exp_i+0.5, r_i, atol=1e-03) - assert_allclose(exp_j+0.5, r_j, atol=1e-03) + assert_allclose(exp_i + 0.5, r_i, atol=1e-03) + assert_allclose(exp_j + 0.5, r_j, atol=1e-03) args['pixel_ref'] = 'corner' g = Grid(**args) @@ -441,17 +565,16 @@ def test_transform(self): assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) r_i, r_j = g.center_grid.transform(exp_i, exp_j, crs=proj) - assert_allclose(exp_i-0.5, r_i, atol=1e-03) - assert_allclose(exp_j-0.5, r_j, atol=1e-03) + assert_allclose(exp_i - 0.5, r_i, atol=1e-03) + assert_allclose(exp_j - 0.5, r_j, atol=1e-03) r_i, r_j = g.center_grid.transform(exp_i, exp_j, crs=g) - assert_allclose(exp_i-0.5, r_i, atol=1e-03) - assert_allclose(exp_j-0.5, r_j, atol=1e-03) + assert_allclose(exp_i - 0.5, r_i, atol=1e-03) + assert_allclose(exp_j - 0.5, r_j, atol=1e-03) ex = g.corner_grid.extent assert_allclose([0, 3, 0, 3], ex, atol=1e-03) - assert_allclose(g.center_grid.extent, - g.corner_grid.extent, - atol=1e-03) - + assert_allclose( + g.center_grid.extent, g.corner_grid.extent, atol=1e-03 + ) # Masked xi = [-0.6, 0.5, 1.2, 2.9, 3.1, 3.6] @@ -469,14 +592,16 @@ def test_transform(self): assert_array_equal(ey, r_j) ex = np.ma.masked_array(ex, mask=[1, 0, 0, 0, 1, 1]) ey = ex - r_i, r_j = g.center_grid.transform(xi, yi, crs=proj, - nearest=True, maskout=True) + r_i, r_j = g.center_grid.transform( + xi, yi, crs=proj, nearest=True, maskout=True + ) assert_array_equal(ex, r_i) assert_array_equal(ey, r_j) assert_array_equal(ex.mask, r_i.mask) assert_array_equal(ey.mask, r_j.mask) - r_i, r_j = g.corner_grid.transform(xi, yi, crs=proj, - nearest=True, maskout=True) + r_i, r_j = g.corner_grid.transform( + xi, yi, crs=proj, nearest=True, maskout=True + ) assert_array_equal(ex, r_i) assert_array_equal(ey, r_j) assert_array_equal(ex.mask, r_i.mask) @@ -496,43 +621,46 @@ def test_transform(self): assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) - def test_lookup_grid(self): - + def test_lookup_grid(self) -> None: data = np.arange(12).reshape((4, 3)) - args = dict(nxny=(3, 4), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (3, 4), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g = Grid(**args) lut = g.grid_lookup(g) for ji, l in lut.items(): - self.assertEqual(data[ji], data[l[:, 0], l[:, 1]]) + assert data[ji] == data[l[:, 0], l[:, 1]] - args = dict(nxny=(2, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (2, 3), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g2 = Grid(**args) lut = g2.grid_lookup(g) - for ji, l in lut.items(): - self.assertEqual(data[ji], data[l[:, 0], l[:, 1]]) + for ji, l in lut.items(): # noqa: E741 + assert data[ji] == data[l[:, 0], l[:, 1]] lut = g.grid_lookup(g2) for (j, i), l in lut.items(): if j > 2 or i > 1: assert l is None else: - self.assertEqual(data[j, i], data[l[:, 0], l[:, 1]]) - - args = dict(nxny=(1, 1), dxdy=(10, 10), x0y0=(0, 0), proj=wgs84) + assert data[j, i] == data[l[:, 0], l[:, 1]] + + args = { + 'nxny': (1, 1), + 'dxdy': (10, 10), + 'x0y0': (0, 0), + 'proj': wgs84, + } g3 = Grid(**args) lut = g3.grid_lookup(g) od = data[lut[(0, 0)][:, 0], lut[(0, 0)][:, 1]] - self.assertEqual(len(od), 12) + assert len(od) == 12 assert_allclose(np.mean(od), np.mean(data)) - def test_lookup_transform(self): - + def test_lookup_transform(self) -> None: data2d = np.arange(12).reshape((4, 3)) data3d = np.stack([data2d, data2d, data2d]) data4d = np.stack([data3d, data3d]) - args = dict(nxny=(3, 4), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (3, 4), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g = Grid(**args) odata = g.lookup_transform(data2d, g) @@ -542,13 +670,13 @@ def test_lookup_transform(self): odata = g.lookup_transform(data4d, g) assert_allclose(odata, data4d) odata, lut = g.lookup_transform(data4d, g, method=len, return_lut=True) - assert_allclose(odata, data4d*0+1) + assert_allclose(odata, data4d * 0 + 1) # set lut odata = g.lookup_transform(data2d, g, lut=lut) assert_allclose(odata, data2d) - args = dict(nxny=(2, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (2, 3), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g2 = Grid(**args) odata = g2.lookup_transform(data2d, g) assert_allclose(odata, data2d[:-1, :-1]) @@ -557,7 +685,11 @@ def test_lookup_transform(self): odata = g2.lookup_transform(data4d, g) assert_allclose(odata, data4d[..., :-1, :-1]) - f = self.assertRaisesRegex if python_version == 'py3' else self.assertRaisesRegexp + f = ( + self.assertRaisesRegex + if python_version == 'py3' + else self.assertRaisesRegexp + ) with f(ValueError, 'dimension not compatible'): g.lookup_transform(data2d[:-1, :-1], g) @@ -569,60 +701,70 @@ def test_lookup_transform(self): assert_allclose(data2d, odata) odata = g.lookup_transform(data2d[:-1, :-1], g2, method=len) - assert_allclose(odata, 1-ref) - - args = dict(nxny=(1, 1), dxdy=(10, 10), x0y0=(0, 0), proj=wgs84) + assert_allclose(odata, 1 - ref) + + args = { + 'nxny': (1, 1), + 'dxdy': (10, 10), + 'x0y0': (0, 0), + 'proj': wgs84, + } g3 = Grid(**args) odata = g3.lookup_transform(data2d, g) - self.assertEqual(odata.shape, (1, 1)) + assert odata.shape == (1, 1) assert_allclose(odata, np.mean(data2d)) odata = g3.lookup_transform(data2d, g, method=np.sum) - self.assertEqual(odata.shape, (1, 1)) + assert odata.shape == (1, 1) assert_allclose(odata, np.sum(data2d)) odata = g3.lookup_transform(data2d, g, method=len) - self.assertEqual(odata.shape, (1, 1)) + assert odata.shape == (1, 1) assert_allclose(odata, 12) # total back and forth data = np.arange(12).reshape((4, 3)) - args = dict(nxny=(3, 4), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (3, 4), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g = Grid(**args) rg = g.regrid(factor=3) tdata = rg.map_gridded_data(data, g, interp='nearest') odata = g.lookup_transform(tdata, rg) assert_allclose(odata, data) odata = g.lookup_transform(tdata, rg, method=len) - assert_allclose(odata, data*0.+9) + assert_allclose(odata, data * 0.0 + 9) - def test_stagg(self): + def test_stagg(self) -> None: """Staggered grids.""" # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: - args = dict(nxny=(3, 2), dxdy=(1, 1), x0y0=(0, 0), - proj=proj, pixel_ref='corner') + args = { + 'nxny': (3, 2), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + 'pixel_ref': 'corner', + } g = Grid(**args) x, y = g.xstagg_xy_coordinates - assert_array_equal(x, np.array([[0,1,2,3], [0,1,2,3]])) - assert_array_equal(y, np.array([[0.5, 0.5, 0.5, 0.5], - [1.5, 1.5, 1.5, 1.5]])) + assert_array_equal(x, np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) + assert_array_equal( + y, np.array([[0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]]) + ) xx, yy = g.corner_grid.xstagg_xy_coordinates assert_array_equal(x, xx) assert_array_equal(y, yy) xt, yt = x, y x, y = g.ystagg_xy_coordinates - assert_array_equal(x, np.array([[0.5, 1.5, 2.5], - [0.5, 1.5, 2.5], - [0.5, 1.5, 2.5]])) - assert_array_equal(y, np.array([[0, 0, 0], - [1, 1, 1], - [2, 2, 2]])) + assert_array_equal( + x, + np.array([[0.5, 1.5, 2.5], [0.5, 1.5, 2.5], [0.5, 1.5, 2.5]]), + ) + assert_array_equal(y, np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])) xx, yy = g.corner_grid.ystagg_xy_coordinates assert_array_equal(x, xx) assert_array_equal(y, yy) @@ -636,14 +778,14 @@ def test_stagg(self): assert_allclose(yt, yy) x, y = g.pixcorner_ll_coordinates - assert_allclose(x, np.array([[0, 1, 2, 3], - [0, 1, 2, 3], - [0, 1, 2, 3]])) - assert_allclose(y, np.array([[0, 0, 0, 0], - [1, 1, 1, 1], - [2, 2, 2, 2]])) - - def test_map_gridded_data(self): + assert_allclose( + x, np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]) + ) + assert_allclose( + y, np.array([[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]) + ) + + def test_map_gridded_data(self) -> None: """Ok now the serious stuff starts with some fake data""" # It should work exact same for any projection @@ -654,51 +796,74 @@ def test_map_gridded_data(self): data = np.arange(nx * ny).reshape((ny, nx)) # Nearest Neighbor - args = dict(nxny=(nx, ny), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (nx, ny), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) odata = g.map_gridded_data(data, g) - self.assertTrue(odata.shape == data.shape) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) # Out of the grid go = Grid(nxny=(nx, ny), dxdy=(1, 1), x0y0=(9, 9), proj=proj) odata = g.map_gridded_data(data, go) odata.set_fill_value(-999) - self.assertTrue(odata.shape == data.shape) - self.assertTrue(np.all(odata.mask)) - - args = dict(nxny=(nx - 1, ny - 1), dxdy=(1, 1), x0y0=(0, 0), - proj=proj) + assert odata.shape == data.shape + assert np.all(odata.mask) + + args = { + 'nxny': (nx - 1, ny - 1), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } ig = Grid(**args) - odata = g.map_gridded_data(data[0:ny - 1, 0:nx - 1], ig) - self.assertTrue(odata.shape == (ny, nx)) - assert_allclose(data[0:ny - 1, 0:nx - 1], - odata[0:ny - 1, 0:nx - 1], atol=1e-03) + odata = g.map_gridded_data(data[0 : ny - 1, 0 : nx - 1], ig) + assert odata.shape == (ny, nx) + assert_allclose( + data[0 : ny - 1, 0 : nx - 1], + odata[0 : ny - 1, 0 : nx - 1], + atol=1e-03, + ) assert_array_equal([True] * 3, odata.mask[ny - 1, :]) data = np.arange(nx * ny).reshape((ny, nx)) * 1.2 - odata = g.map_gridded_data(data[0:ny - 1, 0:nx - 1], ig) - self.assertTrue(odata.shape == (ny, nx)) - assert_allclose(data[0:ny - 1, 0:nx - 1], - odata[0:ny - 1, 0:nx - 1], atol=1e-03) - self.assertTrue( - np.sum(np.isfinite(odata)) == ((ny - 1) * (nx - 1))) + odata = g.map_gridded_data(data[0 : ny - 1, 0 : nx - 1], ig) + assert odata.shape == (ny, nx) + assert_allclose( + data[0 : ny - 1, 0 : nx - 1], + odata[0 : ny - 1, 0 : nx - 1], + atol=1e-03, + ) + assert np.sum(np.isfinite(odata)) == (ny - 1) * (nx - 1) # Bilinear data = np.arange(nx * ny).reshape((ny, nx)) - exp_data = np.array([2., 3., 5., 6., 8., 9.]).reshape( - (ny - 1, nx - 1)) - args = dict(nxny=(nx, ny), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + exp_data = np.array([2.0, 3.0, 5.0, 6.0, 8.0, 9.0]).reshape( + (ny - 1, nx - 1) + ) + args = { + 'nxny': (nx, ny), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } gfrom = Grid(**args) - args = dict(nxny=(nx - 1, ny - 1), dxdy=(1, 1), x0y0=(0.5, 0.5), - proj=proj) + args = { + 'nxny': (nx - 1, ny - 1), + 'dxdy': (1, 1), + 'x0y0': (0.5, 0.5), + 'proj': proj, + } gto = Grid(**args) odata = gto.map_gridded_data(data, gfrom, interp='linear') - self.assertTrue(odata.shape == (ny - 1, nx - 1)) + assert odata.shape == (ny - 1, nx - 1) assert_allclose(exp_data, odata, atol=1e-03) - def test_map_gridded_data_over(self): - + def test_map_gridded_data_over(self) -> None: # It should work exact same for any projection projs = [wgs84, gis.check_crs('epsg:26915')] @@ -710,40 +875,49 @@ def test_map_gridded_data_over(self): in_data[0, :] = 78 # Nearest Neighbor - args = dict(nxny=(nx, ny), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (nx, ny), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) odata = g.map_gridded_data(data, g, out=data.copy()) - self.assertTrue(odata.shape == data.shape) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) odata = g.map_gridded_data(in_data, g, out=data.copy()) - self.assertTrue(odata.shape == data.shape) + assert odata.shape == data.shape assert_allclose(data[1:, :], odata[1:, :], atol=1e-03) assert_allclose(odata[0, :], 78, atol=1e-03) # Bilinear - odata = g.map_gridded_data(data, g, interp='linear', - out=data.copy()) - self.assertTrue(odata.shape == data.shape) + odata = g.map_gridded_data( + data, g, interp='linear', out=data.copy() + ) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) # Spline - odata = g.map_gridded_data(data, g, interp='spline', - out=data.copy()) - self.assertTrue(odata.shape == data.shape) + odata = g.map_gridded_data( + data, g, interp='spline', out=data.copy() + ) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) - @requires_shapely - def test_extent(self): - + def test_extent(self) -> None: # It should work exact same for any projection - args = dict(nxny=(9, 9), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84) + args = {'nxny': (9, 9), 'dxdy': (1, 1), 'x0y0': (0, 0), 'proj': wgs84} g1 = Grid(**args) assert_allclose(g1.extent, g1.extent_in_crs(crs=g1.proj), atol=1e-3) - args = dict(nxny=(9, 9), dxdy=(30000, 30000), x0y0=(0., 1577463), - proj=gis.check_crs('epsg:26915')) + args = { + 'nxny': (9, 9), + 'dxdy': (30000, 30000), + 'x0y0': (0.0, 1577463), + 'proj': gis.check_crs('epsg:26915'), + } g2 = Grid(**args) assert_allclose(g2.extent, g2.extent_in_crs(crs=g2.proj), atol=1e-3) @@ -751,24 +925,25 @@ def test_extent(self): exgx, exgy = g1.ij_to_crs(exg[[0, 1]], exg[[2, 3]], crs=wgs84) lon, lat = g2.corner_grid.ll_coordinates - assert_allclose([np.min(lon), np.min(lat)], [exgx[0], exgy[0]], - rtol=0.1) + assert_allclose( + [np.min(lon), np.min(lat)], [exgx[0], exgy[0]], rtol=0.1 + ) p = g2.extent_as_polygon(crs=g2.proj) assert p.is_valid x, y = p.exterior.coords.xy - assert_allclose([np.min(x), np.max(x), np.min(y), np.max(y)], - g2.extent) + assert_allclose( + [np.min(x), np.max(x), np.min(y), np.max(y)], g2.extent + ) - def test_simple_dataset(self): + def test_simple_dataset(self) -> None: # see if with is working with SimpleNcDataSet(get_demo_file('dem_wgs84.nc')) as nc: - nc = SimpleNcDataSet(get_demo_file('dem_wgs84.nc')) grid_from = nc.grid - self.assertTrue(gis.check_crs(grid_from)) + assert gis.check_crs(grid_from) - def test_map_real_data(self): + def test_map_real_data(self) -> None: """Ok now the serious stuff starts with some real data""" nc = SimpleNcDataSet(get_demo_file('dem_wgs84.nc')) @@ -786,7 +961,7 @@ def test_map_real_data(self): nc = SimpleNcDataSet(get_demo_file('dem_mercator_ul.nc')) data_gdal = nc.nc.variables['dem_gdal'] grid_to = nc.grid - self.assertTrue(grid_to.origin == 'upper-left') + assert grid_to.origin == 'upper-left' odata = grid_to.map_gridded_data(data_from, grid_from, interp='linear') assert_allclose(data_gdal, odata) @@ -828,8 +1003,8 @@ def test_map_real_data(self): assert_allclose(ref_data, odata.filled(np.nan), atol=1e-3) odata = grid_to.map_gridded_data(data, grid_from, interp='spline') - odata[np.where(~ np.isfinite(ref_data))] = np.nan - ref_data[np.where(~ np.isfinite(odata))] = np.nan + odata[np.where(~np.isfinite(ref_data))] = np.nan + ref_data[np.where(~np.isfinite(odata))] = np.nan assert np.sum(np.isfinite(ref_data)) != 0 assert_allclose(ref_data, odata, rtol=0.2, atol=3) @@ -838,13 +1013,13 @@ def test_map_real_data(self): ref_data = np.array([ref_data, ref_data]) odata = grid_to.map_gridded_data(data, grid_from, interp='linear') odata = odata.filled(np.nan) - ref_data[np.where(~ np.isfinite(odata))] = np.nan + ref_data[np.where(~np.isfinite(odata))] = np.nan assert np.sum(np.isfinite(ref_data)) != 0 assert_allclose(ref_data, odata, atol=1e-3) odata = grid_to.map_gridded_data(data, grid_from, interp='spline') - odata[np.where(~ np.isfinite(ref_data))] = np.nan - ref_data[np.where(~ np.isfinite(odata))] = np.nan + odata[np.where(~np.isfinite(ref_data))] = np.nan + ref_data[np.where(~np.isfinite(odata))] = np.nan assert np.sum(np.isfinite(ref_data)) != 0 assert_allclose(ref_data, odata, rtol=0.2, atol=3) @@ -854,193 +1029,243 @@ def test_map_real_data(self): odata = grid_to.map_gridded_data(data, grid_from) # At the borders IDL and Python take other decision on wether it # should be a NaN or not (Python seems to be more conservative) - self.assertTrue(odata.dtype == ref_data.dtype) + assert odata.dtype == ref_data.dtype ref_data[np.where(odata == -999)] = -999 assert_allclose(ref_data, odata.filled(-999)) @requires_shapely - def test_roi(self): - + def test_roi(self) -> None: import shapely.geometry as shpg - g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - p = shpg.Polygon([(1.5, 1.), (2., 1.5), (1.5, 2.), (1., 1.5)]) + g = Grid( + nxny=(3, 3), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + p = shpg.Polygon([(1.5, 1.0), (2.0, 1.5), (1.5, 2.0), (1.0, 1.5)]) roi = g.region_of_interest(geometry=p) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) roi = g.region_of_interest(corners=([0, 0], [2, 2]), crs=wgs84) np.testing.assert_array_equal([[1, 1, 1], [1, 1, 1], [1, 1, 1]], roi) roi = g.region_of_interest(corners=([1.3, 1.3], [1.7, 1.7]), crs=wgs84) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) roi = g.region_of_interest() - np.testing.assert_array_equal([[0,0,0],[0,0,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 0, 0], [0, 0, 0]], roi) - mask = [[0,0,0],[0,1,0],[0,0,0]] + mask = [[0, 0, 0], [0, 1, 0], [0, 0, 0]] roi = g.region_of_interest(roi=mask) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) nc = np.array(p.exterior.coords) + 0.1 p = shpg.Polygon(nc) roi = g.region_of_interest(geometry=p, roi=roi) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) nc = np.array(p.exterior.coords) + 0.5 p = shpg.Polygon(nc) roi = g.region_of_interest(geometry=p, roi=roi) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) nc = np.array(p.exterior.coords) + 0.5 p = shpg.Polygon(nc) roi = g.region_of_interest(geometry=p, roi=roi) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,1]], roi) - - g = Grid(nxny=(4, 2), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - p = shpg.Polygon([(1.5, 1.), (2., 1.5), (1.5, 2.), (1., 1.5)]) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 1]], roi) + + g = Grid( + nxny=(4, 2), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + p = shpg.Polygon([(1.5, 1.0), (2.0, 1.5), (1.5, 2.0), (1.0, 1.5)]) roi = g.region_of_interest(geometry=p) - np.testing.assert_array_equal([[0,0,0,0],[0,1,0,0]], roi) - - g = Grid(nxny=(2, 4), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - p = shpg.Polygon([(1.5, 1.), (2., 1.5), (1.5, 2.), (1., 1.5)]) + np.testing.assert_array_equal([[0, 0, 0, 0], [0, 1, 0, 0]], roi) + + g = Grid( + nxny=(2, 4), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + p = shpg.Polygon([(1.5, 1.0), (2.0, 1.5), (1.5, 2.0), (1.0, 1.5)]) roi = g.region_of_interest(geometry=p) - np.testing.assert_array_equal([[0,0], [0,1], [0,0], [0,0]], roi) - - g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - g2 = Grid(nxny=(1, 1), dxdy=(0.2, 0.2), x0y0=(1.4, 1.4), - proj=wgs84, pixel_ref='corner') + np.testing.assert_array_equal([[0, 0], [0, 1], [0, 0], [0, 0]], roi) + + g = Grid( + nxny=(3, 3), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + g2 = Grid( + nxny=(1, 1), + dxdy=(0.2, 0.2), + x0y0=(1.4, 1.4), + proj=wgs84, + pixel_ref='corner', + ) roi = g.region_of_interest(grid=g2) - np.testing.assert_array_equal([[0,0,0],[0,1,0],[0,0,0]], roi) + np.testing.assert_array_equal([[0, 0, 0], [0, 1, 0], [0, 0, 0]], roi) - def test_to_dataset(self): + def test_to_dataset(self) -> None: projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) ds = g.to_dataset() - self.assertTrue(g == ds.salem.grid) - - g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj, - pixel_ref='corner') + assert g == ds.salem.grid + + g = Grid( + nxny=(3, 3), + dxdy=(1, 1), + x0y0=(0, 0), + proj=proj, + pixel_ref='corner', + ) ds = g.to_dataset() - self.assertTrue(g == ds.salem.grid) + assert g == ds.salem.grid @requires_geopandas - def test_geometry(self): + def test_geometry(self) -> None: projs = [wgs84, gis.check_crs('epsg:26915')] from shapely.geometry import Point + for proj in projs: g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0.5, 0.5), proj=proj) gdf = g.to_geometry() - self.assertEqual(len(gdf), 9) - self.assertTrue(gdf.contains(Point(1.5, 1.5))[4]) - self.assertFalse(gdf.contains(Point(1.5, 1.5))[5]) + assert len(gdf) == 9 + assert gdf.contains(Point(1.5, 1.5))[4] + assert not gdf.contains(Point(1.5, 1.5))[5] gdf = g.to_geometry(to_crs=wgs84) # This is now quite off - self.assertFalse(gdf.contains(Point(1.5, 1.5))[4]) + assert not gdf.contains(Point(1.5, 1.5))[4] - def test_xarray_support(self): + def test_xarray_support(self) -> None: # what happens if we use salem's funcs with xarray? import xarray as xr projs = [wgs84, gis.check_crs('epsg:26915')] for proj in projs: - args = dict(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (3, 3), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) exp_i, exp_j = np.meshgrid(np.arange(3), np.arange(3)) - exp_i, exp_j = (xr.DataArray(exp_i, dims=['y', 'x']), - xr.DataArray(exp_j, dims=['y', 'x'])) + exp_i, exp_j = ( + xr.DataArray(exp_i, dims=['y', 'x']), + xr.DataArray(exp_j, dims=['y', 'x']), + ) r_i, r_j = g.ij_to_crs(exp_i, exp_j) assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) - self.assertTrue(r_i.shape == exp_i.shape) + assert r_i.shape == exp_i.shape # transform r_i, r_j = g.transform(exp_i, exp_j, crs=proj) assert_allclose(exp_i, r_i, atol=1e-03) assert_allclose(exp_j, r_j, atol=1e-03) - self.assertTrue(r_i.shape == exp_i.shape) + assert r_i.shape == exp_i.shape # map nx, ny = (3, 4) data = np.arange(nx * ny).reshape((ny, nx)) - data = xr.DataArray(data, coords={'y':np.arange(ny), - 'x':np.arange(nx)}, - dims=['y', 'x']) + data = xr.DataArray( + data, + coords={'y': np.arange(ny), 'x': np.arange(nx)}, + dims=['y', 'x'], + ) data.attrs = {'test': 'attr'} # Nearest Neighbor - args = dict(nxny=(nx, ny), dxdy=(1, 1), x0y0=(0, 0), proj=proj) + args = { + 'nxny': (nx, ny), + 'dxdy': (1, 1), + 'x0y0': (0, 0), + 'proj': proj, + } g = Grid(**args) odata = g.map_gridded_data(data, g) - self.assertTrue(odata.shape == data.shape) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) # Transform can understand a grid data.attrs['pyproj_srs'] = g.proj.srs odata = g.map_gridded_data(data) - self.assertTrue(odata.shape == data.shape) + assert odata.shape == data.shape assert_allclose(data, odata, atol=1e-03) class TestTransform(unittest.TestCase): - - def test_check_crs_log(self): - + def test_check_crs_log(self) -> None: assert gis.check_crs('wrong') is None - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='salem could not properly parse'): gis.check_crs('wrong', raise_on_error=True) - def test_same_proj(self): - + def test_same_proj(self) -> None: # this should work regardless of gdal or not: - p1 = pyproj.Proj('+proj=utm +zone=15 +datum=NAD83 ' - '+ellps=GRS80 +towgs84=0,0,0 +units=m +no_defs') - p2 = pyproj.Proj('+proj=utm +zone=15 +datum=NAD83 +units=m +no_defs ' - '+ellps=GRS80 +towgs84=0,0,0') - self.assertTrue(gis.proj_is_same(p1, p2)) + p1 = pyproj.Proj( + '+proj=utm +zone=15 +datum=NAD83 ' + '+ellps=GRS80 +towgs84=0,0,0 +units=m +no_defs' + ) + p2 = pyproj.Proj( + '+proj=utm +zone=15 +datum=NAD83 +units=m +no_defs ' + '+ellps=GRS80 +towgs84=0,0,0' + ) + assert gis.proj_is_same(p1, p2) # this needs gdal p1 = gis.check_crs('epsg:26915') - p2 = pyproj.Proj('+proj=utm +zone=15 +ellps=GRS80 +datum=NAD83 ' - '+units=m +no_defs') + p2 = pyproj.Proj( + '+proj=utm +zone=15 +ellps=GRS80 +datum=NAD83 +units=m +no_defs' + ) if gis.has_gdal: - self.assertTrue(gis.proj_is_same(p1, p2)) - - def test_pyproj_trafo(self): + assert gis.proj_is_same(p1, p2) + def test_pyproj_trafo(self) -> None: x = np.random.randn(int(1e6)) * 60 y = np.random.randn(int(1e6)) * 60 - for i in np.arange(3): + for _ in np.arange(3): xx, yy = gis.transform_proj(wgs84, wgs84, x, y) assert_allclose(xx, x) assert_allclose(yy, y) - for i in np.arange(3): + for _ in np.arange(3): xx, yy = gis.transform_proj(wgs84, wgs84, x, y, nocopy=True) assert_allclose(xx, x) assert_allclose(yy, y) - xx, yy = gis.transform_proj(gis.check_crs('epsg:26915'), - gis.check_crs('epsg:26915'), x, y) + xx, yy = gis.transform_proj( + gis.check_crs('epsg:26915'), gis.check_crs('epsg:26915'), x, y + ) assert_allclose(xx, x) assert_allclose(yy, y) @requires_shapely - def test_geometry(self): - + def test_geometry(self) -> None: import shapely.geometry as shpg - g = Grid(nxny=(3, 3), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') - p = shpg.Polygon([(1.5, 1.), (2., 1.5), (1.5, 2.), (1., 1.5)]) + g = Grid( + nxny=(3, 3), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) + p = shpg.Polygon([(1.5, 1.0), (2.0, 1.5), (1.5, 2.0), (1.0, 1.5)]) o = gis.transform_geometry(p, to_crs=g) assert_allclose(p.exterior.coords, o.exterior.coords) @@ -1052,14 +1277,16 @@ def test_geometry(self): assert_allclose(p.exterior.coords, totest) x, y = g.corner_grid.xy_coordinates - p = shpg.MultiPoint([shpg.Point(i, j) for i, j in zip(x.flatten(), - y.flatten())]) + p = shpg.MultiPoint( + [shpg.Point(i, j) for i, j in zip(x.flatten(), y.flatten())] + ) o = gis.transform_geometry(p, to_crs=g.proj) - assert_allclose([_p.coords for _p in o.geoms], - [_p.coords for _p in p.geoms]) + assert_allclose( + [_p.coords for _p in o.geoms], [_p.coords for _p in p.geoms] + ) @requires_geopandas - def test_shape(self): + def test_shape(self) -> None: """Is the transformation doing well?""" from salem import read_shapefile @@ -1067,18 +1294,21 @@ def test_shape(self): so = read_shapefile(get_demo_file('Hintereisferner.shp')) sref = read_shapefile(get_demo_file('Hintereisferner_UTM.shp')) st = gis.transform_geopandas(so, to_crs=sref.crs) - self.assertFalse(st is so) - assert_allclose(st.geometry[0].exterior.coords, - sref.geometry[0].exterior.coords) + assert st is not so + assert_allclose( + st.geometry[0].exterior.coords, sref.geometry[0].exterior.coords + ) sti = gis.transform_geopandas(so, to_crs=sref.crs, inplace=True) - self.assertTrue(sti is so) - assert_allclose(so.geometry[0].exterior.coords, - sref.geometry[0].exterior.coords) - assert_allclose(sti.geometry[0].exterior.coords, - sref.geometry[0].exterior.coords) - - g = Grid(nxny=(1, 1), dxdy=(1, 1), x0y0=(10., 46.), proj=wgs84) + assert sti is so + assert_allclose( + so.geometry[0].exterior.coords, sref.geometry[0].exterior.coords + ) + assert_allclose( + sti.geometry[0].exterior.coords, sref.geometry[0].exterior.coords + ) + + g = Grid(nxny=(1, 1), dxdy=(1, 1), x0y0=(10.0, 46.0), proj=wgs84) so = read_shapefile(get_demo_file('Hintereisferner.shp')) st = gis.transform_geopandas(so, to_crs=g) @@ -1088,21 +1318,23 @@ def test_shape(self): # round trip so_back = gis.transform_geopandas(st, from_crs=g, to_crs=so.crs) - assert_allclose(so_back.geometry[0].exterior.coords, - so.geometry[0].exterior.coords) + assert_allclose( + so_back.geometry[0].exterior.coords, so.geometry[0].exterior.coords + ) class TestGrids(unittest.TestCase): - - def test_mercatorgrid(self): - - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000)) + def test_mercatorgrid(self) -> None: + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), extent=(2000000, 2000000) + ) lon1, lat1 = grid.center_grid.ll_coordinates e1 = grid.extent - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000), - origin='upper-left') + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), + extent=(2000000, 2000000), + origin='upper-left', + ) lon2, lat2 = grid.center_grid.ll_coordinates e2 = grid.extent @@ -1110,15 +1342,17 @@ def test_mercatorgrid(self): assert_allclose(lon1, lon2[::-1, :]) assert_allclose(lat1, lat2[::-1, :]) - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000, 2000), - nx=100) + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), extent=(2000, 2000), nx=100 + ) lon1, lat1 = grid.pixcorner_ll_coordinates e1 = grid.extent - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000, 2000), - origin='upper-left', - nx=100) + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), + extent=(2000, 2000), + origin='upper-left', + nx=100, + ) lon2, lat2 = grid.pixcorner_ll_coordinates e2 = grid.extent @@ -1126,59 +1360,58 @@ def test_mercatorgrid(self): assert_allclose(lon1, lon2[::-1, :]) assert_allclose(lat1, lat2[::-1, :]) - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000, 2000), - nx=10) + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), extent=(2000, 2000), nx=10 + ) e1 = grid.extent - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000, 2000), - origin='upper-left', - nx=9) + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), + extent=(2000, 2000), + origin='upper-left', + nx=9, + ) e2 = grid.extent assert_allclose(e1, e2) -def fuzzy_proj_tester(p1, p2, atol=1e-16): - - d1 = dict() - d2 = dict() +def fuzzy_proj_tester(p1, p2, atol=1e-16) -> None: + d1 = {} + d2 = {} for d, p in zip((d1, d2), (p1, p2)): - for s in p.srs.split('+'): - s = s.split('=') + for i in p.srs.split('+'): + s = i.split('=') if len(s) != 2: continue k = s[0].strip() v = s[1].strip() - try: + with contextlib.suppress(Exception): v = float(v) - except: - pass d[k] = v - for k in d1.keys(): + for k in d1: if k in d2: if d1[k] == d2[k]: # strings continue - else: - try: - assert_allclose(d1[k], d2[k], atol=atol, - err_msg='key: {}'.format(k)) - except TypeError: - assert d1[k] == d2[k] + try: + assert_allclose( + d1[k], d2[k], atol=atol, err_msg='key: {}'.format(k) + ) + except TypeError: + assert d1[k] == d2[k] class TestCartopy(unittest.TestCase): - @requires_cartopy @requires_rasterio - def test_to_cartopy(self): - + def test_to_cartopy(self) -> None: import cartopy.crs as ccrs + from salem import GeoNetcdf, GeoTiff - grid = gis.mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000)) + grid = gis.mercator_grid( + center_ll=(11.38, 47.26), extent=(2000000, 2000000) + ) p = gis.proj_to_cartopy(grid.proj) assert isinstance(p, ccrs.TransverseMercator) fuzzy_proj_tester(grid.proj, pyproj.Proj(p.proj4_params)) diff --git a/salem/tests/test_graphics.py b/salem/tests/test_graphics.py index 68c186c..0edb644 100644 --- a/salem/tests/test_graphics.py +++ b/salem/tests/test_graphics.py @@ -1,79 +1,84 @@ -from __future__ import division -from packaging.version import Version - -import warnings -import os import shutil +import unittest +import warnings +from pathlib import Path import numpy as np import pytest -import unittest from numpy.testing import assert_array_equal +from packaging.version import Version try: import matplotlib as mpl + import matplotlib.pyplot as plt except ImportError: - pytest.skip("Requires matplotlib", allow_module_level=True) + pytest.skip('Requires matplotlib', allow_module_level=True) try: import shapely.geometry as shpg except ImportError: - pytest.skip("Requires shapely", allow_module_level=True) + pytest.skip('Requires shapely', allow_module_level=True) +import geopandas as gpd import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable -import geopandas as gpd MPL_VERSION = Version(mpl.__version__) ftver = Version(mpl.ft2font.__freetype_version__) -if ftver >= Version('2.8.0'): - freetype_subdir = 'freetype_28' -else: - freetype_subdir = 'freetype_old' - -from salem.graphics import ExtendedNorm, DataLevels, Map, get_cmap, shapefiles -from salem import graphics -from salem import (Grid, wgs84, mercator_grid, GeoNetcdf, - read_shapefile_to_grid, GeoTiff, GoogleCenterMap, - GoogleVisibleMap, open_wrf_dataset, open_xr_dataset, - python_version, cache_dir, sample_data_dir) +freetype_subdir = ( + 'freetype_28' if ftver >= Version('2.8.0') else 'freetype_old' +) + +from salem import ( + GeoNetcdf, + GeoTiff, + GoogleCenterMap, + GoogleVisibleMap, + Grid, + graphics, + mercator_grid, + open_wrf_dataset, + open_xr_dataset, + python_version, + read_shapefile_to_grid, + sample_data_dir, + wgs84, +) +from salem.graphics import DataLevels, ExtendedNorm, Map, get_cmap, shapefiles +from salem.tests import requires_cartopy, requires_matplotlib from salem.utils import get_demo_file -from salem.tests import (requires_matplotlib, requires_cartopy) # Globals -current_dir = os.path.dirname(os.path.abspath(__file__)) -testdir = os.path.join(current_dir, 'tmp') +current_dir = Path(__file__).parent +testdir = current_dir / 'tmp' baseline_subdir = '2.0.x' -baseline_dir = os.path.join(sample_data_dir, 'baseline_images', - baseline_subdir, freetype_subdir) +baseline_dir = ( + sample_data_dir / 'baseline_images' / baseline_subdir / freetype_subdir +) tolpy2 = 5 if python_version == 'py3' else 10 -def _create_dummy_shp(fname): - if not os.path.exists(testdir): - os.makedirs(testdir) +def _create_dummy_shp(fname: Path) -> Path: + if not testdir.exists(): + testdir.mkdir(parents=True) - e_line = shpg.LinearRing([(1.5, 1), (2., 1.5), (1.5, 2.), (1, 1.5)]) + e_line = shpg.LinearRing([(1.5, 1), (2.0, 1.5), (1.5, 2.0), (1, 1.5)]) i_line = shpg.LinearRing([(1.4, 1.4), (1.6, 1.4), (1.6, 1.6), (1.4, 1.6)]) p1 = shpg.Polygon(e_line, [i_line]) - p2 = shpg.Polygon([(2.5, 1.3), (3., 1.8), (2.5, 2.3), (2, 1.8)]) - p3 = shpg.Point(0.5, 0.5) - p4 = shpg.Point(1, 1) + p2 = shpg.Polygon([(2.5, 1.3), (3.0, 1.8), (2.5, 2.3), (2, 1.8)]) df = gpd.GeoDataFrame() df['name'] = ['Polygon', 'Line'] df.set_geometry(gpd.GeoSeries([p1, p2]), crs='epsg:4326', inplace=True) - of = os.path.join(testdir, fname) + of = testdir / fname df.to_file(of) return of class TestColors(unittest.TestCase): - @requires_matplotlib - def test_extendednorm(self): - + def test_extendednorm(self) -> None: bounds = [1, 2, 3] cm = mpl.colormaps.get_cmap('jet') @@ -82,7 +87,7 @@ def test_extendednorm(self): x = np.random.randn(100) * 10 - 5 np.testing.assert_array_equal(refnorm(x), mynorm(x)) - refnorm = mpl.colors.BoundaryNorm([0] + bounds + [4], cm.N) + refnorm = mpl.colors.BoundaryNorm([0, *bounds, 4], cm.N) mynorm = graphics.ExtendedNorm(bounds, cm.N, extend='both') x = np.random.random(100) + 1.5 np.testing.assert_array_equal(refnorm(x), mynorm(x)) @@ -101,7 +106,9 @@ def test_extendednorm(self): np.testing.assert_array_equal(refnorm.vmin, mynorm.vmin) np.testing.assert_array_equal(refnorm.vmax, mynorm.vmax) x = [-1, 1.2, 2.3, 9.6] - np.testing.assert_array_equal(cmshould([0,1,2,3]), cmshould(mynorm(x))) + np.testing.assert_array_equal( + cmshould([0, 1, 2, 3]), cmshould(mynorm(x)) + ) x = np.random.randn(100) * 10 + 2 np.testing.assert_array_equal(cmref(refnorm(x)), cmshould(mynorm(x))) @@ -122,7 +129,7 @@ def test_extendednorm(self): np.testing.assert_array_equal(refnorm.vmin, mynorm.vmin) np.testing.assert_array_equal(refnorm.vmax, mynorm.vmax) x = [-1, 1.2, 2.3] - np.testing.assert_array_equal(cmshould([0,1,2]), cmshould(mynorm(x))) + np.testing.assert_array_equal(cmshould([0, 1, 2]), cmshould(mynorm(x))) x = np.random.randn(100) * 10 + 2 np.testing.assert_array_equal(cmref(refnorm(x)), cmshould(mynorm(x))) @@ -139,7 +146,7 @@ def test_extendednorm(self): np.testing.assert_array_equal(refnorm.vmin, mynorm.vmin) np.testing.assert_array_equal(refnorm.vmax, mynorm.vmax) x = [1.2, 2.3, 4] - np.testing.assert_array_equal(cmshould([0,1,2]), cmshould(mynorm(x))) + np.testing.assert_array_equal(cmshould([0, 1, 2]), cmshould(mynorm(x))) x = np.random.randn(100) * 10 + 2 np.testing.assert_array_equal(cmref(refnorm(x)), cmshould(mynorm(x))) @@ -147,114 +154,122 @@ def test_extendednorm(self): bounds = [1, 2, 3, 4] cm = mpl.colormaps.get_cmap('jet') mynorm = graphics.ExtendedNorm(bounds, cm.N, extend='both') - refnorm = mpl.colors.BoundaryNorm([-100] + bounds + [100], cm.N) + refnorm = mpl.colors.BoundaryNorm([-100, *bounds, 100], cm.N) x = np.random.randn(100) * 10 - 5 ref = refnorm(x) ref = np.where(ref == 0, -1, ref) - ref = np.where(ref == cm.N-1, cm.N, ref) + ref = np.where(ref == cm.N - 1, cm.N, ref) np.testing.assert_array_equal(ref, mynorm(x)) class TestGraphics(unittest.TestCase): - @requires_matplotlib - def test_datalevels_output(self): - + def test_datalevels_output(self) -> None: # Test basic stuffs c = graphics.DataLevels(nlevels=2) assert_array_equal(c.levels, [0, 1]) - c.set_data([1, 2, 3, 4]) + c.set_data(np.array([1, 2, 3, 4])) assert_array_equal(c.levels, [1, 4]) - c = graphics.DataLevels(levels=[1, 2, 3]) + c = graphics.DataLevels(levels=np.array([1, 2, 3])) assert_array_equal(c.levels, [1, 2, 3]) - c = graphics.DataLevels(nlevels=10, data=[0, 9]) + c = graphics.DataLevels(nlevels=10, data=np.array([0, 9])) assert_array_equal(c.levels, np.linspace(0, 9, num=10)) - self.assertTrue(c.extend == 'neither') + assert c.extend == 'neither' - c = graphics.DataLevels(nlevels=10, data=[0, 9], vmin=2, vmax=3) + c = graphics.DataLevels( + nlevels=10, data=np.array([0, 9]), vmin=2, vmax=3 + ) assert_array_equal(c.levels, np.linspace(2, 3, num=10)) - self.assertTrue(c.extend == 'both') + assert c.extend == 'both' c.set_extend('neither') - self.assertTrue(c.extend == 'neither') + assert c.extend == 'neither' with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. - warnings.simplefilter("always") + warnings.simplefilter('always') # Trigger a warning. - out = c.to_rgb() + c.to_rgb() # Verify some things assert len(w) == 2 assert issubclass(w[0].category, RuntimeWarning) assert issubclass(w[1].category, RuntimeWarning) - c = graphics.DataLevels(nlevels=10, data=[2.5], vmin=2, vmax=3) + c = graphics.DataLevels( + nlevels=10, data=np.array([2.5]), vmin=2, vmax=3 + ) assert_array_equal(c.levels, np.linspace(2, 3, num=10)) - self.assertTrue(c.extend == 'neither') - c.update(dict(extend='both')) - self.assertTrue(c.extend == 'both') - self.assertRaises(AttributeError, c.update, dict(dummy='t')) + assert c.extend == 'neither' + c.update({'extend': 'both'}) + assert c.extend == 'both' + with pytest.raises(AttributeError): + c.update({'dummy': 't'}) - c = graphics.DataLevels(nlevels=10, data=[0, 9], vmax=3) + c = graphics.DataLevels(nlevels=10, data=np.array([0, 9]), vmax=3) assert_array_equal(c.levels, np.linspace(0, 3, num=10)) - self.assertTrue(c.extend == 'max') + assert c.extend == 'max' - c = graphics.DataLevels(nlevels=10, data=[0, 9], vmin=1) + c = graphics.DataLevels(nlevels=10, data=np.array([0, 9]), vmin=1) assert_array_equal(c.levels, np.linspace(1, 9, num=10)) - self.assertTrue(c.extend == 'min') + assert c.extend == 'min' - c = graphics.DataLevels(nlevels=10, data=[0, 9], vmin=-1) + c = graphics.DataLevels(nlevels=10, data=np.array([0, 9]), vmin=-1) assert_array_equal(c.levels, np.linspace(-1, 9, num=10)) - self.assertTrue(c.extend == 'neither') + assert c.extend == 'neither' c.set_plot_params() - self.assertTrue(c.extend == 'neither') + assert c.extend == 'neither' assert_array_equal(c.vmin, 0) assert_array_equal(c.vmax, 9) c.set_plot_params(vmin=1) assert_array_equal(c.vmin, 1) - c.set_data([-12, 8]) + c.set_data(np.array([-12, 8])) assert_array_equal(c.vmin, 1) - self.assertTrue(c.extend == 'min') - c.set_data([2, 8]) - self.assertTrue(c.extend == 'neither') + assert c.extend == 'min' + c.set_data(np.array([2, 8])) + assert c.extend == 'neither' c.set_extend('both') - self.assertTrue(c.extend == 'both') - c.set_data([3, 3]) - self.assertTrue(c.extend == 'both') + assert c.extend == 'both' + c.set_data(np.array([3, 3])) + assert c.extend == 'both' c.set_extend() - self.assertTrue(c.extend == 'neither') + assert c.extend == 'neither' # Test the conversion cm = mpl.colors.ListedColormap(['white', 'blue', 'red', 'black']) x = [-1, 0.9, 1.2, 2, 999, 0.8] - c = graphics.DataLevels(levels=[0, 1, 2], data=x, cmap=cm) + c = graphics.DataLevels(levels=np.array([0, 1, 2]), data=x, cmap=cm) r = c.to_rgb() - self.assertTrue(len(x) == len(r)) - self.assertTrue(c.extend == 'both') + assert len(x) == len(r) + assert c.extend == 'both' assert_array_equal(r, cm([0, 1, 2, 3, 3, 1])) x = [0.9, 1.2] - c = graphics.DataLevels(levels=[0, 1, 2], data=x, cmap=cm, extend='both') + c = graphics.DataLevels( + levels=np.array([0, 1, 2]), data=x, cmap=cm, extend='both' + ) r = c.to_rgb() - self.assertTrue(len(x) == len(r)) - self.assertTrue(c.extend == 'both') + assert len(x) == len(r) + assert c.extend == 'both' assert_array_equal(r, cm([1, 2])) cm = mpl.colors.ListedColormap(['white', 'blue', 'red']) - c = graphics.DataLevels(levels=[0, 1, 2], data=x, cmap=cm, extend='min') + c = graphics.DataLevels( + levels=np.array([0, 1, 2]), data=x, cmap=cm, extend='min' + ) r = c.to_rgb() - self.assertTrue(len(x) == len(r)) + assert len(x) == len(r) assert_array_equal(r, cm([1, 2])) cm = mpl.colors.ListedColormap(['blue', 'red', 'black']) - c = graphics.DataLevels(levels=[0, 1, 2], data=x, cmap=cm, extend='max') + c = graphics.DataLevels( + levels=np.array([0, 1, 2]), data=x, cmap=cm, extend='max' + ) r = c.to_rgb() - self.assertTrue(len(x) == len(r)) + assert len(x) == len(r) assert_array_equal(r, cm([0, 1])) @requires_matplotlib - def test_map(self): - + def test_map(self) -> None: a = np.zeros((4, 5)) a[0, 0] = -1 a[1, 1] = 1.1 @@ -265,14 +280,20 @@ def test_map(self): cmap = mpl.colormaps.get_cmap('jet').copy() except AttributeError: import copy + cmap = copy.deepcopy(mpl.colormaps.get_cmap('jet')) # ll_corner (type geotiff) - g = Grid(nxny=(5, 4), dxdy=(1, 1), x0y0=(0, 0), proj=wgs84, - pixel_ref='corner') + g = Grid( + nxny=(5, 4), + dxdy=(1, 1), + x0y0=(0, 0), + proj=wgs84, + pixel_ref='corner', + ) c = graphics.Map(g, ny=4, countries=False) - c.set_cmap(cmap) - c.set_plot_params(levels=[0, 1, 2, 3]) + c.set_cmap(cmap.name) + c.set_plot_params(levels=np.array([0, 1, 2, 3])) c.set_data(a) rgb1 = c.to_rgb() c.set_data(a, crs=g) @@ -283,11 +304,16 @@ def test_map(self): assert_array_equal(rgb1, c.to_rgb()) # centergrid (type WRF) - g = Grid(nxny=(5, 4), dxdy=(1, 1), x0y0=(0.5, 0.5), proj=wgs84, - pixel_ref='center') + g = Grid( + nxny=(5, 4), + dxdy=(1, 1), + x0y0=(0.5, 0.5), + proj=wgs84, + pixel_ref='center', + ) c = graphics.Map(g, ny=4, countries=False) - c.set_cmap(cmap) - c.set_plot_params(levels=[0, 1, 2, 3]) + c.set_cmap(cmap.name) + c.set_plot_params(levels=np.array([0, 1, 2, 3])) c.set_data(a) rgb1 = c.to_rgb() c.set_data(a, crs=g) @@ -302,7 +328,7 @@ def test_map(self): # More pixels c = graphics.Map(g, ny=500, countries=False) c.set_cmap(cmap) - c.set_plot_params(levels=[0, 1, 2, 3]) + c.set_plot_params(levels=np.array([0, 1, 2, 3])) c.set_data(a) rgb1 = c.to_rgb() c.set_data(a, crs=g) @@ -315,12 +341,16 @@ def test_map(self): # The interpolation is conservative with the grid... srgb = np.sum(rgb2[..., 0:3], axis=2) pok = np.nonzero(srgb != srgb[0, 0]) - rgb1 = rgb1[np.min(pok[0])+1:np.max(pok[0]-1), - np.min(pok[1])+1:np.max(pok[1]-1), - ...] - rgb2 = rgb2[np.min(pok[0])+1:np.max(pok[0]-1), - np.min(pok[1])+1:np.max(pok[1]-1), - ...] + rgb1 = rgb1[ + np.min(pok[0]) + 1 : np.max(pok[0] - 1), + np.min(pok[1]) + 1 : np.max(pok[1] - 1), + ..., + ] + rgb2 = rgb2[ + np.min(pok[0]) + 1 : np.max(pok[0] - 1), + np.min(pok[1]) + 1 : np.max(pok[1] - 1), + ..., + ] assert_array_equal(rgb1, rgb2) @@ -338,25 +368,26 @@ def test_map(self): rgb1 = c.to_rgb() c.set_data(a, crs=g, interp='linear') rgb2 = c.to_rgb() - # Todo: there's something sensibly wrong about imresize here + # TODO: there's something sensibly wrong about imresize here # but I think it is out of my scope # assert_array_equal(rgb1, rgb2) @requires_matplotlib - def test_increase_coverage(self): - + def test_increase_coverage(self) -> None: # Just for coverage -> empty shapes should not trigger an error grid = mercator_grid(center_ll=(-20, 40), extent=(2000, 2000), nx=10) c = graphics.Map(grid) # Assigning wrongly shaped data should, however - self.assertRaises(ValueError, c.set_data, np.zeros((3, 8))) + with pytest.raises( + ValueError, match='Dimensions of data do not match' + ): + c.set_data(np.zeros((3, 8))) @requires_matplotlib -@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, - tolerance=10) -def test_extendednorm(): +@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=10) +def test_extendednorm() -> plt.Figure: a = np.zeros((4, 5)) a[0, 0] = -9999 a[1, 1] = 1.1 @@ -372,25 +403,25 @@ def test_extendednorm(): fig = plt.figure() ax1 = fig.add_subplot(1, 2, 1) ax2 = fig.add_subplot(1, 2, 2) - imax = ax1.imshow(a, interpolation='None', norm=norm, cmap=cm, - origin='lower'); + imax = ax1.imshow( + a, interpolation='None', norm=norm, cmap=cm, origin='lower' + ) divider = make_axes_locatable(ax1) - cax = divider.append_axes("right", size="5%", pad=0.2) + cax = divider.append_axes('right', size='5%', pad=0.2) plt.colorbar(imax, cax=cax, extend='both') ti = cm(norm(a)) ax2.imshow(ti, interpolation='None', origin='lower') divider = make_axes_locatable(ax2) - cax = divider.append_axes("right", size="5%", pad=0.2) - cbar = mpl.colorbar.ColorbarBase(cax, extend='both', cmap=cm, - norm=norm) + cax = divider.append_axes('right', size='5%', pad=0.2) + mpl.colorbar.ColorbarBase(cax, extend='both', cmap=cm, norm=norm) fig.tight_layout() return fig @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=10) -def test_datalevels(): +def test_datalevels() -> plt.Figure: plt.close() a = np.zeros((4, 5)) @@ -404,6 +435,7 @@ def test_datalevels(): cm = mpl.colormaps.get_cmap('jet').copy() except AttributeError: import copy + cm = copy.deepcopy(mpl.colormaps.get_cmap('jet')) cm.set_bad('pink') @@ -412,33 +444,35 @@ def test_datalevels(): ax = iter([fig.add_subplot(3, 2, i) for i in [1, 2, 3, 4, 5, 6]]) # The extended version should be automated - c = DataLevels(levels=[0, 1, 2, 3], data=a, cmap=cm) + c = DataLevels(levels=np.array([0, 1, 2, 3]), data=a, cmap=cm.name) c.visualize(next(ax), title='levels=[0,1,2,3]') # Without min a[0, 0] = 0 - c = DataLevels(levels=[0, 1, 2, 3], data=a, cmap=cm) + c = DataLevels(levels=np.array([0, 1, 2, 3]), data=a, cmap=cm.name) c.visualize(next(ax), title='modified a for no min oob') # Without max a[3, 3] = 0 - c = DataLevels(levels=[0, 1, 2, 3], data=a, cmap=cm) + c = DataLevels(levels=np.array([0, 1, 2, 3]), data=a, cmap=cm.name) c.visualize(next(ax), title='modified a for no max oob') # Forced bounds - c = DataLevels(levels=[0, 1, 2, 3], data=a, cmap=cm, extend='both') + c = DataLevels( + levels=np.array([0, 1, 2, 3]), data=a, cmap=cm.name, extend='both' + ) c.visualize(next(ax), title="extend='both'") # Autom nlevels a[0, 0] = -1 a[3, 3] = 9 - c = DataLevels(nlevels=127, vmin=0, vmax=3, data=a, cmap=cm) - c.visualize(next(ax), title="Auto levels with oob data") + c = DataLevels(nlevels=127, vmin=0, vmax=3, data=a, cmap=cm.name) + c.visualize(next(ax), title='Auto levels with oob data') # Missing data a[3, 0] = np.nan - c = DataLevels(nlevels=127, vmin=0, vmax=3, data=a, cmap=cm) - c.visualize(next(ax), title="missing data") + c = DataLevels(nlevels=127, vmin=0, vmax=3, data=a, cmap=cm.name) + c.visualize(next(ax), title='missing data') plt.tight_layout() @@ -447,11 +481,11 @@ def test_datalevels(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=5) -def test_datalevels_visu_h(): - a = np.array([-1., 0., 1.1, 1.9, 9.]) +def test_datalevels_visu_h() -> plt.Figure: + a = np.array([-1.0, 0.0, 1.1, 1.9, 9.0]) cm = mpl.colormaps.get_cmap('RdYlBu_r') - dl = DataLevels(a, cmap=cm, levels=[0, 1, 2, 3]) + dl = DataLevels(a, cmap=cm.name, levels=np.array([0, 1, 2, 3])) fig, ax = plt.subplots(1) dl.visualize(ax=ax, orientation='horizontal', add_values=True) @@ -461,11 +495,13 @@ def test_datalevels_visu_h(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_datalevels_visu_v(): - a = np.array([-1., 0., 1.1, 1.9, 9.]) +def test_datalevels_visu_v() -> plt.Figure: + a = np.array([-1.0, 0.0, 1.1, 1.9, 9.0]) cm = mpl.colormaps.get_cmap('RdYlBu_r') - dl = DataLevels(a.reshape((5, 1)), cmap=cm, levels=[0, 1, 2, 3]) + dl = DataLevels( + a.reshape((5, 1)), cmap=cm.name, levels=np.array([0, 1, 2, 3]) + ) fig, ax = plt.subplots(1) dl.visualize(ax=ax, orientation='vertical', add_values=True) @@ -475,7 +511,7 @@ def test_datalevels_visu_v(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=10) -def test_simple_map(): +def test_simple_map() -> plt.Figure: a = np.zeros((4, 5)) a[0, 0] = -1 a[1, 1] = 1.1 @@ -483,22 +519,24 @@ def test_simple_map(): a[2, 4] = 1.9 a[3, 3] = 9 a_inv = a[::-1, :] - fs = _create_dummy_shp('fs.shp') + fs = _create_dummy_shp(Path('fs.shp')) # UL Corner - g1 = Grid(nxny=(5, 4), dxdy=(1, -1), x0y0=(-1, 3), proj=wgs84, - pixel_ref='corner') + g1 = Grid( + nxny=(5, 4), dxdy=(1, -1), x0y0=(-1, 3), proj=wgs84, pixel_ref='corner' + ) c1 = Map(g1, ny=4, countries=False) # LL Corner - g2 = Grid(nxny=(5, 4), dxdy=(1, 1), x0y0=(-1, -1), proj=wgs84, - pixel_ref='corner') + g2 = Grid( + nxny=(5, 4), dxdy=(1, 1), x0y0=(-1, -1), proj=wgs84, pixel_ref='corner' + ) c2 = Map(g2, ny=4, countries=False) # Settings for c, data in zip([c1, c2], [a_inv, a]): - c.set_cmap(mpl.colormaps.get_cmap('jet')) - c.set_plot_params(levels=[0, 1, 2, 3]) + c.set_cmap(mpl.colormaps.get_cmap('jet').name) + c.set_plot_params(levels=np.array([0, 1, 2, 3])) c.set_data(data) c.set_shapefile(fs) c.set_lonlat_contours(interval=0.5) @@ -514,7 +552,7 @@ def test_simple_map(): c2 = Map(g2, ny=400, countries=False) # Settings for c, data, g in zip([c1, c2], [a_inv, a], [g1, g2]): - c.set_cmap(mpl.colormaps.get_cmap('jet')) + c.set_cmap(mpl.colormaps.get_cmap('jet').name) c.set_data(data, crs=g) c.set_shapefile(fs) c.set_plot_params(nlevels=256) @@ -535,14 +573,14 @@ def test_simple_map(): c2.visualize(ax2) fig.tight_layout() - if os.path.exists(testdir): + if testdir.exists(): shutil.rmtree(testdir) return fig @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=12) -def test_contourf(): +def test_contourf() -> plt.Figure: a = np.zeros((4, 5)) a[0, 0] = -1 a[1, 1] = 1.1 @@ -551,23 +589,26 @@ def test_contourf(): a[3, 3] = 9 # UL Corner - g = Grid(nxny=(5, 4), dxdy=(1, -1), x0y0=(-1, 3), proj=wgs84, - pixel_ref='corner') + g = Grid( + nxny=(5, 4), dxdy=(1, -1), x0y0=(-1, 3), proj=wgs84, pixel_ref='corner' + ) c = Map(g, ny=400, countries=False) - c.set_cmap(mpl.colormaps.get_cmap('viridis')) - c.set_plot_params(levels=[0, 1, 2, 3]) + c.set_cmap(mpl.colormaps.get_cmap('viridis').name) + c.set_plot_params(levels=np.array([0, 1, 2, 3])) c.set_data(a) - s = a * 0. + s = a * 0.0 s[2, 2] = 1 - c.set_contourf(s, interp='linear', hatches=['xxx'], colors='none', - levels=[0.5, 1.5]) + c.set_contourf( + s, interp='linear', hatches=['xxx'], colors='none', levels=[0.5, 1.5] + ) - s = a * 0. + s = a * 0.0 s[0:2, 3:] = 1 s[0, 4] = 2 - c.set_contour(s, interp='linear', colors='k', linewidths=6, - levels=[0.5, 1., 1.5]) + c.set_contour( + s, interp='linear', colors='k', linewidths=6, levels=[0.5, 1.0, 1.5] + ) c.set_lonlat_contours(interval=0.5) @@ -588,18 +629,20 @@ def test_contourf(): return fig + @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_merca_map(): - grid = mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000)) +def test_merca_map() -> plt.Figure: + grid = mercator_grid(center_ll=(11.38, 47.26), extent=(2000000, 2000000)) m1 = Map(grid) m1.set_scale_bar(color='red') - grid = mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000), - origin='upper-left') + grid = mercator_grid( + center_ll=(11.38, 47.26), + extent=(2000000, 2000000), + origin='upper-left', + ) m2 = Map(grid) m2.set_scale_bar(length=700000, location=(0.3, 0.05)) @@ -613,9 +656,8 @@ def test_merca_map(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_merca_nolabels(): - grid = mercator_grid(center_ll=(11.38, 47.26), - extent=(2000000, 2000000)) +def test_merca_nolabels() -> plt.Figure: + grid = mercator_grid(center_ll=(11.38, 47.26), extent=(2000000, 2000000)) m1 = Map(grid) @@ -629,8 +671,8 @@ def test_merca_nolabels(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=5) -def test_oceans(): - f = os.path.join(get_demo_file('wrf_tip_d1.nc')) +def test_oceans() -> plt.Figure: + f = get_demo_file('wrf_tip_d1.nc') grid = GeoNetcdf(f).grid m = Map(grid, countries=False) m.set_shapefile(rivers=True, linewidths=2) @@ -646,19 +688,24 @@ def test_oceans(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_geometries(): - +def test_geometries() -> plt.Figure: # UL Corner - g = Grid(nxny=(5, 4), dxdy=(10, 10), x0y0=(-20, -15), proj=wgs84, - pixel_ref='corner') + g = Grid( + nxny=(5, 4), + dxdy=(10, 10), + x0y0=(-20, -15), + proj=wgs84, + pixel_ref='corner', + ) c = Map(g, ny=4) - c.set_lonlat_contours(interval=10., colors='crimson', linewidths=1) + c.set_lonlat_contours(interval=10.0, colors='crimson', linewidths=1) c.set_geometry(shpg.Point(10, 10), color='darkred', markersize=60) - c.set_geometry(shpg.Point(5, 5), s=500, marker='s', - facecolor='green', hatch='||||') + c.set_geometry( + shpg.Point(5, 5), s=500, marker='s', facecolor='green', hatch='||||' + ) - s = np.array([(-5, -10), (0., -5), (-5, 0.), (-10, -5)]) + s = np.array([(-5, -10), (0.0, -5), (-5, 0.0), (-10, -5)]) l1 = shpg.LineString(s) l2 = shpg.LinearRing(s + 3) c.set_geometry(l1) @@ -672,8 +719,7 @@ def test_geometries(): p2 = shpg.Point(20, 20) p3 = shpg.Point(10, 20) mpoints = shpg.MultiPoint([p1, p2, p3]) - c.set_geometry(mpoints, s=250, marker='s', - c='purple', hatch='||||') + c.set_geometry(mpoints, s=250, marker='s', c='purple', hatch='||||') c.set_scale_bar(color='blue') @@ -689,34 +735,62 @@ def test_geometries(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=8) -def test_text(): +def test_text() -> plt.Figure: # UL Corner - g = Grid(nxny=(5, 4), dxdy=(10, 10), x0y0=(-20, -15), proj=wgs84, - pixel_ref='corner') + g = Grid( + nxny=(5, 4), + dxdy=(10, 10), + x0y0=(-20, -15), + proj=wgs84, + pixel_ref='corner', + ) c = Map(g, ny=4, countries=False) - c.set_lonlat_contours(interval=5., colors='crimson', linewidths=1) + c.set_lonlat_contours(interval=5.0, colors='crimson', linewidths=1) c.set_text(-5, -5, 'Less Middle', color='green', style='italic', size=25) - c.set_geometry(shpg.Point(-10, -10), s=500, marker='o', - text='My point', text_delta=[0, 0]) + c.set_geometry( + shpg.Point(-10, -10), + s=500, + marker='o', + text='My point', + text_delta=[0, 0], + ) shape = read_shapefile_to_grid(shapefiles['world_borders'], c.grid) had_c = set() - for index, row in shape.iloc[::-1].iterrows(): + for _index, row in shape.iloc[::-1].iterrows(): if row.CNTRY_NAME in had_c: c.set_geometry(row.geometry, crs=c.grid) else: - c.set_geometry(row.geometry, text=row.CNTRY_NAME, crs=c.grid, - text_kwargs=dict(horizontalalignment='center', - verticalalignment='center', - clip_on=True, - color='gray'), text_delta=[0, 0]) + c.set_geometry( + row.geometry, + text=row.CNTRY_NAME, + crs=c.grid, + text_kwargs={ + 'horizontalalignment': 'center', + 'verticalalignment': 'center', + 'clip_on': True, + 'color': 'gray', + }, + text_delta=[0, 0], + ) had_c.add(row.CNTRY_NAME) - c.set_points([20, 20, 10], [10, 20, 20], s=250, marker='s', - c='purple', hatch='||||', text='baaaaad', text_delta=[0, 0], - text_kwargs=dict(horizontalalignment='center', - verticalalignment='center', color='red')) + c.set_points( + [20, 20, 10], + [10, 20, 20], + s=250, + marker='s', + c='purple', + hatch='||||', + text='baaaaad', + text_delta=[0, 0], + text_kwargs={ + 'horizontalalignment': 'center', + 'verticalalignment': 'center', + 'color': 'red', + }, + ) fig, ax = plt.subplots(1, 1) c.visualize(ax=ax, addcbar=False) @@ -733,14 +807,12 @@ def test_text(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_hef_linear(): - grid = mercator_grid(center_ll=(10.76, 46.798444), - extent=(10000, 7000)) +def test_hef_linear() -> plt.Figure: + grid = mercator_grid(center_ll=(10.76, 46.798444), extent=(10000, 7000)) c = Map(grid, countries=False) c.set_lonlat_contours(interval=10) c.set_shapefile(get_demo_file('Hintereisferner_UTM.shp')) - c.set_topography(get_demo_file('hef_srtm.tif'), - interp='linear') + c.set_topography(get_demo_file('hef_srtm.tif'), interp='linear') fig, ax = plt.subplots(1, 1) c.visualize(ax=ax, addcbar=False, title='linear') @@ -750,9 +822,8 @@ def test_hef_linear(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_hef_default_spline(): - grid = mercator_grid(center_ll=(10.76, 46.798444), - extent=(10000, 7000)) +def test_hef_default_spline() -> plt.Figure: + grid = mercator_grid(center_ll=(10.76, 46.798444), extent=(10000, 7000)) c = Map(grid, countries=False) c.set_lonlat_contours(interval=0) c.set_shapefile(get_demo_file('Hintereisferner_UTM.shp')) @@ -766,9 +837,8 @@ def test_hef_default_spline(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=6) -def test_hef_from_array(): - grid = mercator_grid(center_ll=(10.76, 46.798444), - extent=(10000, 7000)) +def test_hef_from_array() -> plt.Figure: + grid = mercator_grid(center_ll=(10.76, 46.798444), extent=(10000, 7000)) c = Map(grid, countries=False) c.set_lonlat_contours(interval=0) c.set_shapefile(get_demo_file('Hintereisferner_UTM.shp')) @@ -784,11 +854,9 @@ def test_hef_from_array(): @requires_matplotlib -@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, - tolerance=15) -def test_hef_topo_withnan(): - grid = mercator_grid(center_ll=(10.76, 46.798444), - extent=(10000, 7000)) +@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=15) +def test_hef_topo_withnan() -> plt.Figure: + grid = mercator_grid(center_ll=(10.76, 46.798444), extent=(10000, 7000)) c = Map(grid, countries=False) c.set_lonlat_contours(interval=10) c.set_shapefile(get_demo_file('Hintereisferner_UTM.shp')) @@ -797,6 +865,7 @@ def test_hef_topo_withnan(): mytopo = dem.get_vardata() h = c.set_topography(mytopo, crs=dem.grid, interp='spline') + assert h is not None c.set_lonlat_contours() c.set_cmap(get_cmap('topo')) c.set_plot_params(nlevels=256) @@ -811,14 +880,16 @@ def test_hef_topo_withnan(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=25) -def test_gmap(): - g = GoogleCenterMap(center_ll=(10.762660, 46.794221), zoom=13, - size_x=640, size_y=640) +def test_gmap() -> plt.Figure: + g = GoogleCenterMap( + center_ll=(10.762660, 46.794221), zoom=13, size_x=640, size_y=640 + ) m = Map(g.grid, countries=False, factor=1) m.set_lonlat_contours(interval=0.025) - m.set_shapefile(get_demo_file('Hintereisferner.shp'), - linewidths=2, edgecolor='darkred') + m.set_shapefile( + get_demo_file('Hintereisferner.shp'), linewidths=2, edgecolor='darkred' + ) m.set_rgb(g.get_vardata()) fig, ax = plt.subplots(1, 1) @@ -829,12 +900,11 @@ def test_gmap(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=25) -def test_gmap_transformed(): +def test_gmap_transformed() -> plt.Figure: dem = GeoTiff(get_demo_file('hef_srtm.tif')) dem.set_subset(margin=-100) - dem = mercator_grid(center_ll=(10.76, 46.798444), - extent=(10000, 7000)) + dem = mercator_grid(center_ll=(10.76, 46.798444), extent=(10000, 7000)) i, j = dem.ij_coordinates g = GoogleVisibleMap(x=i, y=j, crs=dem, size_x=500, size_y=400) @@ -846,8 +916,9 @@ def test_gmap_transformed(): m.set_data(img) m.set_lonlat_contours(interval=0.025) - m.set_shapefile(get_demo_file('Hintereisferner.shp'), - linewidths=2, edgecolor='darkred') + m.set_shapefile( + get_demo_file('Hintereisferner.shp'), linewidths=2, edgecolor='darkred' + ) m.set_rgb(img, g.grid) fig, ax = plt.subplots(1, 1) @@ -858,7 +929,7 @@ def test_gmap_transformed(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=10) -def test_gmap_llconts(): +def test_gmap_llconts() -> plt.Figure: # This was because some problems were left unnoticed by other tests g = GoogleCenterMap(center_ll=(11.38, 47.26), zoom=9) m = Map(g.grid) @@ -873,14 +944,18 @@ def test_gmap_llconts(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=13) -def test_plot_on_map(): +def test_plot_on_map() -> plt.Figure: import salem from salem.utils import get_demo_file + ds = salem.open_wrf_dataset(get_demo_file('wrfout_d01.nc')) - t2_sub = ds.salem.subset(corners=((77., 20.), (97., 35.)), crs=salem.wgs84).T2.isel(time=2) + t2_sub = ds.salem.subset( + corners=((77.0, 20.0), (97.0, 35.0)), crs=salem.wgs84 + ).T2.isel(time=2) shdf = salem.read_shapefile(get_demo_file('world_borders.shp')) - shdf = shdf.loc[shdf['CNTRY_NAME'].isin( - ['Nepal', 'Bhutan'])] # GeoPandas' GeoDataFrame + shdf = shdf.loc[ + shdf['CNTRY_NAME'].isin(['Nepal', 'Bhutan']) + ] # GeoPandas' GeoDataFrame t2_sub = t2_sub.salem.subset(shape=shdf, margin=2) # add 2 grid points t2_roi = t2_sub.salem.roi(shape=shdf) fig, ax = plt.subplots(1, 1) @@ -891,26 +966,30 @@ def test_plot_on_map(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir) -def test_example_docs(): - +def test_example_docs() -> plt.Figure: import salem from salem.utils import get_demo_file + ds = salem.open_xr_dataset(get_demo_file('wrfout_d01.nc')) t2 = ds.T2.isel(Time=2) - t2_sub = t2.salem.subset(corners=((77., 20.), (97., 35.)), - crs=salem.wgs84) + t2_sub = t2.salem.subset( + corners=((77.0, 20.0), (97.0, 35.0)), crs=salem.wgs84 + ) shdf = salem.read_shapefile(get_demo_file('world_borders.shp')) - shdf = shdf.loc[shdf['CNTRY_NAME'].isin( - ['Nepal', 'Bhutan'])] # GeoPandas' GeoDataFrame + shdf = shdf.loc[ + shdf['CNTRY_NAME'].isin(['Nepal', 'Bhutan']) + ] # GeoPandas' GeoDataFrame t2_sub = t2_sub.salem.subset(shape=shdf, margin=2) # add 2 grid points t2_roi = t2_sub.salem.roi(shape=shdf) - smap = t2_roi.salem.get_map(data=t2_roi-273.15, cmap='RdYlBu_r', vmin=-14, vmax=18) + smap = t2_roi.salem.get_map( + data=t2_roi - 273.15, cmap='RdYlBu_r', vmin=-14, vmax=18 + ) _ = smap.set_topography(get_demo_file('himalaya.tif')) smap.set_shapefile(shape=shdf, color='grey', linewidth=3, zorder=5) smap.set_points(91.1, 29.6) smap.set_text(91.2, 29.7, 'Lhasa', fontsize=17) - smap.set_data(ds.T2.isel(Time=1)-273.15, crs=ds.salem.grid) + smap.set_data(ds.T2.isel(Time=1) - 273.15, crs=ds.salem.grid) fig, ax = plt.subplots(1, 1) smap.visualize(ax=ax) @@ -920,31 +999,35 @@ def test_example_docs(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=5) -def test_colormaps(): - +def test_colormaps() -> plt.Figure: fig = plt.figure(figsize=(8, 3)) - axs = [fig.add_axes([0.05, 0.80, 0.9, 0.15]), - fig.add_axes([0.05, 0.475, 0.9, 0.15]), - fig.add_axes([0.05, 0.15, 0.9, 0.15])] + axs = [ + fig.add_axes((0.05, 0.80, 0.9, 0.15)), + fig.add_axes((0.05, 0.475, 0.9, 0.15)), + fig.add_axes((0.05, 0.15, 0.9, 0.15)), + ] for ax, cm in zip(axs, ['topo', 'dem', 'nrwc']): - cb = mpl.colorbar.ColorbarBase(ax, cmap=get_cmap(cm), - orientation='horizontal') - cb.set_label(cm); + cb = mpl.colorbar.ColorbarBase( + ax, cmap=get_cmap(cm), orientation='horizontal' + ) + cb.set_label(cm) return fig @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=5) -def test_geogrid_simulator(): +def test_geogrid_simulator() -> plt.Figure: from salem.wrftools import geogrid_simulator - g, maps = geogrid_simulator(get_demo_file('namelist_mercator.wps'), - do_maps=True) + + g, maps = geogrid_simulator( + get_demo_file('namelist_mercator.wps'), do_maps=True + ) assert len(g) == 4 fig, axs = plt.subplots(2, 2) axs = np.asarray(axs).flatten() - for i, (m, ax) in enumerate(zip(maps, axs)): + for _, (m, ax) in enumerate(zip(maps, axs)): m.set_rgb(natural_earth='lr') m.plot(ax=ax) return fig @@ -952,16 +1035,16 @@ def test_geogrid_simulator(): @requires_matplotlib @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=5) -def test_lookup_transform(): - +def test_lookup_transform() -> plt.Figure: dsw = open_wrf_dataset(get_demo_file('wrfout_d01.nc')) dse = open_xr_dataset(get_demo_file('era_interim_tibet.nc')) out = dse.salem.lookup_transform(dsw.T2C.isel(time=0), method=len) fig, ax = plt.subplots(1, 1) sm = out.salem.get_map() sm.set_data(out) - sm.set_geometry(dsw.salem.grid.extent_as_polygon(), edgecolor='r', - linewidth=2) + sm.set_geometry( + dsw.salem.grid.extent_as_polygon(), edgecolor='r', linewidth=2 + ) sm.visualize(ax=ax) return fig @@ -970,8 +1053,7 @@ def test_lookup_transform(): @requires_cartopy @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=10) @pytest.mark.skip(reason='There is an unknown issue with cartopy') -def test_cartopy(): - +def test_cartopy() -> plt.Figure: import cartopy fig = plt.figure(figsize=(8, 11)) @@ -1020,8 +1102,7 @@ def test_cartopy(): @requires_cartopy @pytest.mark.mpl_image_compare(baseline_dir=baseline_dir, tolerance=7) @pytest.mark.skip(reason='There is an unknown issue with cartopy') -def test_cartopy_polar(): - +def test_cartopy_polar() -> plt.Figure: import cartopy fig = plt.figure(figsize=(8, 8)) @@ -1043,8 +1124,7 @@ def test_cartopy_polar(): ax = plt.subplot(2, 2, 3) smap = ds.HGT_M.salem.quick_map(ax=ax, cmap='Oranges') - ax.scatter(ds.XLONG_M, ds.XLAT_M, s=5, - transform=smap.transform(ax=ax)) + ax.scatter(ds.XLONG_M, ds.XLAT_M, s=5, transform=smap.transform(ax=ax)) p = ds.salem.cartopy() ax = plt.subplot(2, 2, 4, projection=p) diff --git a/salem/tests/test_misc.py b/salem/tests/test_misc.py index f82baa8..bd5a326 100644 --- a/salem/tests/test_misc.py +++ b/salem/tests/test_misc.py @@ -1,93 +1,98 @@ -from __future__ import division - -import unittest +from __future__ import annotations +import copy import shutil -import os import time -import warnings -import copy +import unittest +from pathlib import Path -import pytest import netCDF4 import numpy as np +import pytest from numpy.testing import assert_allclose -from salem.tests import (requires_geopandas, requires_dask, - requires_matplotlib, requires_cartopy) -from salem import utils, transform_geopandas, GeoTiff, read_shapefile, sio -from salem import read_shapefile_to_grid +from salem import ( + GeoTiff, + read_shapefile, + read_shapefile_to_grid, + sio, + transform_geopandas, + utils, +) +from salem.tests import ( + requires_cartopy, + requires_dask, + requires_geopandas, + requires_matplotlib, +) from salem.utils import get_demo_file +current_dir = Path(__file__).parent +testdir = current_dir / 'tmp' -current_dir = os.path.dirname(os.path.abspath(__file__)) -testdir = os.path.join(current_dir, 'tmp') - -def is_cartopy_rotated_working(): +def is_cartopy_rotated_working() -> bool: + import pyproj + from cartopy.crs import PlateCarree from salem.gis import proj_to_cartopy - from cartopy.crs import PlateCarree - import pyproj - cp = pyproj.Proj('+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' - '+to_meter=0.0174532925199433 +o_lon_p=0.0 +o_lat_p=80.5 ' - '+lon_0=357.5 +no_defs') + cp = pyproj.Proj( + '+ellps=WGS84 +proj=ob_tran +o_proj=latlon ' + '+to_meter=0.0174532925199433 +o_lon_p=0.0 +o_lat_p=80.5 ' + '+lon_0=357.5 +no_defs' + ) cp = proj_to_cartopy(cp) out = PlateCarree().transform_points(cp, np.array([-20]), np.array([-9])) - if not (np.allclose(out[0, 0], -22.243473889042903, atol=1e-5) and - np.allclose(out[0, 1], -0.06328365194179102, atol=1e-5)): - - # Cartopy also had issues - return False - - return True + # Cartopy also had issues + return np.allclose( + out[0, 0], -22.243473889042903, atol=1e-5 + ) and np.allclose(out[0, 1], -0.06328365194179102, atol=1e-5) @requires_geopandas -def create_dummy_shp(fname): - - import shapely.geometry as shpg +def create_dummy_shp(fname: Path | str) -> Path: import geopandas as gpd + import shapely.geometry as shpg - e_line = shpg.LinearRing([(1.5, 1), (2., 1.5), (1.5, 2.), (1, 1.5)]) + if isinstance(fname, str): + fname = Path(fname) + e_line = shpg.LinearRing([(1.5, 1), (2.0, 1.5), (1.5, 2.0), (1, 1.5)]) i_line = shpg.LinearRing([(1.4, 1.4), (1.6, 1.4), (1.6, 1.6), (1.4, 1.6)]) p1 = shpg.Polygon(e_line, [i_line]) - p2 = shpg.Polygon([(2.5, 1.3), (3., 1.8), (2.5, 2.3), (2, 1.8)]) + p2 = shpg.Polygon([(2.5, 1.3), (3.0, 1.8), (2.5, 2.3), (2, 1.8)]) df = gpd.GeoDataFrame(crs='EPSG:4326', geometry=gpd.GeoSeries([p1, p2])) df['name'] = ['Polygon', 'Line'] - of = os.path.join(testdir, fname) + of = testdir / fname df.to_file(of) return of -def delete_test_dir(): - if os.path.exists(testdir): +def delete_test_dir() -> None: + if testdir.exists(): shutil.rmtree(testdir) class TestUtils(unittest.TestCase): + def setUp(self) -> None: + if not testdir.exists(): + testdir.mkdir(parents=True) - def setUp(self): - if not os.path.exists(testdir): - os.makedirs(testdir) - - def tearDown(self): + def tearDown(self) -> None: delete_test_dir() - def test_hash_cache_dir(self): + def test_hash_cache_dir(self) -> None: h1 = utils._hash_cache_dir() h2 = utils._hash_cache_dir() - self.assertEqual(h1, h2) - - def test_demofiles(self): + assert h1 == h2 - self.assertTrue(os.path.exists(utils.get_demo_file('dem_wgs84.nc'))) - self.assertTrue(utils.get_demo_file('dummy') is None) - - def test_read_colormap(self): + def test_demofiles(self) -> None: + assert utils.get_demo_file('dem_wgs84.nc').exists() + with pytest.raises(FileNotFoundError): + utils.get_demo_file('dummy') + def test_read_colormap(self) -> None: cl = utils.read_colormap('topo') * 256 assert_allclose(cl[4, :], (177, 242, 196)) assert_allclose(cl[-1, :], (235, 233, 235)) @@ -96,8 +101,7 @@ def test_read_colormap(self): assert_allclose(cl[4, :], (153, 100, 43)) assert_allclose(cl[-1, :], (255, 255, 255)) - def test_reduce(self): - + def test_reduce(self) -> None: arr = [[1, 1, 2, 2], [1, 1, 2, 2]] assert_allclose(utils.reduce(arr, 1), arr) assert_allclose(utils.reduce(arr, 2), [[1, 2]]) @@ -107,94 +111,98 @@ def test_reduce(self): assert_allclose(arr.shape, (3, 2, 4)) assert_allclose(utils.reduce(arr, 1), arr) assert_allclose(utils.reduce(arr, 2), [[[1, 2]], [[1, 2]], [[1, 2]]]) - assert_allclose(utils.reduce(arr, 2, how=np.sum), - [[[4, 8]], [[4, 8]], [[4, 8]]]) + assert_allclose( + utils.reduce(arr, 2, how=np.sum), [[[4, 8]], [[4, 8]], [[4, 8]]] + ) arr[0, ...] = 0 - assert_allclose(utils.reduce(arr, 2, how=np.sum), - [[[0, 0]], [[4, 8]], [[4, 8]]]) + assert_allclose( + utils.reduce(arr, 2, how=np.sum), [[[0, 0]], [[4, 8]], [[4, 8]]] + ) arr[1, ...] = 1 - assert_allclose(utils.reduce(arr, 2, how=np.sum), - [[[0, 0]], [[4, 4]], [[4, 8]]]) + assert_allclose( + utils.reduce(arr, 2, how=np.sum), [[[0, 0]], [[4, 4]], [[4, 8]]] + ) class TestIO(unittest.TestCase): + def setUp(self) -> None: + if not testdir.exists(): + testdir.mkdir(parents=True) - def setUp(self): - if not os.path.exists(testdir): - os.makedirs(testdir) - - def tearDown(self): + def tearDown(self) -> None: delete_test_dir() @requires_geopandas - def test_cache_working(self): - + def test_cache_working(self) -> None: f1 = 'f1.shp' f1 = create_dummy_shp(f1) cf1 = utils.cached_shapefile_path(f1) - self.assertFalse(os.path.exists(cf1)) + assert not cf1.exists() _ = read_shapefile(f1) - self.assertFalse(os.path.exists(cf1)) + assert not cf1.exists() _ = read_shapefile(f1, cached=True) - self.assertTrue(os.path.exists(cf1)) + assert cf1.exists() # nested calls - self.assertTrue(cf1 == utils.cached_shapefile_path(cf1)) + assert cf1 == utils.cached_shapefile_path(cf1) # wait a bit time.sleep(0.1) f1 = create_dummy_shp(f1) cf2 = utils.cached_shapefile_path(f1) - self.assertFalse(os.path.exists(cf1)) + assert not cf1.exists() _ = read_shapefile(f1, cached=True) - self.assertFalse(os.path.exists(cf1)) - self.assertTrue(os.path.exists(cf2)) + assert not cf1.exists() + assert cf2.exists() df = read_shapefile(f1, cached=True) - np.testing.assert_allclose(df.min_x, [1., 2.]) - np.testing.assert_allclose(df.max_x, [2., 3.]) - np.testing.assert_allclose(df.min_y, [1., 1.3]) - np.testing.assert_allclose(df.max_y, [2., 2.3]) + np.testing.assert_allclose(df.min_x, [1.0, 2.0]) + np.testing.assert_allclose(df.max_x, [2.0, 3.0]) + np.testing.assert_allclose(df.min_y, [1.0, 1.3]) + np.testing.assert_allclose(df.max_y, [2.0, 2.3]) - self.assertRaises(ValueError, read_shapefile, 'f1.sph') - self.assertRaises(ValueError, utils.cached_shapefile_path, 'f1.splash') + with pytest.raises(ValueError, match='File extension not recognised'): + read_shapefile(Path('f1.sph')) + with pytest.raises(ValueError, match='File extension not recognised'): + utils.cached_shapefile_path(Path('f1.splash')) @requires_geopandas - def test_read_to_grid(self): - + def test_read_to_grid(self) -> None: g = GeoTiff(utils.get_demo_file('hef_srtm.tif')) sf = utils.get_demo_file('Hintereisferner_UTM.shp') df1 = read_shapefile_to_grid(sf, g.grid) df2 = transform_geopandas(read_shapefile(sf), to_crs=g.grid) - assert_allclose(df1.geometry[0].exterior.coords, - df2.geometry[0].exterior.coords) + assert_allclose( + df1.geometry[0].exterior.coords, df2.geometry[0].exterior.coords + ) # test for caching d = g.grid.to_dict() # change key ordering by chance - d2 = dict((k, v) for k, v in d.items()) + d2 = dict((k, v) for k, v in d.items()) # noqa: C402 from salem.sio import _memory_shapefile_to_grid, cached_shapefile_path + shape_cpath = cached_shapefile_path(sf) - res = _memory_shapefile_to_grid.call_and_shelve(shape_cpath, - grid=g.grid, - **d) + res = _memory_shapefile_to_grid.call_and_shelve( + shape_cpath, grid=g.grid, **d + ) try: h1 = res.timestamp except AttributeError: h1 = res.argument_hash - res = _memory_shapefile_to_grid.call_and_shelve(shape_cpath, - grid=g.grid, - **d2) + res = _memory_shapefile_to_grid.call_and_shelve( + shape_cpath, grid=g.grid, **d2 + ) try: h2 = res.timestamp except AttributeError: h2 = res.argument_hash - self.assertEqual(h1, h2) - - def test_notimevar(self): + assert h1 == h2 + def test_notimevar(self) -> None: import xarray as xr + da = xr.DataArray(np.arange(12).reshape(3, 4), dims=['lat', 'lon']) ds = da.to_dataset(name='var') @@ -203,16 +211,15 @@ def test_notimevar(self): class TestSkyIsFalling(unittest.TestCase): - @requires_matplotlib - def test_projplot(self): - + def test_projplot(self) -> None: # this caused many problems on fabien's laptop. # this is just to be sure that on your system, everything is fine - import pyproj import matplotlib.pyplot as plt - from salem.gis import transform_proj, check_crs + import pyproj + + from salem.gis import check_crs, transform_proj pyproj.Proj(proj='latlong', datum='WGS84') plt.figure() @@ -223,61 +230,67 @@ def test_projplot(self): proj_out = check_crs('EPSG:4326') proj_in = pyproj.Proj(srs, preserve_units=True) - lon, lat = transform_proj(proj_in, proj_out, -2235000, -2235000) + lon, _ = transform_proj( + proj_in, proj_out, np.array(-2235000), np.array(-2235000) + ) np.testing.assert_allclose(lon, 70.75731, atol=1e-5) - def test_gh_152(self): - + def test_gh_152(self) -> None: # https://github.com/fmaussion/salem/issues/152 import xarray as xr - da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['lat', 'lon'], - coords={'lat': np.linspace(0, 30, 4), - 'lon': np.linspace(-20, 20, 5)}) + + da = xr.DataArray( + np.arange(20).reshape(4, 5), + dims=['lat', 'lon'], + coords={ + 'lat': np.linspace(0, 30, 4), + 'lon': np.linspace(-20, 20, 5), + }, + ) da.salem.roi() class TestXarray(unittest.TestCase): + def setUp(self) -> None: + if not testdir.exists(): + testdir.mkdir(parents=True) - def setUp(self): - if not os.path.exists(testdir): - os.makedirs(testdir) - - def tearDown(self): + def tearDown(self) -> None: delete_test_dir() @requires_dask - def test_era(self): - + def test_era(self) -> None: ds = sio.open_xr_dataset(get_demo_file('era_interim_tibet.nc')).chunk() - self.assertEqual(ds.salem.x_dim, 'longitude') - self.assertEqual(ds.salem.y_dim, 'latitude') + assert ds.salem.x_dim == 'longitude' + assert ds.salem.y_dim == 'latitude' dss = ds.salem.subset(ds=ds) - self.assertEqual(dss.salem.grid, ds.salem.grid) + assert dss.salem.grid == ds.salem.grid lon = 91.1 lat = 31.1 dss = ds.salem.subset(corners=((lon, lat), (lon, lat)), margin=1) - self.assertEqual(len(dss.latitude), 3) - self.assertEqual(len(dss.longitude), 3) + assert len(dss.latitude) == 3 + assert len(dss.longitude) == 3 np.testing.assert_almost_equal(dss.longitude, [90.0, 90.75, 91.5]) - def test_roi(self): + def test_roi(self) -> None: import xarray as xr + # Check that all attrs are preserved with sio.open_xr_dataset(get_demo_file('era_interim_tibet.nc')) as ds: ds.encoding = {'_FillValue': np.nan} ds['t2m'].encoding = {'_FillValue': np.nan} - ds_ = ds.salem.roi(roi=np.ones_like(ds.t2m.values[0, ...])) + ds_ = ds.salem.roi(roi=np.ones_like(ds.t2m.to_numpy()[0, ...])) xr.testing.assert_identical(ds, ds_) assert ds.encoding == ds_.encoding assert ds.t2m.encoding == ds_.t2m.encoding @requires_geopandas # because of the grid tests, more robust with GDAL - def test_basic_wrf(self): + def test_basic_wrf(self) -> None: import xarray as xr ds = sio.open_xr_dataset(get_demo_file('wrf_tip_d1.nc')).chunk() @@ -302,21 +315,20 @@ def test_basic_wrf(self): # the grid should not be missunderstood as lonlat t2 = ds.T2.isel(Time=0) - 273.15 with pytest.raises(RuntimeError): - t2.salem.grid + _ = t2.salem.grid @requires_dask - def test_geo_em(self): - + def test_geo_em(self) -> None: for i in [1, 2, 3]: fg = get_demo_file('geo_em_d0{}_lambert.nc'.format(i)) ds = sio.open_wrf_dataset(fg).chunk() - self.assertFalse('Time' in ds.dims) - self.assertTrue('time' in ds.dims) - self.assertTrue('south_north' in ds.dims) - self.assertTrue('south_north' in ds.coords) + assert 'Time' not in ds.dims + assert 'time' in ds.dims + assert 'south_north' in ds.dims + assert 'south_north' in ds.coords @requires_geopandas # because of the grid tests, more robust with GDAL - def test_wrf(self): + def test_wrf(self) -> None: import xarray as xr ds = sio.open_wrf_dataset(get_demo_file('wrf_tip_d1.nc')).chunk() @@ -341,12 +353,12 @@ def test_wrf(self): # the grid should not be missunderstood as lonlat t2 = ds.T2.isel(time=0) - 273.15 with pytest.raises(RuntimeError): - t2.salem.grid + _ = t2.salem.grid @requires_dask - def test_ncl_diagvars(self): - + def test_ncl_diagvars(self) -> None: import xarray as xr + wf = get_demo_file('wrf_cropped.nc') ncl_out = get_demo_file('wrf_cropped_ncl.nc') @@ -359,7 +371,7 @@ def test_ncl_diagvars(self): ref = nc['SLP'] tot = w['SLP'] - tot = tot.values + tot = tot.to_numpy() assert_allclose(ref, tot, rtol=1e-6) w = w.isel(time=1, south_north=slice(12, 16), west_east=slice(9, 16)) @@ -371,7 +383,7 @@ def test_ncl_diagvars(self): ref = nc['SLP'] tot = w['SLP'] - tot = tot.values + tot = tot.to_numpy() assert_allclose(ref, tot, rtol=1e-6) w = w.isel(bottom_top=slice(3, 5)) @@ -383,14 +395,14 @@ def test_ncl_diagvars(self): ref = nc['SLP'] tot = w['SLP'] - tot = tot.values + tot = tot.to_numpy() assert_allclose(ref, tot, rtol=1e-6) @requires_dask - def test_ncl_diagvars_compressed(self): - + def test_ncl_diagvars_compressed(self) -> None: rtol = 2e-5 import xarray as xr + wf = get_demo_file('wrf_cropped_compressed.nc') ncl_out = get_demo_file('wrf_cropped_ncl.nc') @@ -428,22 +440,27 @@ def test_ncl_diagvars_compressed(self): assert_allclose(ref, tot, rtol=rtol) @requires_dask - def test_unstagger(self): - + def test_unstagger(self) -> None: wf = get_demo_file('wrf_cropped.nc') w = sio.open_wrf_dataset(wf).chunk() nc = sio.open_xr_dataset(wf).chunk() - nc['PH_UNSTAGG'] = nc['P']*0. - uns = nc['PH'].isel(bottom_top_stag=slice(0, -1)).values + \ - nc['PH'].isel(bottom_top_stag=slice(1, len(nc.bottom_top_stag))).values + nc['PH_UNSTAGG'] = nc['P'] * 0.0 + uns = ( + nc['PH'].isel(bottom_top_stag=slice(0, -1)).to_numpy() + + nc['PH'] + .isel(bottom_top_stag=slice(1, len(nc.bottom_top_stag))) + .to_numpy() + ) nc['PH_UNSTAGG'].values = uns * 0.5 assert_allclose(w['PH'], nc['PH_UNSTAGG']) # chunk - v = w['PH'].chunk({'time': 1, 'bottom_top': 6, 'south_north': 13, 'west_east': 13}) + v = w['PH'].chunk( + {'time': 1, 'bottom_top': 6, 'south_north': 13, 'west_east': 13} + ) assert_allclose(v.mean(), nc['PH_UNSTAGG'].mean(), atol=1e-2) wn = w.isel(west_east=slice(4, 8)) @@ -473,8 +490,7 @@ def test_unstagger(self): w['PH'].chunk() @requires_dask - def test_unstagger_compressed(self): - + def test_unstagger_compressed(self) -> None: wf = get_demo_file('wrf_cropped.nc') wfc = get_demo_file('wrf_cropped_compressed.nc') @@ -484,53 +500,59 @@ def test_unstagger_compressed(self): assert_allclose(w['PH'], wc['PH'], rtol=0.003) @requires_dask - def test_diagvars(self): - + def test_diagvars(self) -> None: wf = get_demo_file('wrf_d01_allvars_cropped.nc') w = sio.open_wrf_dataset(wf).chunk() # ws - w['ws_ref'] = np.sqrt(w['U']**2 + w['V']**2) + w['ws_ref'] = np.sqrt(w['U'] ** 2 + w['V'] ** 2) assert_allclose(w['ws_ref'], w['WS']) wcrop = w.isel(west_east=slice(4, 8), bottom_top=4) assert_allclose(wcrop['ws_ref'], wcrop['WS']) @requires_dask - def test_diagvars_compressed(self): - + def test_diagvars_compressed(self) -> None: wf = get_demo_file('wrf_d01_allvars_cropped_compressed.nc') w = sio.open_wrf_dataset(wf).chunk() # ws - w['ws_ref'] = np.sqrt(w['U']**2 + w['V']**2) + w['ws_ref'] = np.sqrt(w['U'] ** 2 + w['V'] ** 2) assert_allclose(w['ws_ref'], w['WS']) wcrop = w.isel(west_east=slice(4, 8), bottom_top=4) assert_allclose(wcrop['ws_ref'], wcrop['WS']) @requires_dask - def test_prcp(self): - + def test_prcp(self) -> None: wf = get_demo_file('wrfout_d01.nc') w = sio.open_wrf_dataset(wf).chunk() nc = sio.open_xr_dataset(wf) - nc['REF_PRCP_NC'] = nc['RAINNC']*0. - uns = nc['RAINNC'].isel(Time=slice(1, len(nc.bottom_top_stag))).values - \ - nc['RAINNC'].isel(Time=slice(0, -1)).values - nc['REF_PRCP_NC'].values[1:, ...] = uns * 60 / 180. # for three hours - nc['REF_PRCP_NC'].values[0, ...] = np.nan - - nc['REF_PRCP_C'] = nc['RAINC']*0. - uns = nc['RAINC'].isel(Time=slice(1, len(nc.bottom_top_stag))).values - \ - nc['RAINC'].isel(Time=slice(0, -1)).values - nc['REF_PRCP_C'].values[1:, ...] = uns * 60 / 180. # for three hours - nc['REF_PRCP_C'].values[0, ...] = np.nan + nc['REF_PRCP_NC'] = nc['RAINNC'] * 0.0 + uns = ( + nc['RAINNC'] + .isel(Time=slice(1, len(nc.bottom_top_stag))) + .to_numpy() + - nc['RAINNC'].isel(Time=slice(0, -1)).to_numpy() + ) + nc['REF_PRCP_NC'].to_numpy()[1:, ...] = ( + uns * 60 / 180.0 + ) # for three hours + nc['REF_PRCP_NC'].to_numpy()[0, ...] = np.nan + + nc['REF_PRCP_C'] = nc['RAINC'] * 0.0 + uns = ( + nc['RAINC'].isel(Time=slice(1, len(nc.bottom_top_stag))).to_numpy() + - nc['RAINC'].isel(Time=slice(0, -1)).to_numpy() + ) + nc['REF_PRCP_C'].to_numpy()[1:, ...] = ( + uns * 60 / 180.0 + ) # for three hours + nc['REF_PRCP_C'].to_numpy()[0, ...] = np.nan nc['REF_PRCP'] = nc['REF_PRCP_C'] + nc['REF_PRCP_NC'] for suf in ['_NC', '_C', '']: - assert_allclose(w['PRCP' + suf], nc['REF_PRCP' + suf], rtol=1e-5) wn = w.isel(time=slice(1, 3)) @@ -546,7 +568,7 @@ def test_prcp(self): assert_allclose(wn['PRCP' + suf], ncn['REF_PRCP' + suf], rtol=1e-5) wn = w.isel(time=0) - self.assertTrue(~np.any(np.isfinite(wn['PRCP' + suf].values))) + assert ~np.any(np.isfinite(wn['PRCP' + suf].values)) wn = w.isel(time=slice(1, 3), south_north=slice(50, -1)) ncn = nc.isel(Time=slice(1, 3), south_north=slice(50, -1)) @@ -561,11 +583,10 @@ def test_prcp(self): assert_allclose(wn['PRCP' + suf], ncn['REF_PRCP' + suf], rtol=1e-5) wn = w.isel(time=0, south_north=slice(50, -1)) - self.assertTrue(~np.any(np.isfinite(wn['PRCP' + suf].values))) + assert ~np.any(np.isfinite(wn['PRCP' + suf].values)) @requires_dask - def test_prcp_compressed(self): - + def test_prcp_compressed(self) -> None: wf = get_demo_file('wrfout_d01.nc') wfc = get_demo_file('wrfout_d01_compressed.nc') @@ -576,8 +597,7 @@ def test_prcp_compressed(self): assert_allclose(w['PRCP' + suf], wc['PRCP' + suf], atol=0.0003) @requires_geopandas # because of the grid tests, more robust with GDAL - def test_transform_logic(self): - + def test_transform_logic(self) -> None: # This is just for the naming and dim logic, the rest is tested elsewh ds1 = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')).chunk() ds2 = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')).chunk() @@ -587,38 +607,49 @@ def test_transform_logic(self): with pytest.raises(ValueError): ds1.salem.transform_and_add(t2.values, grid=t2.salem.grid) - ds1.salem.transform_and_add(t2.values, grid=t2.salem.grid, name='t2_2darr') + ds1.salem.transform_and_add( + t2.values, grid=t2.salem.grid, name='t2_2darr' + ) assert 't2_2darr' in ds1 - assert_allclose(ds1.t2_2darr.coords['south_north'], - t2.coords['south_north']) - assert_allclose(ds1.t2_2darr.coords['west_east'], - t2.coords['west_east']) + assert_allclose( + ds1.t2_2darr.coords['south_north'], t2.coords['south_north'] + ) + assert_allclose( + ds1.t2_2darr.coords['west_east'], t2.coords['west_east'] + ) assert ds1.salem.grid == ds1.t2_2darr.salem.grid # 3darray case t2 = ds2.T2 - ds1.salem.transform_and_add(t2.values, grid=t2.salem.grid, name='t2_3darr') + ds1.salem.transform_and_add( + t2.values, grid=t2.salem.grid, name='t2_3darr' + ) assert 't2_3darr' in ds1 - assert_allclose(ds1.t2_3darr.coords['south_north'], - t2.coords['south_north']) - assert_allclose(ds1.t2_3darr.coords['west_east'], - t2.coords['west_east']) + assert_allclose( + ds1.t2_3darr.coords['south_north'], t2.coords['south_north'] + ) + assert_allclose( + ds1.t2_3darr.coords['west_east'], t2.coords['west_east'] + ) assert 'time' in ds1.t2_3darr.coords # dataarray case ds1.salem.transform_and_add(t2, name='NEWT2') assert 'NEWT2' in ds1 assert_allclose(ds1.NEWT2, ds1.T2) - assert_allclose(ds1.t2_3darr.coords['south_north'], - t2.coords['south_north']) - assert_allclose(ds1.t2_3darr.coords['west_east'], - t2.coords['west_east']) + assert_allclose( + ds1.t2_3darr.coords['south_north'], t2.coords['south_north'] + ) + assert_allclose( + ds1.t2_3darr.coords['west_east'], t2.coords['west_east'] + ) assert 'time' in ds1.t2_3darr.coords # dataset case - ds1.salem.transform_and_add(ds2[['RAINC', 'RAINNC']], - name={'RAINC': 'PRCPC', - 'RAINNC': 'PRCPNC'}) + ds1.salem.transform_and_add( + ds2[['RAINC', 'RAINNC']], + name={'RAINC': 'PRCPC', 'RAINNC': 'PRCPNC'}, + ) assert 'PRCPC' in ds1 assert_allclose(ds1.PRCPC, ds1.RAINC) assert 'time' in ds1.PRCPNC.coords @@ -626,16 +657,15 @@ def test_transform_logic(self): # what happens with external data? dse = sio.open_xr_dataset(get_demo_file('era_interim_tibet.nc')) out = ds1.salem.transform(dse.t2m, interp='linear') - assert_allclose(out.coords['south_north'], - t2.coords['south_north']) - assert_allclose(out.coords['west_east'], - t2.coords['west_east']) + assert_allclose(out.coords['south_north'], t2.coords['south_north']) + assert_allclose(out.coords['west_east'], t2.coords['west_east']) @requires_geopandas - def test_lookup_transform(self): - + def test_lookup_transform(self) -> None: dsw = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')).chunk() - dse = sio.open_xr_dataset(get_demo_file('era_interim_tibet.nc')).chunk() + dse = sio.open_xr_dataset( + get_demo_file('era_interim_tibet.nc') + ).chunk() out = dse.salem.lookup_transform(dsw.T2C.isel(time=0), method=len) # qualitative tests (quantitative testing done elsewhere) assert out[0, 0] == 0 @@ -643,16 +673,17 @@ def test_lookup_transform(self): dsw = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')) dse = sio.open_xr_dataset(get_demo_file('era_interim_tibet.nc')) - _, lut = dse.salem.lookup_transform(dsw.T2C.isel(time=0), method=len, - return_lut=True) - out2 = dse.salem.lookup_transform(dsw.T2C.isel(time=0), method=len, - lut=lut) + _, lut = dse.salem.lookup_transform( + dsw.T2C.isel(time=0), method=len, return_lut=True + ) + out2 = dse.salem.lookup_transform( + dsw.T2C.isel(time=0), method=len, lut=lut + ) # qualitative tests (quantitative testing done elsewhere) assert_allclose(out, out2) @requires_dask - def test_full_wrf_wfile(self): - + def test_full_wrf_wfile(self) -> None: from salem.wrftools import var_classes # TODO: these tests are qualitative and should be compared against ncl @@ -665,16 +696,20 @@ def test_full_wrf_wfile(self): # just check that the data is here var_classes = copy.deepcopy(var_classes) for vn in var_classes: - _ = ds[vn].values - dss = ds.isel(west_east=slice(2, 6), south_north=slice(2, 5), - bottom_top=slice(0, 15)) - _ = dss[vn].values - dss = ds.isel(west_east=1, south_north=2, - bottom_top=3, time=2) - _ = dss[vn].values + _ = ds[vn].to_numpy() + dss = ds.isel( + west_east=slice(2, 6), + south_north=slice(2, 5), + bottom_top=slice(0, 15), + ) + _ = dss[vn].to_numpy() + dss = ds.isel(west_east=1, south_north=2, bottom_top=3, time=2) + _ = dss[vn].to_numpy() # some chunking experiments - v = ds.WS.chunk({'time': 2, 'bottom_top': 1, 'south_north': 4, 'west_east': 5}) + v = ds.WS.chunk( + {'time': 2, 'bottom_top': 1, 'south_north': 4, 'west_east': 5} + ) assert_allclose(v.mean(), ds.WS.mean(), atol=1e-3) ds = ds.isel(time=slice(1, 4)) v = ds.PRCP.chunk({'time': 1, 'south_north': 2, 'west_east': 2}) @@ -682,8 +717,7 @@ def test_full_wrf_wfile(self): assert_allclose(v.max(), ds.PRCP.max()) @requires_dask - def test_full_wrf_wfile_compressed(self): - + def test_full_wrf_wfile_compressed(self) -> None: from salem.wrftools import var_classes # TODO: these tests are qualitative and should be compared against ncl @@ -696,16 +730,20 @@ def test_full_wrf_wfile_compressed(self): # just check that the data is here var_classes = copy.deepcopy(var_classes) for vn in var_classes: - _ = ds[vn].values - dss = ds.isel(west_east=slice(2, 6), south_north=slice(2, 5), - bottom_top=slice(0, 15)) - _ = dss[vn].values - dss = ds.isel(west_east=1, south_north=2, - bottom_top=3, time=2) - _ = dss[vn].values + _ = ds[vn].to_numpy() + dss = ds.isel( + west_east=slice(2, 6), + south_north=slice(2, 5), + bottom_top=slice(0, 15), + ) + _ = dss[vn].to_numpy() + dss = ds.isel(west_east=1, south_north=2, bottom_top=3, time=2) + _ = dss[vn].to_numpy() # some chunking experiments - v = ds.WS.chunk({'time': 2, 'bottom_top': 1, 'south_north': 4, 'west_east': 5}) + v = ds.WS.chunk( + {'time': 2, 'bottom_top': 1, 'south_north': 4, 'west_east': 5} + ) assert_allclose(v.mean(), ds.WS.mean(), atol=1e-3) ds = ds.isel(time=slice(1, 4)) v = ds.PRCP.chunk({'time': 1, 'south_north': 2, 'west_east': 2}) @@ -713,95 +751,95 @@ def test_full_wrf_wfile_compressed(self): assert_allclose(v.max(), ds.PRCP.max()) @requires_dask - def test_3d_interp(self): - + def test_3d_interp(self) -> None: f = get_demo_file('wrf_d01_allvars_cropped.nc') ds = sio.open_wrf_dataset(f).chunk() - out = ds.salem.wrf_zlevel('Z', levels=6000.) - ref_2d = out * 0. + 6000. + out = ds.salem.wrf_zlevel('Z', levels=6000.0) + ref_2d = out * 0.0 + 6000.0 assert_allclose(out, ref_2d) # this used to raise an error _ = out.isel(time=1) - out = ds.salem.wrf_zlevel('Z', levels=[6000., 7000.]) - assert_allclose(out.sel(z=6000.), ref_2d) - assert_allclose(out.sel(z=7000.), ref_2d * 0. + 7000.) + out = ds.salem.wrf_zlevel('Z', levels=[6000.0, 7000.0]) + assert_allclose(out.sel(z=6000.0), ref_2d) + assert_allclose(out.sel(z=7000.0), ref_2d * 0.0 + 7000.0) assert np.all(np.isfinite(out)) out = ds.salem.wrf_zlevel('Z') - assert_allclose(out.sel(z=7500.), ref_2d * 0. + 7500.) + assert_allclose(out.sel(z=7500.0), ref_2d * 0.0 + 7500.0) - out = ds.salem.wrf_plevel('PRESSURE', levels=400.) - ref_2d = out * 0. + 400. + out = ds.salem.wrf_plevel('PRESSURE', levels=400.0) + ref_2d = out * 0.0 + 400.0 assert_allclose(out, ref_2d) - out = ds.salem.wrf_plevel('PRESSURE', levels=[400., 300.]) - assert_allclose(out.sel(p=400.), ref_2d) - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + out = ds.salem.wrf_plevel('PRESSURE', levels=[400.0, 300.0]) + assert_allclose(out.sel(p=400.0), ref_2d) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) out = ds.salem.wrf_plevel('PRESSURE') - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) assert np.any(~np.isfinite(out)) out = ds.salem.wrf_plevel('PRESSURE', fill_value='extrapolate') - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) assert np.all(np.isfinite(out)) ds = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')) - ws_h = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000., - use_multiprocessing=False) + ws_h = ds.isel(time=1).salem.wrf_zlevel( + 'WS', levels=8000.0, use_multiprocessing=False + ) assert np.all(np.isfinite(ws_h)) - ws_h2 = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000.) + ws_h2 = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000.0) assert_allclose(ws_h, ws_h2) @requires_dask - def test_3d_interp_compressed(self): + def test_3d_interp_compressed(self) -> None: f = get_demo_file('wrf_d01_allvars_cropped_compressed.nc') ds = sio.open_wrf_dataset(f).chunk() - out = ds.salem.wrf_zlevel('Z', levels=6000.) - ref_2d = out * 0. + 6000. + out = ds.salem.wrf_zlevel('Z', levels=6000.0) + ref_2d = out * 0.0 + 6000.0 assert_allclose(out, ref_2d) # this used to raise an error _ = out.isel(time=1) - out = ds.salem.wrf_zlevel('Z', levels=[6000., 7000.]) - assert_allclose(out.sel(z=6000.), ref_2d) - assert_allclose(out.sel(z=7000.), ref_2d * 0. + 7000.) + out = ds.salem.wrf_zlevel('Z', levels=[6000.0, 7000.0]) + assert_allclose(out.sel(z=6000.0), ref_2d) + assert_allclose(out.sel(z=7000.0), ref_2d * 0.0 + 7000.0) assert np.all(np.isfinite(out)) out = ds.salem.wrf_zlevel('Z') - assert_allclose(out.sel(z=7500.), ref_2d * 0. + 7500.) + assert_allclose(out.sel(z=7500.0), ref_2d * 0.0 + 7500.0) - out = ds.salem.wrf_plevel('PRESSURE', levels=400.) - ref_2d = out * 0. + 400. + out = ds.salem.wrf_plevel('PRESSURE', levels=400.0) + ref_2d = out * 0.0 + 400.0 assert_allclose(out, ref_2d) - out = ds.salem.wrf_plevel('PRESSURE', levels=[400., 300.]) - assert_allclose(out.sel(p=400.), ref_2d) - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + out = ds.salem.wrf_plevel('PRESSURE', levels=[400.0, 300.0]) + assert_allclose(out.sel(p=400.0), ref_2d) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) out = ds.salem.wrf_plevel('PRESSURE') - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) assert np.any(~np.isfinite(out)) out = ds.salem.wrf_plevel('PRESSURE', fill_value='extrapolate') - assert_allclose(out.sel(p=300.), ref_2d * 0. + 300.) + assert_allclose(out.sel(p=300.0), ref_2d * 0.0 + 300.0) assert np.all(np.isfinite(out)) ds = sio.open_wrf_dataset(get_demo_file('wrfout_d01.nc')) - ws_h = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000., - use_multiprocessing=False) + ws_h = ds.isel(time=1).salem.wrf_zlevel( + 'WS', levels=8000.0, use_multiprocessing=False + ) assert np.all(np.isfinite(ws_h)) - ws_h2 = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000.) + ws_h2 = ds.isel(time=1).salem.wrf_zlevel('WS', levels=8000.0) assert_allclose(ws_h, ws_h2) @requires_dask - def test_mf_datasets(self): - + def test_mf_datasets(self) -> None: import xarray as xr # prepare the data @@ -809,10 +847,10 @@ def test_mf_datasets(self): ds = xr.open_dataset(f) for i in range(4): dss = ds.isel(Time=[i]) - dss.to_netcdf(os.path.join(testdir, 'wrf_slice_{}.nc'.format(i))) + dss.to_netcdf(testdir / f'wrf_slice_{i}.nc') dss.close() ds = sio.open_wrf_dataset(f) - dsm = sio.open_mf_wrf_dataset(os.path.join(testdir, 'wrf_slice_*.nc')) + dsm = sio.open_mf_wrf_dataset(testdir / 'wrf_slice_*.nc') assert_allclose(ds['RAINNC'], dsm['RAINNC']) assert_allclose(ds['GEOPOTENTIAL'], dsm['GEOPOTENTIAL']) @@ -820,31 +858,31 @@ def test_mf_datasets(self): assert 'PRCP' not in dsm.variables prcp_nc_r = dsm.RAINNC.salem.deacc(as_rate=False) - self.assertEqual(prcp_nc_r.units, 'mm step-1') - self.assertEqual(prcp_nc_r.description, 'TOTAL GRID SCALE PRECIPITATION') + assert prcp_nc_r.units == 'mm step-1' + assert prcp_nc_r.description == 'TOTAL GRID SCALE PRECIPITATION' prcp_nc = dsm.RAINNC.salem.deacc() - self.assertEqual(prcp_nc.units, 'mm h-1') - self.assertEqual(prcp_nc.description, 'TOTAL GRID SCALE PRECIPITATION') + assert prcp_nc.units == 'mm h-1' + assert prcp_nc.description == 'TOTAL GRID SCALE PRECIPITATION' - assert_allclose(prcp_nc_r/3, prcp_nc) + assert_allclose(prcp_nc_r / 3, prcp_nc) # note that this is needed because there are variables which just # can't be computed lazily (i.e. prcp) - fo = os.path.join(testdir, 'wrf_merged.nc') - if os.path.exists(fo): - os.remove(fo) + fo = testdir / 'wrf_merged.nc' + if fo.exists(): + fo.unlink() dsm = dsm[['RAINNC', 'RAINC']].load() dsm.to_netcdf(fo) dsm.close() dsm = sio.open_wrf_dataset(fo) assert_allclose(ds['PRCP'], dsm['PRCP'], rtol=1e-6) - assert_allclose(prcp_nc, dsm['PRCP_NC'].isel(time=slice(1, 4)), - rtol=1e-6) + assert_allclose( + prcp_nc, dsm['PRCP_NC'].isel(time=slice(1, 4)), rtol=1e-6 + ) @requires_cartopy - def test_metum(self): - + def test_metum(self) -> None: if not sio.is_rotated_proj_working(): with pytest.raises(RuntimeError): sio.open_metum_dataset(get_demo_file('rotated_grid.nc')) @@ -867,8 +905,10 @@ def test_metum(self): if not is_cartopy_rotated_working(): return - from salem.gis import proj_to_cartopy from cartopy.crs import PlateCarree + + from salem.gis import proj_to_cartopy + cp = proj_to_cartopy(ds.salem.grid.proj) xx, yy = ds.salem.grid.xy_coordinates @@ -877,21 +917,21 @@ def test_metum(self): assert_allclose(out[:, 1].reshape(ii.shape), ds.latitude_t, atol=1e-7) # Round trip - out = cp.transform_points(PlateCarree(), - ds.longitude_t.values.flatten(), - ds.latitude_t.values.flatten()) + out = cp.transform_points( + PlateCarree(), + ds.longitude_t.to_numpy().flatten(), + ds.latitude_t.to_numpy().flatten(), + ) assert_allclose(out[:, 0].reshape(ii.shape), xx, atol=1e-7) assert_allclose(out[:, 1].reshape(ii.shape), yy, atol=1e-7) class TestGeogridSim(unittest.TestCase): - @requires_geopandas - def test_lambert(self): - + def test_lambert(self) -> None: from salem.wrftools import geogrid_simulator - g, m = geogrid_simulator(get_demo_file('namelist_lambert.wps')) + g, _ = geogrid_simulator(get_demo_file('namelist_lambert.wps')) assert len(g) == 3 @@ -899,16 +939,15 @@ def test_lambert(self): fg = get_demo_file('geo_em_d0{}_lambert.nc'.format(i)) with netCDF4.Dataset(fg) as nc: nc.set_auto_mask(False) - lon, lat = g[i-1].ll_coordinates + lon, lat = g[i - 1].ll_coordinates assert_allclose(lon, nc['XLONG_M'][0, ...], atol=1e-4) assert_allclose(lat, nc['XLAT_M'][0, ...], atol=1e-4) @requires_geopandas - def test_lambert_tuto(self): - + def test_lambert_tuto(self) -> None: from salem.wrftools import geogrid_simulator - g, m = geogrid_simulator(get_demo_file('namelist_tutorial.wps')) + g, _ = geogrid_simulator(get_demo_file('namelist_tutorial.wps')) assert len(g) == 1 @@ -920,11 +959,10 @@ def test_lambert_tuto(self): assert_allclose(lat, nc['XLAT_M'][0, ...], atol=1e-4) @requires_geopandas - def test_mercator(self): - + def test_mercator(self) -> None: from salem.wrftools import geogrid_simulator - g, m = geogrid_simulator(get_demo_file('namelist_mercator.wps')) + g, _ = geogrid_simulator(get_demo_file('namelist_mercator.wps')) assert len(g) == 4 @@ -932,16 +970,15 @@ def test_mercator(self): fg = get_demo_file('geo_em_d0{}_mercator.nc'.format(i)) with netCDF4.Dataset(fg) as nc: nc.set_auto_mask(False) - lon, lat = g[i-1].ll_coordinates + lon, lat = g[i - 1].ll_coordinates assert_allclose(lon, nc['XLONG_M'][0, ...], atol=1e-4) assert_allclose(lat, nc['XLAT_M'][0, ...], atol=1e-4) @requires_geopandas - def test_polar(self): - + def test_polar(self) -> None: from salem.wrftools import geogrid_simulator - g, m = geogrid_simulator(get_demo_file('namelist_polar.wps')) + g, _ = geogrid_simulator(get_demo_file('namelist_polar.wps')) assert len(g) == 2 @@ -949,6 +986,6 @@ def test_polar(self): fg = get_demo_file('geo_em_d0{}_polarstereo.nc'.format(i)) with netCDF4.Dataset(fg) as nc: nc.set_auto_mask(False) - lon, lat = g[i-1].ll_coordinates + lon, lat = g[i - 1].ll_coordinates assert_allclose(lon, nc['XLONG_M'][0, ...], atol=5e-3) assert_allclose(lat, nc['XLAT_M'][0, ...], atol=5e-3) diff --git a/salem/utils.py b/salem/utils.py index e0ec917..2b23419 100644 --- a/salem/utils.py +++ b/salem/utils.py @@ -1,22 +1,48 @@ -""" -Some useful functions -""" -from __future__ import division +"""Some useful functions.""" +from __future__ import annotations + +import importlib.util import io import os import shutil +import sys +import warnings import zipfile from collections import OrderedDict +from pathlib import Path +from typing import List +from urllib.request import urlopen, urlretrieve import numpy as np from joblib import Memory -from salem import (cache_dir, sample_data_dir, sample_data_gh_commit, - download_dir, python_version) -from urllib.request import urlretrieve, urlopen + +from salem import ( + cache_dir, + download_dir, + python_version, + sample_data_dir, + sample_data_gh_commit, +) + + +def import_if_exists(module_name: str, package: str | None = None) -> bool: + """Import a module if it exists and is not yet imported.""" + if module_name in sys.modules: + return True + try: + importlib.import_module(module_name, package=package) + except ImportError: + return False + return True + + +def deprecated_arg(msg: str) -> None: + """Warns that an argument is deprecated.""" + warnings.warn(msg, DeprecationWarning, stacklevel=2) -def _hash_cache_dir(): +def _hash_cache_dir() -> str: """Get the path to the right cache directory. We need to make sure that cached files correspond to the same @@ -28,6 +54,7 @@ def _hash_cache_dir(): Returns ------- path to the dir + """ import hashlib @@ -35,42 +62,49 @@ def _hash_cache_dir(): try: import shapely + out['shapely_version'] = shapely.__version__ out['shapely_file'] = shapely.__file__ except ImportError: pass try: import fiona + out['fiona_version'] = fiona.__version__ out['fiona_file'] = fiona.__file__ except ImportError: pass try: import pandas + out['pandas_version'] = pandas.__version__ out['pandas_file'] = pandas.__file__ except ImportError: pass try: import geopandas + out['geopandas_version'] = geopandas.__version__ out['geopandas_file'] = geopandas.__file__ except ImportError: pass try: import osgeo + out['osgeo_version'] = osgeo.__version__ out['osgeo_file'] = osgeo.__file__ except ImportError: pass try: import pyproj + out['pyproj_version'] = pyproj.__version__ out['pyproj_file'] = pyproj.__file__ except ImportError: pass try: import salem + out['salem_version'] = salem.__version__ out['salem_file'] = salem.__file__ except ImportError: @@ -94,15 +128,51 @@ def _hash_cache_dir(): # A series of variables and dimension names that Salem will understand valid_names = dict() -valid_names['x_dim'] = ['west_east', 'lon', 'longitude', 'longitudes', 'lons', - 'xlong', 'xlong_m', 'dimlon', 'x', 'lon_3', 'long', - 'phony_dim_0', 'eastings', 'easting', 'nlon', 'nlong', - 'grid_longitude_t'] -valid_names['y_dim'] = ['south_north', 'lat', 'latitude', 'latitudes', 'lats', - 'xlat', 'xlat_m', 'dimlat', 'y', 'lat_3', 'phony_dim_1', - 'northings', 'northing', 'nlat', 'grid_latitude_t'] -valid_names['z_dim'] = ['levelist', 'level', 'pressure', 'press', 'zlevel', 'z', - 'bottom_top'] +valid_names['x_dim'] = [ + 'west_east', + 'lon', + 'longitude', + 'longitudes', + 'lons', + 'xlong', + 'xlong_m', + 'dimlon', + 'x', + 'lon_3', + 'long', + 'phony_dim_0', + 'eastings', + 'easting', + 'nlon', + 'nlong', + 'grid_longitude_t', +] +valid_names['y_dim'] = [ + 'south_north', + 'lat', + 'latitude', + 'latitudes', + 'lats', + 'xlat', + 'xlat_m', + 'dimlat', + 'y', + 'lat_3', + 'phony_dim_1', + 'northings', + 'northing', + 'nlat', + 'grid_latitude_t', +] +valid_names['z_dim'] = [ + 'levelist', + 'level', + 'pressure', + 'press', + 'zlevel', + 'z', + 'bottom_top', +] valid_names['t_dim'] = ['time', 'times', 'xtime'] valid_names['lon_var'] = ['lon', 'longitude', 'longitudes', 'lons', 'long'] @@ -113,7 +183,7 @@ def _hash_cache_dir(): nearth_base = 'http://shadedrelief.com/natural3/ne3_data/' -def str_in_list(l1, l2): +def str_in_list(l1: List[str], l2: List[str]) -> List[str]: """Check if one element of l1 is in l2 and if yes, returns the name of that element in a list (could be more than one. @@ -123,55 +193,56 @@ def str_in_list(l1, l2): ['time'] >>> print(str_in_list(['time', 'lon'], ['temp','time','prcp','lon'])) ['time', 'lon'] + """ return [i for i in l1 if i.lower() in l2] -def empty_cache(): +def empty_cache() -> None: """Empty salem's cache directory.""" - - if os.path.exists(cache_dir): + if cache_dir.exists(): shutil.rmtree(cache_dir) - os.makedirs(cache_dir) + cache_dir.mkdir() -def cached_shapefile_path(fpath): - """Checks if a shapefile is cached and returns the corresponding path. +def cached_shapefile_path(fpath: Path | str) -> Path: + """Check if a shapefile is cached and returns the corresponding path. This function checks for the last time the file has changed, so it should be safe to use. """ - - p, ext = os.path.splitext(fpath) - - if ext.lower() == '.p': + if isinstance(fpath, str): + fpath = Path(fpath) + if fpath.suffix.lower() == '.p': # No need to recache pickled files (this is for nested calls) return fpath - if ext.lower() != '.shp': - raise ValueError('File extension not recognised: {}'.format(ext)) + if fpath.suffix.lower() != '.shp': + msg = f'File extension not recognised: {fpath.suffix.lower()}' + raise ValueError(msg) # Cached directory and file - cp = os.path.commonprefix([cache_dir, p]) - cp = os.path.join(cache_dir, hash_cache_dir + '_shp', - os.path.relpath(p, cp)) - ct = '{:d}'.format(int(round(os.path.getmtime(fpath)*1000.))) - of = os.path.join(cp, ct + '.p') - if os.path.exists(cp): + cp = Path(os.path.commonprefix([cache_dir, fpath.with_suffix('')])) + cp = ( + cache_dir + / f'{hash_cache_dir}_shp' + / fpath.with_suffix('').relative_to(cp) + ) + ct = f'{int(round(fpath.stat().st_mtime*1000.)):d}' + of = cp / (ct + '.p') + if cp.exists(): # We have to check if the file changed - if os.path.exists(of): + if of.exists(): return of - else: - # the file has changed - shutil.rmtree(cp) + # the file has changed + shutil.rmtree(cp) - os.makedirs(cp) + cp.mkdir(parents=True, exist_ok=True) return of def _urlretrieve(url, ofile, *args, **kwargs): """Wrapper for urlretrieve which overwrites.""" - try: return urlretrieve(url, ofile, *args, **kwargs) except: @@ -191,11 +262,13 @@ def download_demo_files(): Borrowed from OGGM. """ - - master_zip_url = 'https://github.com/%s/archive/%s.zip' % \ - (sample_data_gh_repo, sample_data_gh_commit) - ofile = os.path.join(cache_dir, - 'salem-sample-data-%s.zip' % sample_data_gh_commit) + master_zip_url = 'https://github.com/{}/archive/{}.zip'.format( + sample_data_gh_repo, + sample_data_gh_commit, + ) + ofile = os.path.join( + cache_dir, 'salem-sample-data-%s.zip' % sample_data_gh_commit + ) odir = os.path.join(cache_dir) # download only if necessary @@ -224,14 +297,13 @@ def download_demo_files(): return out -def get_demo_file(fname): - """Returns the path to the desired demo file.""" - +def get_demo_file(fname: str | Path) -> Path: + """Return the path to the desired demo file.""" d = download_demo_files() - if fname in d: - return d[fname] - else: - return None + if str(fname) in d: + return Path(d[str(fname)]) + msg = 'File not found in demo files: {}'.format(fname) + raise FileNotFoundError(msg) def get_natural_earth_file(res='lr'): @@ -243,8 +315,8 @@ def get_natural_earth_file(res='lr'): ---------- res : str 'lr' or 'hr' (low res or high res) - """ + """ if not os.path.exists(download_dir): os.makedirs(download_dir) @@ -267,25 +339,24 @@ def get_natural_earth_file(res='lr'): @memory.cache def read_colormap(name): """Reads a colormap from the custom files in Salem.""" - path = get_demo_file(name + '.c3g') out = [] - with open(path, 'r') as file: + with open(path) as file: for line in file: if 'rgb(' not in line: continue line = line.split('(')[-1].split(')')[0] out.append([float(n) for n in line.split(',')]) - return np.asarray(out).astype(float) / 256. + return np.asarray(out).astype(float) / 256.0 @memory.cache def joblib_read_img_url(url): """Prevent to re-download from GoogleStaticMap if it was done before""" - from matplotlib.image import imread + fd = urlopen(url, timeout=10) return imread(io.BytesIO(fd.read())) @@ -306,9 +377,10 @@ def nice_scale(mapextent, maxlen=0.15): 20.0 >>> print(nice_scale(140, maxlen=0.5)) 50.0 + """ d = np.array([1, 2, 5]) - e = (np.ones(12) * 10) ** (np.arange(12)-5) + e = (np.ones(12) * 10) ** (np.arange(12) - 5) candidates = np.matmul(e[:, None], d[None, :]).flatten() return np.max(candidates[candidates / mapextent <= maxlen]) @@ -332,10 +404,17 @@ def reduce(arr, factor=1, how=np.mean): Returns ------- the reduced array + """ arr = np.asarray(arr) shape = list(arr.shape) - newshape = shape[:-2] + [np.round(shape[-2] / factor).astype(int), factor, - np.round(shape[-1] / factor).astype(int), factor] - return how(how(arr.reshape(*newshape), axis=len(newshape)-3), - axis=len(newshape)-2) + newshape = shape[:-2] + [ + np.round(shape[-2] / factor).astype(int), + factor, + np.round(shape[-1] / factor).astype(int), + factor, + ] + return how( + how(arr.reshape(*newshape), axis=len(newshape) - 3), + axis=len(newshape) - 2, + ) diff --git a/salem/version.py b/salem/version.py index 055d4af..54aef3c 100644 --- a/salem/version.py +++ b/salem/version.py @@ -1,5 +1,6 @@ try: - from importlib.metadata import version, PackageNotFoundError + from importlib.metadata import PackageNotFoundError, version + try: __version__ = version(__name__.split('.', maxsplit=1)[0]) except PackageNotFoundError: @@ -8,9 +9,12 @@ finally: del version, PackageNotFoundError except ModuleNotFoundError: - from pkg_resources import get_distribution, DistributionNotFound + from pkg_resources import DistributionNotFound, get_distribution + try: - __version__ = get_distribution(__name__.split('.', maxsplit=1)[0]).version + __version__ = get_distribution( + __name__.split('.', maxsplit=1)[0] + ).version except DistributionNotFound: # package is not installed pass diff --git a/salem/wrftools.py b/salem/wrftools.py index 8b95d32..4d51da8 100644 --- a/salem/wrftools.py +++ b/salem/wrftools.py @@ -3,17 +3,16 @@ Diagnostic variables are simply a subclass of FakeVariable that implement __getitem__. See examples below. """ -from __future__ import division + import copy import numpy as np -import pyproj -from scipy.interpolate import interp1d from netCDF4 import num2date from pandas import to_datetime +from scipy.interpolate import interp1d from xarray.core import indexing -from salem import lazy_property, wgs84, gis +from salem import gis, lazy_property, wgs84 POOL = None @@ -22,6 +21,7 @@ def _init_pool(): global POOL if POOL is None: import multiprocessing as mp + POOL = mp.Pool() @@ -29,8 +29,7 @@ def dummy_func(*args): pass -class ScaledVar(): - +class ScaledVar: def __init__(self, ncvar): self.ncvar = ncvar try: @@ -46,11 +45,11 @@ def __exit__(self, type, value, traceback): self.ncvar.set_auto_scale(self.scale) -class Unstaggerer(object): +class Unstaggerer: """Duck NetCDF4.Variable class which "unstaggers" WRF variables. - It looks for the staggered dimension and automatically unstaggers it. - """ + It looks for the staggered dimension and automatically unstaggers it. + """ def __init__(self, ncvar): """Instanciate. @@ -58,8 +57,8 @@ def __init__(self, ncvar): Parameters ---------- ncvar: the netCDF variable to unstagger. - """ + """ self.ncvar = ncvar # Attributes @@ -91,11 +90,13 @@ def __init__(self, ncvar): def filter_attrs(): return attrs + self.ncattrs = filter_attrs self.filters = ncvar.filters def _chunking(): return self.shape + self.chunking = _chunking for attr in self.ncattrs(): @@ -116,12 +117,12 @@ def can_do(ncvar): Parameters ---------- ncvar: the netCDF variable candidate forunstagger. + """ return np.any(['_stag' in d for d in ncvar.dimensions]) def __getitem__(self, item): """Override __getitem__.""" - # take care of ellipsis and other strange indexes item = list(indexing.expanded_indexer(item, len(self.dimensions))) @@ -129,36 +130,36 @@ def __getitem__(self, item): was_scalar = False sl = item[self.ds] if np.isscalar(sl) and not isinstance(sl, slice): - sl = slice(sl, sl+1) + sl = slice(sl, sl + 1) was_scalar = True # Ok, get the indexes right start = sl.start or 0 stop = sl.stop or self._ds_shape if stop < 0: - stop += self._ds_shape-1 - stop = np.clip(stop+1, 0, self._ds_shape) + stop += self._ds_shape - 1 + stop = np.clip(stop + 1, 0, self._ds_shape) itemr = copy.deepcopy(item) if was_scalar: item[self.ds] = start - itemr[self.ds] = start+1 + itemr[self.ds] = start + 1 else: - item[self.ds] = slice(start, stop-1) - itemr[self.ds] = slice(start+1, stop) + item[self.ds] = slice(start, stop - 1) + itemr[self.ds] = slice(start + 1, stop) with ScaledVar(self.ncvar) as var: - return 0.5*(var[tuple(item)] + var[tuple(itemr)]) + return 0.5 * (var[tuple(item)] + var[tuple(itemr)]) -class FakeVariable(object): - """Duck NetCDF4.Variable class - """ +class FakeVariable: + """Duck NetCDF4.Variable class""" + def __init__(self, nc): self.name = self.__class__.__name__ self.nc = nc @staticmethod def can_do(): - raise NotImplementedError() + raise NotImplementedError def _copy_attrs_from(self, ncvar): # copies the necessary nc attributes from a template variable @@ -170,6 +171,7 @@ def _copy_attrs_from(self, ncvar): def filter_attrs(): return attrs + self.ncattrs = filter_attrs self.filters = ncvar.filters self.chunking = ncvar.chunking @@ -187,7 +189,7 @@ def getncattr(self, name): return getattr(self, name) def __getitem__(self, item): - raise NotImplementedError() + raise NotImplementedError class T2C(FakeVariable): @@ -234,9 +236,13 @@ def _factor(self): time = [] stimes = vars['Times'][0:2] for t in stimes: - time.append(to_datetime(t.tobytes().decode(), - errors='raise', - format='%Y-%m-%d_%H:%M:%S')) + time.append( + to_datetime( + t.tobytes().decode(), + errors='raise', + format='%Y-%m-%d_%H:%M:%S', + ) + ) dt_minutes = time[1] - time[0] dt_minutes = dt_minutes.seconds / 60 return 60 / dt_minutes @@ -251,7 +257,6 @@ def can_do(nc): return can_do def __getitem__(self, item): - # take care of ellipsis and other strange indexes item = list(indexing.expanded_indexer(item, len(self.dimensions))) @@ -260,25 +265,25 @@ def __getitem__(self, item): was_scalar = False if np.isscalar(sl) and not isinstance(sl, slice): was_scalar = True - sl = slice(sl, sl+1) + sl = slice(sl, sl + 1) # Ok, get the indexes right start = sl.start or 0 stop = sl.stop or self._nel if stop < 0: - stop += self._nel-1 + stop += self._nel - 1 start -= 1 do_nan = False if start < 0: do_nan = True itemr = copy.deepcopy(item) - item[0] = slice(start, stop-1) - itemr[0] = slice(start+1, stop) + item[0] = slice(start, stop - 1) + itemr[0] = slice(start + 1, stop) # done with ScaledVar(self.nc.variables[self.accvn]) as var: if do_nan: - item[0] = slice(0, stop-1) + item[0] = slice(0, stop - 1) out = var[itemr] try: # in case we have a masked array @@ -296,7 +301,6 @@ def __getitem__(self, item): class PRCP_NC(AccumulatedVariable): - def __init__(self, nc): AccumulatedVariable.__init__(self, nc, 'RAINNC') self.units = 'mm h-1' @@ -308,7 +312,6 @@ def can_do(nc): class PRCP_C(AccumulatedVariable): - def __init__(self, nc): AccumulatedVariable.__init__(self, nc, 'RAINC') self.units = 'mm h-1' @@ -328,13 +331,17 @@ def __init__(self, nc): @staticmethod def can_do(nc): - return (AccumulatedVariable.can_do(nc) and - 'RAINC' in nc.variables and - 'RAINNC' in nc.variables) + return ( + AccumulatedVariable.can_do(nc) + and 'RAINC' in nc.variables + and 'RAINNC' in nc.variables + ) def __getitem__(self, item): - with ScaledVar(self.nc.variables['PRCP_NC']) as p1, \ - ScaledVar(self.nc.variables['PRCP_C']) as p2: + with ( + ScaledVar(self.nc.variables['PRCP_NC']) as p1, + ScaledVar(self.nc.variables['PRCP_C']) as p2, + ): return p1[item] + p2[item] @@ -351,7 +358,7 @@ def can_do(nc): def __getitem__(self, item): with ScaledVar(self.nc.variables['T']) as var: - return var[item] + 300. + return var[item] + 300.0 class TK(FakeVariable): @@ -366,15 +373,17 @@ def can_do(nc): return np.all([n in nc.variables for n in ['T', 'P', 'PB']]) def __getitem__(self, item): - p1000mb = 100000. + p1000mb = 100000.0 r_d = 287.04 - cp = 7 * r_d / 2. + cp = 7 * r_d / 2.0 with ScaledVar(self.nc.variables['T']) as var: - t = var[item] + 300. - with ScaledVar(self.nc.variables['P']) as p, \ - ScaledVar(self.nc.variables['PB']) as pb: + t = var[item] + 300.0 + with ( + ScaledVar(self.nc.variables['P']) as p, + ScaledVar(self.nc.variables['PB']) as pb, + ): p = p[item] + pb[item] - return (p/p1000mb)**(r_d/cp) * t + return (p / p1000mb) ** (r_d / cp) * t class WS(FakeVariable): @@ -390,9 +399,9 @@ def can_do(nc): def __getitem__(self, item): with ScaledVar(self.nc.variables['U']) as var: - ws = var[item]**2 + ws = var[item] ** 2 with ScaledVar(self.nc.variables['V']) as var: - ws += var[item]**2 + ws += var[item] ** 2 return np.sqrt(ws) @@ -408,9 +417,10 @@ def can_do(nc): return np.all([n in nc.variables for n in ['P', 'PB']]) def __getitem__(self, item): - - with ScaledVar(self.nc.variables['P']) as p, \ - ScaledVar(self.nc.variables['PB']) as pb: + with ( + ScaledVar(self.nc.variables['P']) as p, + ScaledVar(self.nc.variables['PB']) as pb, + ): res = p[item] + pb[item] if p.units == 'Pa': res /= 100 @@ -431,8 +441,10 @@ def can_do(nc): return np.all([n in nc.variables for n in ['PH', 'PHB']]) def __getitem__(self, item): - with ScaledVar(self.nc.variables['PH']) as p, \ - ScaledVar(self.nc.variables['PHB']) as pb: + with ( + ScaledVar(self.nc.variables['PH']) as p, + ScaledVar(self.nc.variables['PHB']) as pb, + ): return p[item] + pb[item] @@ -469,17 +481,16 @@ def can_do(nc): return np.all([n in nc.variables for n in need]) def __getitem__(self, item): - # take care of ellipsis and other strange indexes item = list(indexing.expanded_indexer(item, len(self.dimensions))) # we need the empty dims for _ncl_slp() to work squeezax = [] for i, c in enumerate(item): if np.isscalar(c) and not isinstance(c, slice): - item[i] = slice(c, c+1) + item[i] = slice(c, c + 1) squeezax.append(i) # add a slice in the 4th dim - item.insert(self.ds, slice(0, self._ds_shape+1)) + item.insert(self.ds, slice(0, self._ds_shape + 1)) item = tuple(item) # get data @@ -497,19 +508,20 @@ def __getitem__(self, item): # Diagnostic variable classes in a list var_classes = [cls.__name__ for cls in vars()['FakeVariable'].__subclasses__()] -var_classes.extend([cls.__name__ for cls in - vars()['AccumulatedVariable'].__subclasses__()]) +var_classes.extend( + [cls.__name__ for cls in vars()['AccumulatedVariable'].__subclasses__()] +) var_classes.remove('AccumulatedVariable') def _interp1d(args): - f = interp1d(args[0], args[1], fill_value=args[3], - bounds_error=False) + f = interp1d(args[0], args[1], fill_value=args[3], bounds_error=False) return f(args[2]) -def interp3d(data, zcoord, levels, fill_value=np.nan, - use_multiprocessing=True): +def interp3d( + data, zcoord, levels, fill_value=np.nan, use_multiprocessing=True +): """Interpolate on the first dimension of a 3d var Useful for WRF pressure or geopotential levels @@ -531,14 +543,17 @@ def interp3d(data, zcoord, levels, fill_value=np.nan, Returns ------- a ndarray, with the first dimension now begin of shape nlevels - """ + """ ndims = len(data.shape) if ndims == 4: out = [] for d, z in zip(data, zcoord): - out.append(np.expand_dims(interp3d(d, z, levels, - fill_value=fill_value), 0)) + out.append( + np.expand_dims( + interp3d(d, z, levels, fill_value=fill_value), 0 + ) + ) return np.concatenate(out, axis=0) if ndims != 3: raise ValueError('ndims must be 3') @@ -547,8 +562,9 @@ def interp3d(data, zcoord, levels, fill_value=np.nan, inp = [] for j in range(data.shape[-2]): for i in range(data.shape[-1]): - inp.append((zcoord[:, j, i], data[:, j, i], levels, - fill_value)) + inp.append( + (zcoord[:, j, i], data[:, j, i], levels, fill_value) + ) _init_pool() out = POOL.map(_interp1d, inp, chunksize=1000) out = np.asarray(out).T @@ -560,8 +576,12 @@ def interp3d(data, zcoord, levels, fill_value=np.nan, out = np.zeros((len(levels), data.shape[-2], data.shape[-1])) for i in range(data.shape[-1]): for j in range(data.shape[-2]): - f = interp1d(zcoord[:, j, i], data[:, j, i], - fill_value=fill_value, bounds_error=False) + f = interp1d( + zcoord[:, j, i], + data[:, j, i], + fill_value=fill_value, + bounds_error=False, + ) out[:, j, i] = f(levels) return out @@ -579,8 +599,8 @@ def _ncl_slp(z, t, p, q): T: temp P: pressure Q: specific humidity - """ + """ ndims = len(z.shape) if ndims == 4: out = [] @@ -596,7 +616,7 @@ def _ncl_slp(z, t, p, q): g = 9.81 gamma = 0.0065 tc = 273.16 + 17.5 - pconst = 10000. + pconst = 10000.0 # Find least zeta level that is pconst Pa above the surface. We # later use this level to extrapolate a surface pressure and @@ -613,14 +633,13 @@ def _ncl_slp(z, t, p, q): if np.any(level == -1): raise RuntimeError('Error_in_finding_100_hPa_up') # pragma: no cover - klo = (level-1).clip(0, nz-1) - khi = (klo+1).clip(0, nz-1) + klo = (level - 1).clip(0, nz - 1) + khi = (klo + 1).clip(0, nz - 1) if np.any((klo - khi) == 0): raise RuntimeError('Trapping levels are weird.') # pragma: no cover - x, y = np.meshgrid(np.arange(nx, dtype=int), - np.arange(ny, dtype=int)) + x, y = np.meshgrid(np.arange(nx, dtype=int), np.arange(ny, dtype=int)) plo = p[klo, y, x] phi = p[khi, y, x] @@ -634,13 +653,17 @@ def _ncl_slp(z, t, p, q): qlo = q[klo, y, x] qhi = q[khi, y, x] - tlo *= (1. + 0.608 * qlo) - thi *= (1. + 0.608 * qhi) + tlo *= 1.0 + 0.608 * qlo + thi *= 1.0 + 0.608 * qhi p_at_pconst = p0 - pconst - t_at_pconst = thi - (thi-tlo) * np.log(p_at_pconst/phi) * np.log(plo/phi) - z_at_pconst = zhi - (zhi-zlo) * np.log(p_at_pconst/phi) * np.log(plo/phi) - t_surf = t_at_pconst * ((p0/p_at_pconst)**(gamma*r/g)) + t_at_pconst = thi - (thi - tlo) * np.log(p_at_pconst / phi) * np.log( + plo / phi + ) + z_at_pconst = zhi - (zhi - zlo) * np.log(p_at_pconst / phi) * np.log( + plo / phi + ) + t_surf = t_at_pconst * ((p0 / p_at_pconst) ** (gamma * r / g)) t_sea_level = t_at_pconst + gamma * z_at_pconst # If we follow a traditional computation, there is a correction to the @@ -649,7 +672,7 @@ def _ncl_slp(z, t, p, q): l1 = t_sea_level < tc l2 = t_surf <= tc l3 = ~l1 - t_sea_level = tc - 0.005 * (t_surf-tc)**2 + t_sea_level = tc - 0.005 * (t_surf - tc) ** 2 pok = np.nonzero(l2 & l3) t_sea_level[pok] = tc @@ -657,7 +680,9 @@ def _ncl_slp(z, t, p, q): z_half_lowest = z[0, ...] # Convert to hPa in this step - return 0.01 * (p0 * np.exp((2.*g*z_half_lowest)/(r*(t_sea_level+t_surf)))) + return 0.01 * ( + p0 * np.exp((2.0 * g * z_half_lowest) / (r * (t_sea_level + t_surf))) + ) def geogrid_simulator(fpath, do_maps=True, map_kwargs=None): @@ -678,8 +703,8 @@ def geogrid_simulator(fpath, do_maps=True, map_kwargs=None): - grids: a list of Grids corresponding to the domains defined in the namelist - maps: a list of maps corresponding to the grids (if do_maps==True) - """ + """ with open(fpath) as f: lines = f.readlines() @@ -725,39 +750,46 @@ def geogrid_simulator(fpath, do_maps=True, map_kwargs=None): # define projection if map_proj == 'LAMBERT': - pwrf = '+proj=lcc +lat_1={lat_1} +lat_2={lat_2} ' \ - '+lat_0={lat_0} +lon_0={lon_0} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + pwrf = ( + '+proj=lcc +lat_1={lat_1} +lat_2={lat_2} ' + '+lat_0={lat_0} +lon_0={lon_0} ' + '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + ) pwrf = pwrf.format(**pargs) elif map_proj == 'MERCATOR': - pwrf = '+proj=merc +lat_ts={lat_1} +lon_0={lon_0} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + pwrf = ( + '+proj=merc +lat_ts={lat_1} +lon_0={lon_0} ' + '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + ) pwrf = pwrf.format(**pargs) elif map_proj == 'POLAR': - pwrf = '+proj=stere +lat_ts={lat_1} +lat_0=90.0 +lon_0={lon_0} ' \ - '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + pwrf = ( + '+proj=stere +lat_ts={lat_1} +lat_0=90.0 +lon_0={lon_0} ' + '+x_0=0 +y_0=0 +a=6370000 +b=6370000' + ) pwrf = pwrf.format(**pargs) else: - raise NotImplementedError('WRF proj not implemented yet: ' - '{}'.format(map_proj)) + raise NotImplementedError( + 'WRF proj not implemented yet: ' '{}'.format(map_proj) + ) pwrf = gis.check_crs(pwrf) # get easting and northings from dom center (probably unnecessary here) e, n = gis.transform_proj(wgs84, pwrf, pargs['ref_lon'], pargs['lat_0']) # LL corner - nx, ny = e_we[0]-1, e_sn[0]-1 - x0 = -(nx-1) / 2. * dx + e # -2 because of staggered grid - y0 = -(ny-1) / 2. * dy + n + nx, ny = e_we[0] - 1, e_sn[0] - 1 + x0 = -(nx - 1) / 2.0 * dx + e # -2 because of staggered grid + y0 = -(ny - 1) / 2.0 * dy + n # parent grid grid = gis.Grid(nxny=(nx, ny), x0y0=(x0, y0), dxdy=(dx, dy), proj=pwrf) # child grids out = [grid] - for ips, jps, pid, ratio, we, sn in zip(i_parent_start, j_parent_start, - parent_id, parent_ratio, - e_we, e_sn): + for ips, jps, pid, ratio, we, sn in zip( + i_parent_start, j_parent_start, parent_id, parent_ratio, e_we, e_sn + ): if ips == 1: continue ips -= 1 @@ -767,28 +799,35 @@ def geogrid_simulator(fpath, do_maps=True, map_kwargs=None): nx = we / ratio ny = sn / ratio if nx != (we / ratio): - raise RuntimeError('e_we and ratios are incompatible: ' - '(e_we - 1) / ratio must be integer!') + raise RuntimeError( + 'e_we and ratios are incompatible: ' + '(e_we - 1) / ratio must be integer!' + ) if ny != (sn / ratio): - raise RuntimeError('e_sn and ratios are incompatible: ' - '(e_sn - 1) / ratio must be integer!') + raise RuntimeError( + 'e_sn and ratios are incompatible: ' + '(e_sn - 1) / ratio must be integer!' + ) prevgrid = out[pid - 1] xx, yy = prevgrid.corner_grid.x_coord, prevgrid.corner_grid.y_coord dx = prevgrid.dx / ratio dy = prevgrid.dy / ratio - grid = gis.Grid(nxny=(we, sn), - x0y0=(xx[ips], yy[jps]), - dxdy=(dx, dy), - pixel_ref='corner', - proj=pwrf) + grid = gis.Grid( + nxny=(we, sn), + x0y0=(xx[ips], yy[jps]), + dxdy=(dx, dy), + pixel_ref='corner', + proj=pwrf, + ) out.append(grid.center_grid) maps = None if do_maps: - from salem import Map import shapely.geometry as shpg + from salem import Map + if map_kwargs is None: map_kwargs = {} @@ -796,15 +835,22 @@ def geogrid_simulator(fpath, do_maps=True, map_kwargs=None): for i, g in enumerate(out): m = Map(g, **map_kwargs) - for j in range(i+1, len(out)): + for j in range(i + 1, len(out)): cg = out[j] left, right, bottom, top = cg.extent - s = np.array([(left, bottom), (right, bottom), - (right, top), (left, top)]) + s = np.array( + [ + (left, bottom), + (right, bottom), + (right, top), + (left, top), + ] + ) l1 = shpg.LinearRing(s) - m.set_geometry(l1, crs=cg.proj, linewidth=(len(out)-j), - zorder=5) + m.set_geometry( + l1, crs=cg.proj, linewidth=(len(out) - j), zorder=5 + ) maps.append(m) diff --git a/setup.py b/setup.py index 41efb01..687a493 100755 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 import setuptools + if __name__ == "__main__": setuptools.setup()