Skip to content

Commit

Permalink
Adding statistics to Raster (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn authored Jan 10, 2025
1 parent 5c68af7 commit 514ea82
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 10 deletions.
28 changes: 28 additions & 0 deletions doc/source/raster_class.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,31 @@ rast_reproj.to_pointcloud()
# Export to xarray data array
rast_reproj.to_xarray()
```

## Obtain Statistics
The `get_stats()` method allows to extract key statistical information from a raster in a dictionary.
Supported statistics are : mean, median, max, mean, sum, sum of squares, 90th percentile, nmad, rmse, std.
Callable functions are supported as well.

### Usage Examples:
- Get all statistics in a dict:
```{code-cell} ipython3
rast.get_stats()
```

- Get a single statistic (e.g., 'mean') as a float:
```{code-cell} ipython3
rast.get_stats("mean")
```

- Get multiple statistics in a dict:
```{code-cell} ipython3
rast.get_stats(["mean", "max", "rmse"])
```

- Using a custom callable statistic:
```{code-cell} ipython3
def custom_stat(data):
return np.nansum(data > 100) # Count the number of pixels above 100
rast.get_stats(custom_stat)
```
177 changes: 167 additions & 10 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import logging
import math
import pathlib
import warnings
Expand Down Expand Up @@ -69,6 +70,7 @@
decode_sensor_metadata,
parse_and_convert_metadata_from_filename,
)
from geoutils.stats import nmad
from geoutils.vector.vector import Vector

# If python38 or above, Literal is builtin. Otherwise, use typing_extensions
Expand Down Expand Up @@ -1870,6 +1872,157 @@ def set_mask(self, mask: NDArrayBool | Mask) -> None:
else:
self.data[mask_arr > 0] = np.ma.masked

def _statistics(self, band: int = 1) -> dict[str, np.floating[Any]]:
"""
Calculate common statistics for a specified band in the raster.
:param band: The index of the band for which to compute statistics. Default is 1.
:returns: A dictionary containing the calculated statistics for the selected band, including mean, median, max,
min, sum, sum of squares, 90th percentile, NMAD, RMSE, and standard deviation.
"""
if self.count == 1:
data = self.data
else:
data = self.data[band - 1]

# If data is a MaskedArray, use the compressed version (without masked values)
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()

# Compute the statistics
stats_dict = {
"Mean": np.nanmean(data),
"Median": np.nanmedian(data),
"Max": np.nanmax(data),
"Min": np.nanmin(data),
"Sum": np.nansum(data),
"Sum of squares": np.nansum(np.square(data)),
"90th percentile": np.nanpercentile(data, 90),
"NMAD": nmad(data),
"RMSE": np.sqrt(np.nanmean(np.square(data - np.nanmean(data)))),
"Standard deviation": np.nanstd(data),
}
return stats_dict

@overload
def get_stats(
self,
stats_name: (
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
),
band: int = 1,
) -> np.floating[Any]: ...

@overload
def get_stats(
self,
stats_name: (
list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> dict[str, np.floating[Any]]: ...

def get_stats(
self,
stats_name: (
Literal["mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"]
| Callable[[NDArrayNum], np.floating[Any]]
| list[
Literal[
"mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
]
| Callable[[NDArrayNum], np.floating[Any]]
]
| None
) = None,
band: int = 1,
) -> np.floating[Any] | dict[str, np.floating[Any]]:
"""
Retrieve specified statistics or all available statistics for the raster data. Allows passing custom callables
to calculate custom stats.
:param stats_name: Name or list of names of the statistics to retrieve. If None, all statistics are returned.
Accepted names include:
- "mean", "median", "max", "min", "sum", "sum of squares", "90th percentile", "nmad", "rmse", "std"
You can also use common aliases for these names (e.g., "average", "maximum", "minimum", etc.).
Custom callables can also be provided.
:param band: The index of the band for which to compute statistics. Default is 1.
:returns: The requested statistic or a dictionary of statistics if multiple or all are requested.
"""
if not self.is_loaded:
self.load()
stats_dict = self._statistics(band=band)
if stats_name is None:
return stats_dict

# Define the metric aliases and their actual names
stats_aliases = {
"mean": "Mean",
"average": "Mean",
"median": "Median",
"max": "Max",
"maximum": "Max",
"min": "Min",
"minimum": "Min",
"sum": "Sum",
"sumofsquares": "Sum of squares",
"sum2": "Sum of squares",
"percentile": "90th percentile",
"90thpercentile": "90th percentile",
"90percentile": "90th percentile",
"percentile90": "90th percentile",
"nmad": "NMAD",
"rmse": "RMSE",
"std": "Standard deviation",
"stddev": "Standard deviation",
"standarddev": "Standard deviation",
"standarddeviation": "Standard deviation",
}
if isinstance(stats_name, list):
result = {}
for name in stats_name:
if callable(name):
result[name.__name__] = name(self.data[band] if self.count > 1 else self.data)
else:
result[name] = self._get_single_stat(stats_dict, stats_aliases, name)
return result
else:
if callable(stats_name):
return stats_name(self.data[band] if self.count > 1 else self.data)
else:
return self._get_single_stat(stats_dict, stats_aliases, stats_name)

@staticmethod
def _get_single_stat(
stats_dict: dict[str, np.floating[Any]], stats_aliases: dict[str, str], stat_name: str
) -> np.floating[Any]:
"""
Retrieve a single statistic based on a flexible name or alias.
:param stats_dict: The dictionary of available statistics.
:param stats_aliases: The dictionary of alias mappings to the actual stat names.
:param stat_name: The name or alias of the statistic to retrieve.
:returns: The requested statistic value, or None if the stat name is not recognized.
"""

normalized_name = stat_name.lower().replace(" ", "").replace("_", "").replace("-", "")
if normalized_name in stats_aliases:
actual_name = stats_aliases[normalized_name]
return stats_dict[actual_name]
else:
logging.warning("Statistic name '%s' is not recognized", stat_name)
return np.float32(np.nan)

@overload
def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ...

Expand Down Expand Up @@ -1904,24 +2057,28 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str:
]

if stats:
as_str.append("\nStatistics:\n")
if not self.is_loaded:
self.load()

if self.count == 1:
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data):.2f}\n")
statistics = self.get_stats()

# Determine the maximum length of the stat names for alignment
max_len = max(len(name) for name in statistics.keys())

# Format the stats with aligned names
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")
else:
for b in range(self.count):
# try to keep with rasterio convention.
as_str.append(f"Band {b + 1}:\n")
as_str.append(f"[MAXIMUM]: {np.nanmax(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MINIMUM]: {np.nanmin(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEDIAN]: {np.ma.median(self.data[b, :, :]):.2f}\n")
as_str.append(f"[MEAN]: {np.nanmean(self.data[b, :, :]):.2f}\n")
as_str.append(f"[STD DEV]: {np.nanstd(self.data[b, :, :]):.2f}\n")
statistics = self.get_stats(band=b)
if isinstance(statistics, dict):
max_len = max(len(name) for name in statistics.keys())
for name, value in statistics.items():
as_str.append(f"{name.ljust(max_len)}: {value:.2f}\n")

if verbose:
print("".join(as_str))
Expand Down
26 changes: 26 additions & 0 deletions geoutils/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
""" Statistical tools"""

from typing import Any

import numpy as np

from geoutils._typing import NDArrayNum


def nmad(data: NDArrayNum, nfact: float = 1.4826) -> np.floating[Any]:
"""
Calculate the normalized median absolute deviation (NMAD) of an array.
Default scaling factor is 1.4826 to scale the median absolute deviation (MAD) to the dispersion of a normal
distribution (see https://en.wikipedia.org/wiki/Median_absolute_deviation#Relation_to_standard_deviation, and
e.g. Höhle and Höhle (2009), http://dx.doi.org/10.1016/j.isprsjprs.2009.02.003)
:param data: Input array or raster
:param nfact: Normalization factor for the data
:returns nmad: (normalized) median absolute deviation of data.
"""
if isinstance(data, np.ma.masked_array):
data_arr = data.compressed()
else:
data_arr = np.asarray(data)
return nfact * np.nanmedian(np.abs(data_arr - np.nanmedian(data_arr)))
50 changes: 50 additions & 0 deletions tests/test_raster/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from __future__ import annotations

import logging
import os
import pathlib
import re
import tempfile
import warnings
from cmath import isnan
from io import StringIO
from tempfile import TemporaryFile
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1944,6 +1947,53 @@ def test_split_bands(self) -> None:
red_c.data.data.squeeze().astype("float32"), img.data.data[0, :, :].astype("float32"), equal_nan=True
)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_stats(self, example: str, caplog) -> None:
raster = gu.Raster(example)

# Full stats
stats = raster.get_stats()
expected_stats = [
"Mean",
"Median",
"Max",
"Min",
"Sum",
"Sum of squares",
"90th percentile",
"NMAD",
"RMSE",
"Standard deviation",
]
for name in expected_stats:
assert name in stats
assert stats.get(name) is not None

# Single stat
stat = raster.get_stats(stats_name="Average")
assert isinstance(stat, np.floating)

def percentile_95(data: NDArrayNum) -> np.floating[Any]:
if isinstance(data, np.ma.MaskedArray):
data = data.compressed()
return np.nanpercentile(data, 95)

stat = raster.get_stats(stats_name=percentile_95)
assert isinstance(stat, np.floating)

# Selected stats and callable
stats_name = ["mean", "maximum", "std", "percentile_95"]
stats = raster.get_stats(stats_name=["mean", "maximum", "std", percentile_95])
for name in stats_name:
assert name in stats
assert stats.get(name) is not None

# non-existing stat
with caplog.at_level(logging.WARNING):
stat = raster.get_stats(stats_name="80 percentile")
assert isnan(stat)
assert "Statistic name '80 percentile' is not recognized" in caplog.text


class TestMask:
# Paths to example data
Expand Down
30 changes: 30 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Test functions for stats
"""

import scipy

from geoutils import Raster, examples
from geoutils.stats import nmad


class TestStats:
landsat_b4_path = examples.get_path("everest_landsat_b4")
landsat_raster = Raster(landsat_b4_path)

def test_nmad(self) -> None:
"""Test NMAD functionality runs on any type of input"""

# Check that the NMAD is computed the same with a masked array or NaN array, and is equal to scipy nmad
nmad_ma = nmad(self.landsat_raster.data)
nmad_array = nmad(self.landsat_raster.get_nanarray())
nmad_scipy = scipy.stats.median_abs_deviation(self.landsat_raster.data, axis=None, scale="normal")

assert nmad_ma == nmad_array
assert nmad_ma.round(2) == nmad_scipy.round(2)

# Check that the scaling factor works
nmad_1 = nmad(self.landsat_raster.data, nfact=1)
nmad_2 = nmad(self.landsat_raster.data, nfact=2)

assert nmad_1 * 2 == nmad_2

0 comments on commit 514ea82

Please sign in to comment.