diff --git a/examples/time_series_learning.py b/examples/time_series_learning.py index 562677e..0f3fa9a 100644 --- a/examples/time_series_learning.py +++ b/examples/time_series_learning.py @@ -15,7 +15,7 @@ import torch from matplotlib.figure import Figure -from torch_crps import crps_analytical +from torch_crps.analytical import crps_analytical EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent) diff --git a/pyproject.toml b/pyproject.toml index 34b5ef6..9acd520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ "torch>=2.7", ] -description = "Implementations of the CRPS using PyTorch" +description = "PyTorch-based implementations of the Continuously-Ranked Probability Score (CRPS) as well as its locally scale-invariant version (SCRPS)" dynamic = ["version"] license = "CC-BY-4.0" maintainers = [{name = "Fabio Muratore", email = "accounts@famura.net"}] @@ -140,6 +140,7 @@ ignore = [ "PLC0415", # import should be at the top-level of a file "PLR", # pylint refactor "RUF001", # allow for greek letters + "RUF003", # allow for greek letters "UP035", # allow importing Callable from typing ] preview = true diff --git a/readme.md b/readme.md index 4da6f30..8a05715 100644 --- a/readme.md +++ b/readme.md @@ -13,18 +13,58 @@ [![Ruff][ruff-badge]][ruff] [![uv][uv-badge]][uv] -Implementations of the Continuously-Ranked Probability Score (CRPS) using PyTorch +PyTorch-based implementations of the Continuously-Ranked Probability Score (CRPS) as well as its locally scale-invariant +version (SCRPS) ## Background -The Continuously-Ranked Probability Score (CRPS) is a strictly proper scoring rule. -It assesses how well a distribution with the cumulative distribution function $F$ is explaining an observation $y$ +### Continuously-Ranked Probability Score (CRPS) -$$ \text{CRPS}(F,y) = \int _{\mathbb {R} }(F(x)-\mathbb {1} (x\geq y))^{2}dx \qquad (\text{integral formulation}) $$ +The CRPS is a strictly proper scoring rule. +It assesses how well a distribution with the cumulative distribution function $F(X)$ of the estimate $X$ (a random +variable) is explaining an observation $y$ + +$$ +\text{CRPS}(F,y) = \int _{\mathbb {R}} \left( F(x)-\mathbb {1} (x\geq y) \right)^{2} dx +$$ where $1$ denoted the indicator function. -In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different formulations of the CRPS. +In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different formulations of the CRPS. One of them is + +$$ +\text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] +$$ + +which can be shortened to + +$$ +\text{CRPS}(F, y) = A - 0.5 D +$$ + +where $A$ is called the accuracy term and $D$ is called the disperion term (at least I do it in this repo). + +### Scaled Continuously-Ranked Probability Score (SCRPS) + +The SCRPS is a locally scale-invariant version of the CRPS. +In their [paper][scrps-paper], Bolling & Wallin define it in a positively-oriented, i.e., higher is better. +In contrast, I implement the SCRPS in this repo negatively-oriented, just like a loss function. + +Oversimplifying the notation, the (negatively-oriented) SCRPS can be written as + +$$ +\text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) +$$ + +which can be shortened to + +$$ +\text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \log(D) +$$ + +The scale-invariance, i.e., the SCRPS value does not depend on the magnitude of $D$, comes from the division by $D$. + +Note that the SCRPS can, in contrast to the CRPS, yield negative values. ### Incomplete list of sources that I came across while researching about the CRPS @@ -33,6 +73,7 @@ In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different fo - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 - Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications to Ensemble Weather Forecasts"; 2018 - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019 +- Bollin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2029 - Olivares & Négiar & Ma et al; "CLOVER: Probabilistic Forecasting with Coherent Learning Objective Reparameterization"; 2023 - Vermorel & Tikhonov; "Continuously-Ranked Probability Score (CRPS)" [blog post][Lokad-post]; 2024 - Nvidia; "PhysicsNeMo Framework" [source code][nvidia-crps-implementation]; 2025 @@ -40,8 +81,8 @@ In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different fo ## Application to Machine Learning -The CRPS can be used as a loss function in machine learning, just like the well-known negative log-likelihood loss which -is the log scoring rule. +The CRPS, as well as the SCRPS, can be used as a loss function in machine learning, just like the well-known negative +log-likelihood loss which is the log scoring rule. The parametrized model outputs a distribution $q(x)$. The CRPS loss evaluates how good $q(x)$ is explaining the observation $y$. @@ -54,7 +95,15 @@ There is [work on multi-variate CRPS estimation][multivariate-crps], but it is n ## Implementation -The integral formulation is infeasible to naively evaluate on a computer due to the infinite integration over $x$. +The direct implementation of the integral formulation is not suited to evaluate on a computer due to the infinite +integration over the domain of the random variable $X$. +Nevertheless, this repository includes such an implementation to verify the others. + +The normalization-by-observation variants are improper solutions to normalize the CPRS values. The goal is to use the +CPRS as a loss function in machine learning tasks. For that, it is highly beneficial if the loss does not depend on +the scale of the problem. +However, deviding by the absolute maximum of the observations is a bad proxy for doing this. +I plan on removing these methods once I gained trust in my SCRPS implementation. I found [Nvidia's implementation][nvidia-crps-implementation] of the CRPS for ensemble preductions in $M log(M)$ time inspiring to read. @@ -89,6 +138,7 @@ inspiring to read. [uv]: https://docs.astral.sh/uv [crps-folumations]: https://link.springer.com/article/10.1007/s11004-017-9709-7 +[scrps-paper]: https://arxiv.org/abs/1912.05642 [Lokad-post]: https://www.lokad.com/continuous-ranked-probability-score/ [multivariate-crps]: https://arxiv.org/pdf/2410.09133 [nvidia-crps-implementation]: https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html diff --git a/tests/conftest.py b/tests/conftest.py index 633a7cd..ed0ce0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,9 @@ import pytest import torch +from torch.distributions import Normal, StudentT + +from torch_crps.analytical.studentt import standardized_studentt_cdf_via_scipy results_dir = Path(__file__).parent / "results" results_dir.mkdir(parents=True, exist_ok=True) @@ -49,3 +52,97 @@ def case_batched_3d(): "y": torch.randn(2, 3) * 10 + 50, "expected_shape": torch.Size([2, 3]), } + + +def crps_analytical_normal_gneiting( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the analytical CRPS assuming a normal distribution. + + See Also: + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 + Equation (5) for the analytical formula for CRPS of Normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + # Compute standard normal CDF and PDF. + z = (y - q.loc) / q.scale + standard_normal = torch.distributions.Normal(0, 1) + phi_z = standard_normal.cdf(z) # Φ(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) + + # Analytical CRPS formula. + crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) + + return crps + + +def crps_analytical_studentt_jordan( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + + This is the previous implementation of the analytical CRPS for StudentT distributions.. It is provided here for + testing and comparison purposes. + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + # Extract degrees of freedom ν, location μ, and scale σ. + nu, mu, sigma = q.df, q.loc, q.scale + if torch.any(nu <= 1): + raise ValueError("StudentT CRPS requires degrees of freedom > 1") + + # Standardize, and create standard StudentT distribution for CDF and PDF. + z = (y - mu) / sigma + standard_t = torch.distributions.StudentT(nu, loc=0, scale=1) + + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) + + # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 + # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) + # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) + # For numerical stability, we compute in log space. + log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_gamma_df_minus_half = torch.lgamma(nu - 0.5) + log_gamma_df_half = torch.lgamma(nu / 2) + log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + + # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) + # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) + # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) + log_beta_ratio = ( + log_gamma_half + + log_gamma_df_minus_half + - torch.lgamma(nu) + - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) + ) + beta_frac = torch.exp(log_beta_ratio) + + # Compute the CRPS for standardized values. + crps_standard = ( + z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac + ) + + # Apply location-scale transformation CRPS(F_{ν,μ,σ}, y) = σ * CRPS(F_{ν}, z) with z = (y - μ) / σ. + crps = sigma * crps_standard + + return crps diff --git a/tests/test_analytical.py b/tests/test_analytical.py new file mode 100644 index 0000000..69ebb91 --- /dev/null +++ b/tests/test_analytical.py @@ -0,0 +1,173 @@ +from typing import Any, Callable + +import pytest +import torch +from torch.distributions import Normal, StudentT +from typing_extensions import Literal + +from tests.conftest import crps_analytical_normal_gneiting, crps_analytical_studentt_jordan, needs_cuda +from torch_crps.analytical import crps_analytical, scrps_analytical +from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal +from torch_crps.analytical.studentt import ( + crps_analytical_studentt, +) + + +@pytest.mark.parametrize( + "use_cuda", + [ + pytest.param(False, id="cpu"), + pytest.param(True, marks=needs_cuda, id="cuda"), + ], +) +@pytest.mark.parametrize("crps_fcn", [crps_analytical_normal, scrps_analytical_normal], ids=["CRPS", "SCRPS"]) +def test_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable[..., torch.Tensor]): + """Test that analytical solution works with batched Normal distributions.""" + torch.manual_seed(0) + + # Define a batch of 2 independent univariate Normal distributions. + mu = torch.tensor([[0.0, 1.0], [2.0, 3.0], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu") + sigma = torch.tensor([[1.0, 0.5], [1.5, 2.0], [0.01, 0.01]], device="cuda" if use_cuda else "cpu") + normal_dist = torch.distributions.Normal(loc=mu, scale=sigma) + + # Define observed values for each distribution in the batch. + y = torch.tensor([[0.5, 1.5], [2.5, 3.5], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu") + + # Compute CRPS using the analytical method. + crps_analytical = crps_fcn(normal_dist, y) + + # Simple sanity check: CRPS should be non-negative. + assert crps_analytical.shape == y.shape, "CRPS output shape should match input shape." + assert crps_analytical.dtype in [torch.float32, torch.float64], "CRPS output dtype should be float." + assert crps_analytical.device == y.device, "CRPS output device should match input device." + if crps_fcn == crps_analytical_normal: + assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative." + + +@pytest.mark.parametrize( + "loc, scale", + [ + (torch.tensor(0.0), torch.tensor(1.0)), + (torch.tensor(-1.0), torch.tensor(0.5)), + (torch.tensor(1.0), torch.tensor(0.5)), + (torch.tensor(10.0), torch.tensor(20.0)), + (torch.tensor(-10.0), torch.tensor(20.0)), + (torch.tensor(100.0), torch.tensor(5.0)), + (torch.tensor(-100.0), torch.tensor(5.0)), + ], + ids=[ + "standard", + "small-neg-mean_small-var", + "small-pos-mean_small-var", + "pos-mean_large-var", + "neg-mean_large-var", + "large-mean_medium-var", + "large-neg-mean_medium-var", + ], +) +@pytest.mark.parametrize("y", [torch.tensor([-100.0, -10.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50.0])]) +@pytest.mark.parametrize("crps_fcn_type", ["CRPS", "SCRPS"], ids=["CRPS", "SCRPS"]) +def test_studentt_convergence_to_normal( + loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, crps_fcn_type: Literal["CRPS", "SCRPS"] +): + """Test that for a high degrees of freedom, the StudentT score converges to the Normal score + when their standard deviations are matched. + """ + # Create the StudentT distribution with a high degree of freedom. + high_df = torch.tensor(1000.0) + q_studentt = StudentT(df=high_df, loc=loc, scale=scale) + + # Calculate the standard deviation of the StudentT distribution. The variance is (df / (df - 2)) * scale^2 + student_t_std_dev = scale * torch.sqrt(high_df / (high_df - 2)) + + # Create the Normal distribution with matching standard deviation. + q_normal = Normal(loc=loc, scale=student_t_std_dev) + + # Calculate the analytical scores for both. + if crps_fcn_type == "CRPS": + score_value_studentt = crps_analytical(q_studentt, y) + score_value_normal = crps_analytical(q_normal, y) + else: + score_value_studentt = scrps_analytical(q_studentt, y) + score_value_normal = scrps_analytical(q_normal, y) + + # Assert that their results are nearly identical. + # The tolerance can be quite tight now. + atol = 6e-3 if crps_fcn_type == "CRPS" else 2e-2 + assert torch.allclose(score_value_studentt, score_value_normal, atol=atol), ( + f"StudentT {crps_fcn_type} with high 'df' should match Normal {crps_fcn_type}." + ) + + +@pytest.mark.parametrize( + "q", + [ + torch.distributions.Normal(loc=torch.zeros(3), scale=torch.ones(3)), + torch.distributions.StudentT(df=5, loc=torch.zeros(3), scale=torch.ones(3)), + "NOT_A_SUPPORTED_DISTRIBUTION", + ], + ids=["Normal", "StudentT", "not_supported"], +) +@pytest.mark.parametrize("crps_fcn", [crps_analytical, scrps_analytical], ids=["CRPS", "SCRPS"]) +def test_analytical_interface_smoke(q: Any, crps_fcn: Callable[..., torch.Tensor]): # noqa: ANN401 + """Test if the top-level interface function is working""" + y = torch.zeros(3) # can be the same for all tests + + if isinstance(q, (Normal, StudentT)): + # Supported, should return a result. + crps = crps_fcn(q, y) + assert isinstance(crps, torch.Tensor) + + else: + # Not supported, should raise an error. + with pytest.raises(NotImplementedError): + crps_fcn(q, y) + + +def test_analytical_crps_normal_consistency(): + """Test if the two ways to compute the CRPS for normal distributions give the same result: + + - old method: `crps_analytical_normal_gneiting` + - new method: `_accuracy_normal_gneiting` and `_dispersion_normal_gneiting` packaged in `crps_analytical_normal` + """ + torch.manual_seed(0) + + # Create a Normal distribution. + loc = torch.tensor([0.0, 1.0, -1.0]) + scale = torch.tensor([1.0, 2.0, 0.5]) + normal_dist = torch.distributions.Normal(loc=loc, scale=scale) + + # Define observed values. + y = torch.tensor([0.5, 2.0, -0.5]) + + # Compute CRPS values. + crps_old = crps_analytical_normal_gneiting(normal_dist, y) + crps_new = crps_analytical_normal(normal_dist, y) + + # Assert that both methods give the same result. + assert torch.allclose(crps_old, crps_new, atol=1e-6), "CRPS values from both methods should match." + + +def test_analytical_crps_studentt_consistency(): + """Test if the two ways to compute the CRPS for StudentT distributions give the same result: + + - old method: `_crps_analytical_studentt_jordan` + - new method: `_accuracy_studentt_jordan` and `_dispersion_studentt_jordan` packaged in `crps_analytical_studentt` + """ + torch.manual_seed(0) + + # Create a StudentT distribution. + df = torch.tensor([3.0, 5.0, 10.0]) + loc = torch.tensor([0.0, 1.0, -1.0]) + scale = torch.tensor([1.0, 2.0, 0.5]) + studentt_dist = torch.distributions.StudentT(df=df, loc=loc, scale=scale) + + # Define observed values. + y = torch.tensor([0.5, 2.0, -0.5]) + + # Compute CRPS values. + crps_old = crps_analytical_studentt_jordan(studentt_dist, y) + crps_new = crps_analytical_studentt(studentt_dist, y) + + # Assert that both methods give the same result. + assert torch.allclose(crps_old, crps_new, atol=1e-6), "CRPS values from both methods should match." diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py deleted file mode 100644 index a88bf29..0000000 --- a/tests/test_analytical_crps.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any - -import pytest -import torch -from torch.distributions import Normal, StudentT - -from tests.conftest import needs_cuda -from torch_crps import crps_analytical, crps_analytical_normal, crps_analytical_studentt - - -@pytest.mark.parametrize( - "use_cuda", - [ - pytest.param(False, id="cpu"), - pytest.param(True, marks=needs_cuda, id="cuda"), - ], -) -def test_crps_analytical_normal_batched_smoke(use_cuda: bool): - """Test that analytical solution works with batched Normal distributions.""" - torch.manual_seed(0) - - # Define a batch of 2 independent univariate Normal distributions. - mu = torch.tensor([[0.0, 1.0], [2.0, 3.0], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu") - sigma = torch.tensor([[1.0, 0.5], [1.5, 2.0], [0.01, 0.01]], device="cuda" if use_cuda else "cpu") - normal_dist = torch.distributions.Normal(loc=mu, scale=sigma) - - # Define observed values for each distribution in the batch. - y = torch.tensor([[0.5, 1.5], [2.5, 3.5], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu") - - # Compute CRPS using the analytical method. - crps_analytical = crps_analytical_normal(normal_dist, y) - - # Simple sanity check: CRPS should be non-negative. - assert crps_analytical.shape == y.shape, "CRPS output shape should match input shape." - assert crps_analytical.dtype in [torch.float32, torch.float64], "CRPS output dtype should be float." - assert crps_analytical.device == y.device, "CRPS output device should match input device." - assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative." - - -@pytest.mark.parametrize( - "loc, scale", - [ - (torch.tensor(0.0), torch.tensor(1.0)), - (torch.tensor(2.0), torch.tensor(0.5)), - (torch.tensor(-5.0), torch.tensor(10.0)), - ], - ids=["standard", "shifted_scaled", "neg-mean_large-var"], -) -@pytest.mark.parametrize("y", [torch.tensor([-10.0, -1.0, 0.0, 0.5, 2.0, 5.0])]) -def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor): - """Test that for a very high degrees of freedom, the StudentT CRPS converges to the Normal CRPS. - This validates both implementations against each other. - """ - # Create the two distributions with identical parameters. - high_df = torch.tensor(1000.0) - q_studentt = StudentT(df=high_df, loc=loc, scale=scale) - q_normal = Normal(loc=loc, scale=scale) - - # Calculate the analytical CRPS for both. - crps_studentt = crps_analytical_studentt(q_studentt, y) - crps_normal = crps_analytical_normal(q_normal, y) - - # Assert that their results are nearly identical. - assert torch.allclose(crps_studentt, crps_normal, atol=2e-3), ( - "StudentT CRPS with high 'df' should match Normal CRPS." - ) - - -@pytest.mark.parametrize( - "q", - [ - torch.distributions.Normal(loc=torch.zeros(3), scale=torch.ones(3)), - torch.distributions.StudentT(df=5, loc=torch.zeros(3), scale=torch.ones(3)), - "NOT_A_SUPPORTED_DISTRIBUTION", - ], - ids=["Normal", "StudentT", "not_supported"], -) -def test_crps_analytical_interface_smoke(q: Any): # noqa: ANN401 - """Test if the top-level interface function is working""" - y = torch.zeros(3) # can be the same for all tests - - if isinstance(q, (Normal, StudentT)): - # Supported, should return a result. - crps = crps_analytical(q, y) - assert isinstance(crps, torch.Tensor) - - else: - # Not supported, should raise an error. - with pytest.raises(NotImplementedError): - crps_analytical(q, y) diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py new file mode 100644 index 0000000..9859374 --- /dev/null +++ b/tests/test_ensemble.py @@ -0,0 +1,144 @@ +import math +from collections.abc import Callable + +import pytest +import torch +from _pytest.fixtures import FixtureRequest + +from tests.conftest import needs_cuda +from torch_crps.ensemble import crps_ensemble, crps_ensemble_naive, scrps_ensemble + + +@pytest.mark.parametrize( + "test_case_fixture_name", + ["case_flat_1d", "case_batched_2d", "case_batched_3d"], + ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], +) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +@pytest.mark.parametrize("crps_fcn", [crps_ensemble_naive, crps_ensemble], ids=["naive", "default"]) +@pytest.mark.parametrize( + "use_cuda", + [ + pytest.param(False, id="cpu"), + pytest.param(True, marks=needs_cuda, id="cuda"), + ], +) +def test_ensemble_smoke( + test_case_fixture_name: str, + crps_fcn: Callable[[torch.Tensor, torch.Tensor, bool], torch.Tensor], + biased: bool, + use_cuda: bool, + request: FixtureRequest, +): + """Test that naive ensemble method yield.""" + test_case_fixture: dict = request.getfixturevalue(test_case_fixture_name) + x, y, expected_shape = test_case_fixture["x"], test_case_fixture["y"], test_case_fixture["expected_shape"] + if use_cuda: + x, y = x.cuda(), y.cuda() + + crps = crps_fcn(x, y, biased) + + assert isinstance(crps, torch.Tensor) + assert crps.shape == expected_shape, "The output shape is incorrect!" + assert crps.dtype in [torch.float32, torch.float64], "The output dtype is not float!" + assert crps.device == x.device, "The output device does not match the input device!" + assert torch.all(crps >= 0), "CRPS values should be non-negative!" + + +@pytest.mark.parametrize( + "batch_shape", + [(), (3,), (3, 5)], + ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], +) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +def test_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ensemble: int = 10): + """Test that both implementations of crps_ensemble yield the same result.""" + torch.manual_seed(0) + + # Create a random ensemble forecast and observation. + if len(batch_shape) > 0: + x = torch.randn(*batch_shape, dim_ensemble) + y = torch.randn(*batch_shape) + else: + x = torch.randn(dim_ensemble) + y = torch.randn(batch_shape) + + crps_naive = crps_ensemble_naive(x, y, biased) + crps_default = crps_ensemble(x, y, biased) + + # Assert that both methods agree within numerical tolerance. + assert torch.allclose(crps_naive, crps_default, atol=1e-8, rtol=1e-6), ( + f"CRPS values do not match: naive={crps_naive}, default={crps_default}" + ) + + +@pytest.mark.parametrize("crps_fcn", [crps_ensemble, scrps_ensemble], ids=["CRPS", "SCRPS"]) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +def test_ensemble_invalid_shapes( + crps_fcn: Callable[[torch.Tensor, torch.Tensor, bool], torch.Tensor], biased: bool, dim_ensemble: int = 10 +): + """Test that crps_ensemble raises an error for invalid input shapes.""" + # Mismatch in the number of batch dimensions. + x = torch.randn(2, 3, dim_ensemble) + y = torch.randn(3) + with pytest.raises(ValueError): + crps_fcn(x, y, biased) + + # Mismatch in batch dimension sizes. + x = torch.randn(4, 5, dim_ensemble) + y = torch.randn(4, 6) + with pytest.raises(ValueError): + crps_fcn(x, y, biased) + + +def test_ensemble_scrps_nonnegativity(num_samples: int = 100, dim_ensemble: int = 50): + """Test that the SCRPS can have negative values (in contrast to the CRPS).""" + torch.manual_seed(0) + + # Create a random ensemble forecast with small dispersion, and observations. + x = 1e-1 * torch.randn(num_samples, dim_ensemble) + y = torch.randn(num_samples) + + scrps_values = scrps_ensemble(x, y, False) + + assert torch.any(scrps_values < 0), "SCRPS should have some negative values!" + + +@pytest.mark.parametrize("seed", [0, 1, 2, 3, 4], ids=["seed_0", "seed_1", "seed_2", "seed_3", "seed_4"]) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +@pytest.mark.parametrize("scale_factor", [1e1, 1e3, 1e6, 1e9], ids=["1e1", "1e3", "1e6", "1e9"]) +def test_ensemble_scrps_scale_invariance( + seed: int, + biased: bool, + scale_factor: float, + num_samples: int = 100, + dim_ensemble: int = 50, +): + """Test that the SCRPS is locally scale-invariant (to a certain extent).""" + torch.manual_seed(seed) + rtol = 0.5 * math.log(scale_factor) # just an educated guess + + # Create a random ensemble forecast and observations. + x = torch.randn(num_samples, dim_ensemble) + y = torch.randn(num_samples) + scrps_original = scrps_ensemble(x, y, biased) + + # Scale the ensemble forecasts by a factor of 1000. + x_scaled = scale_factor * x + scrps_scaled = scrps_ensemble(x_scaled, y, biased) + + # Assert that the SCRPS values are approximately scale-invariant. + assert torch.allclose(scrps_original, scrps_scaled, rtol=rtol), ( + f"The SCRPS values are not scale-invariant as expected, i.e., scaling the forecasts by {scale_factor} " + f"should not change the SCRPS values by more than a factor of {rtol}, but it did." + ) + + # Scale the observations by a factor of 1000. + y_scaled = scale_factor * y + scrps_scaled = scrps_ensemble(x, y_scaled, biased) + + # Assert that the SCRPS values are approximately scale-invariant. + assert torch.allclose(scrps_original, scrps_scaled, rtol=rtol), ( + f"The SCRPS values are not scale-invariant as expected, i.e., scaling the observations by {scale_factor} " + f"should not change the SCRPS values by more than a factor of {rtol}, but it did." + ) diff --git a/tests/test_ensemble_crps.py b/tests/test_ensemble_crps.py deleted file mode 100644 index 5b0db05..0000000 --- a/tests/test_ensemble_crps.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections.abc import Callable - -import pytest -import torch -from _pytest.fixtures import FixtureRequest - -from tests.conftest import needs_cuda -from torch_crps import crps_ensemble, crps_ensemble_naive - - -@pytest.mark.parametrize( - "test_case_fixture_name", - ["case_flat_1d", "case_batched_2d", "case_batched_3d"], - ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], -) -@pytest.mark.parametrize("crps_fcn", [crps_ensemble_naive, crps_ensemble], ids=["naive", "default"]) -@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) -@pytest.mark.parametrize( - "use_cuda", - [ - pytest.param(False, id="cpu"), - pytest.param(True, marks=needs_cuda, id="cuda"), - ], -) -def test_crps_ensemble_smoke( - test_case_fixture_name: str, crps_fcn: Callable, biased: bool, use_cuda: bool, request: FixtureRequest -): - """Test that naive ensemble method yield.""" - test_case_fixture: dict = request.getfixturevalue(test_case_fixture_name) - x, y, expected_shape = test_case_fixture["x"], test_case_fixture["y"], test_case_fixture["expected_shape"] - if use_cuda: - x, y = x.cuda(), y.cuda() - - crps = crps_fcn(x, y, biased) - - assert isinstance(crps, torch.Tensor) - assert crps.shape == expected_shape, "The output shape is incorrect!" - assert crps.dtype in [torch.float32, torch.float64], "The output dtype is not float!" - assert crps.device == x.device, "The output device does not match the input device!" - assert torch.all(crps >= 0), "CRPS values should be non-negative!" - - -@pytest.mark.parametrize( - "batch_shape", - [(), (3,), (3, 5)], - ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], -) -@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) -def test_crps_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ensemble: int = 10): - """Test that both implementations of crps_ensemble yield the same result.""" - torch.manual_seed(0) - - # Create a random ensemble forecast and observation. - if len(batch_shape) > 0: - x = torch.randn(*batch_shape, dim_ensemble) - y = torch.randn(*batch_shape) - else: - x = torch.randn(dim_ensemble) - y = torch.randn(batch_shape) - - crps_naive = crps_ensemble_naive(x, y, biased) - crps_default = crps_ensemble(x, y, biased) - - # Assert that both methods agree within numerical tolerance. - assert torch.allclose(crps_naive, crps_default, atol=1e-8, rtol=1e-6), ( - f"CRPS values do not match: naive={crps_naive}, default={crps_default}" - ) - - -def test_crps_ensemble_invalid_shapes(dim_ensemble: int = 10): - """Test that crps_ensemble raises an error for invalid input shapes.""" - # Mismatch in the number of batch dimensions. - x = torch.randn(2, 3, dim_ensemble) - y = torch.randn(3) - with pytest.raises(ValueError): - crps_ensemble(x, y) - - # Mismatch in batch dimension sizes. - x = torch.randn(4, 5, dim_ensemble) - y = torch.randn(4, 6) - with pytest.raises(ValueError): - crps_ensemble(x, y) diff --git a/tests/test_integral_crps.py b/tests/test_integral.py similarity index 91% rename from tests/test_integral_crps.py rename to tests/test_integral.py index 2906ee5..4d97c05 100644 --- a/tests/test_integral_crps.py +++ b/tests/test_integral.py @@ -2,7 +2,9 @@ import torch from tests.conftest import needs_cuda -from torch_crps import crps_analytical_normal, crps_analytical_studentt, crps_integral +from torch_crps import crps_integral +from torch_crps.analytical.normal import crps_analytical_normal +from torch_crps.analytical.studentt import crps_analytical_studentt @pytest.mark.parametrize( @@ -12,7 +14,7 @@ pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_integral_vs_analytical_normal(use_cuda: bool): +def test_integral_vs_analytical_normal(use_cuda: bool): """Test that naive integral method matches the analytical solution for Normal distributions.""" torch.manual_seed(0) @@ -47,7 +49,7 @@ def test_crps_integral_vs_analytical_normal(use_cuda: bool): pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_integral_vs_analytical_studentt(use_cuda: bool): +def test_integral_vs_analytical_studentt(use_cuda: bool): """Test that naive integral method matches the analytical solution for StudentT distributions.""" torch.manual_seed(0) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 209f48a..274a45b 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -4,41 +4,45 @@ import torch from torch_crps import ( - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ) @pytest.mark.parametrize( "wrapped_crps_fcn", [ - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ], ids=[ - "crps_analytical_normal_normalized", - "crps_analytical_normalized", - "crps_analytical_studentt_normalized", - "crps_ensemble_normalized", - "crps_integral_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", + "crps_analytical_studentt_obsnormalized", + "crps_ensemble_obsnormalized", + "crps_integral_obsnormalized", ], ) -def test_nomrmalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: int = 3): +def test_normalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: int = 3): """Test if the normalization wrapper handles the underlying function's arguments correctly.""" torch.manual_seed(0) # Set up test cases. - if wrapped_crps_fcn in (crps_analytical_normal_normalized, crps_analytical_normalized, crps_integral_normalized): + if wrapped_crps_fcn in ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_integral_obsnormalized, + ): q = torch.distributions.Normal(loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_analytical_studentt_normalized: + elif wrapped_crps_fcn == crps_analytical_studentt_obsnormalized: q = torch.distributions.StudentT(df=5 * torch.ones(num_y), loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_ensemble_normalized: + elif wrapped_crps_fcn == crps_ensemble_obsnormalized: q = torch.randn(num_y, 10) # dim_ensemble = 10 else: raise NotImplementedError("Test case setup error.") @@ -69,31 +73,35 @@ def test_nomrmalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: @pytest.mark.parametrize( "wrapped_crps_fcn", [ - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ], ids=[ - "crps_analytical_normal_normalized", - "crps_analytical_normalized", - "crps_analytical_studentt_normalized", - "crps_ensemble_normalized", - "crps_integral_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", + "crps_analytical_studentt_obsnormalized", + "crps_ensemble_obsnormalized", + "crps_integral_obsnormalized", ], ) @pytest.mark.parametrize("num_y", [1, 5, 100], ids=["1_obs", "5_obs", "100_obs"]) -def test_nomrmalization_wrapper_output_consistency(wrapped_crps_fcn: Callable, num_y: int): +def test_normalization_wrapper_output_consistency(wrapped_crps_fcn: Callable, num_y: int): """Test if the normalization wrapper results in normalized CRPS values.""" torch.manual_seed(0) # Set up test cases. - if wrapped_crps_fcn in (crps_analytical_normal_normalized, crps_analytical_normalized, crps_integral_normalized): + if wrapped_crps_fcn in ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_integral_obsnormalized, + ): q = torch.distributions.Normal(loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_analytical_studentt_normalized: + elif wrapped_crps_fcn == crps_analytical_studentt_obsnormalized: q = torch.distributions.StudentT(df=5 * torch.ones(num_y), loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_ensemble_normalized: + elif wrapped_crps_fcn == crps_ensemble_obsnormalized: q = torch.randn(num_y, 10) # dim_ensemble = 10 else: raise NotImplementedError("Test case setup error.") diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index 4062956..f213f83 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -1,28 +1,30 @@ -from .analytical_crps import ( - crps_analytical, - crps_analytical_normal, - crps_analytical_studentt, +from .analytical import crps_analytical, scrps_analytical +from .analytical.normal import crps_analytical_normal, scrps_analytical_normal +from .analytical.studentt import crps_analytical_studentt, scrps_analytical_studentt +from .ensemble import crps_ensemble, crps_ensemble_naive, scrps_ensemble +from .integral import crps_integral +from .normalization import ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ) -from .ensemble_crps import crps_ensemble, crps_ensemble_naive -from .integral_crps import crps_integral -from .normalization import normalize_by_observation - -crps_analytical_normalized = normalize_by_observation(crps_analytical) -crps_analytical_normal_normalized = normalize_by_observation(crps_analytical_normal) -crps_analytical_studentt_normalized = normalize_by_observation(crps_analytical_studentt) -crps_ensemble_normalized = normalize_by_observation(crps_ensemble) -crps_integral_normalized = normalize_by_observation(crps_integral) __all__ = [ "crps_analytical", "crps_analytical_normal", - "crps_analytical_normal_normalized", - "crps_analytical_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", "crps_analytical_studentt", - "crps_analytical_studentt_normalized", + "crps_analytical_studentt_obsnormalized", "crps_ensemble", "crps_ensemble_naive", - "crps_ensemble_normalized", + "crps_ensemble_obsnormalized", "crps_integral", - "crps_integral_normalized", + "crps_integral_obsnormalized", + "scrps_analytical", + "scrps_analytical_normal", + "scrps_analytical_studentt", + "scrps_ensemble", ] diff --git a/torch_crps/abstract.py b/torch_crps/abstract.py new file mode 100644 index 0000000..1c387ab --- /dev/null +++ b/torch_crps/abstract.py @@ -0,0 +1,27 @@ +import torch + + +def crps_abstract(accuracy: torch.Tensor, dispersion: torch.Tensor) -> torch.Tensor: + """High-level function to compute the CRPS from the accuracy and dispersion terms. + + Args: + accuracy: The accuracy term A, independent of the methods used to compute it, of shape (*batch_shape,). + dispersion: The dispersion term D, independent of the methods used to compute it, of shape (*batch_shape,). + + Returns: + The CRPS value for each forecast in the batch, of shape (*batch_shape,). + """ + return accuracy - 0.5 * dispersion + + +def scrps_abstract(accuracy: torch.Tensor, dispersion: torch.Tensor) -> torch.Tensor: + """High-level function to compute the SCRPS from the accuracy and dispersion terms. + + Args: + accuracy: The accuracy term A, independent of the methods used to compute it, of shape (*batch_shape,). + dispersion: The dispersion term D, independent of the methods used to compute it, of shape (*batch_shape,). + + Returns: + The SCRPS value for each forecast in the batch, of shape (*batch_shape,). + """ + return accuracy / dispersion + 0.5 * torch.log(dispersion) diff --git a/torch_crps/analytical/__init__.py b/torch_crps/analytical/__init__.py new file mode 100644 index 0000000..4d864e8 --- /dev/null +++ b/torch_crps/analytical/__init__.py @@ -0,0 +1,3 @@ +from .dispatch import crps_analytical, scrps_analytical # noqa: F401 +from .normal import crps_analytical_normal, scrps_analytical_normal # noqa: F401 +from .studentt import crps_analytical_studentt, scrps_analytical_studentt # noqa: F401 diff --git a/torch_crps/analytical/dispatch.py b/torch_crps/analytical/dispatch.py new file mode 100644 index 0000000..cccb37e --- /dev/null +++ b/torch_crps/analytical/dispatch.py @@ -0,0 +1,68 @@ +import torch +from torch.distributions import Distribution, Normal, StudentT + +from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal +from torch_crps.analytical.studentt import ( + crps_analytical_studentt, + scrps_analytical_studentt, +) + + +def crps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form. + + Note: + The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. + There exists analytical solutions for other distributions, but they are not implemented, yet. + Feel free to create an issue or pull request. + + Args: + q: A PyTorch distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + if isinstance(q, Normal): + return crps_analytical_normal(q, y) + elif isinstance(q, StudentT): + return crps_analytical_studentt(q, y) + else: + raise NotImplementedError( + f"Detected distribution of type {type(q)}, but there are only analytical solutions for " + "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " + "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." + ) + + +def scrps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) Scaled CRPS (SCRPS) in closed-form. + + Note: + The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. + There exists analytical solutions for other distributions, but they are not implemented, yet. + Feel free to create an issue or pull request. + + Args: + q: A PyTorch distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + if isinstance(q, Normal): + return scrps_analytical_normal(q, y) + elif isinstance(q, StudentT): + return scrps_analytical_studentt(q, y) + else: + raise NotImplementedError( + f"Detected distribution of type {type(q)}, but there are only analytical solutions for " + "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " + "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need." + ) diff --git a/torch_crps/analytical/normal.py b/torch_crps/analytical/normal.py new file mode 100644 index 0000000..062c1d4 --- /dev/null +++ b/torch_crps/analytical/normal.py @@ -0,0 +1,100 @@ +import torch +from torch.distributions import Normal + +from torch_crps.abstract import crps_abstract, scrps_abstract + + +def _accuracy_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute accuracy term A = E[|X - y|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + z = (y - q.loc) / q.scale + standard_normal = torch.distributions.Normal(0, 1) + + cdf_z = standard_normal.cdf(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) + + return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) + + +def _dispersion_normal( + q: Normal, +) -> torch.Tensor: + """Compute dispersion term D = E[|X - X'|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype)) + + return 2 * q.scale / sqrt_pi + + +def crps_analytical_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. + + See Also: + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. + Equation (5) for the analytical formula for CRPS of Normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q) + + return crps_abstract(accuracy, dispersion) + + +def scrps_analytical_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution. + + $$ + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) + $$ + + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. + + Note: + In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + Equation (3) for the definition of the SCRPS. + Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q) + + return scrps_abstract(accuracy, dispersion) diff --git a/torch_crps/analytical/studentt.py b/torch_crps/analytical/studentt.py new file mode 100644 index 0000000..e8dba89 --- /dev/null +++ b/torch_crps/analytical/studentt.py @@ -0,0 +1,205 @@ +import torch +from torch.distributions import StudentT + +from torch_crps.abstract import crps_abstract, scrps_abstract + + +def standardized_studentt_cdf_via_scipy( + z: torch.Tensor, + nu: torch.Tensor | float, +) -> torch.Tensor: + """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has + a stable implementation. + + Note: + - The inputs `z` must be standardized. + - This breaks differentiability and requires to move tensors to the CPU. + + Args: + z: Standardized values at which to evaluate the CDF. + nu: Degrees of freedom of the StudentT distribution. + + Returns: + CDF values of the standardized StudentT distribution at `z`. + """ + try: + from scipy.stats import t as scipy_student_t + except ImportError as e: + raise ImportError( + "scipy is required for the analytical solution for the StudentT distribution. " + "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`." + ) from e + + z_np = z.detach().cpu().numpy() + nu_np = nu.detach().cpu().numpy() if isinstance(nu, torch.Tensor) else nu + + cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np) + + return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) + + +def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: + r"""Computes the accuracy term $A = E[|Y - y|]$ for the Student-T distribution. + + $$ + A = \sigma \left[ z(2F_{\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{\nu}(z) \right] + $$ + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + nu, mu, sigma = q.df, q.loc, q.scale + + # Standardize, and create standard StudentT distribution for CDF and PDF. + z = (y - mu) / sigma + standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma)) + + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) + + # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ] + accuracy_unscaled = z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) + + accuracy = sigma * accuracy_unscaled + return accuracy + + +def _dispersion_studentt( + q: StudentT, +) -> torch.Tensor: + r"""Computes the dispersion term $D = E[|Y - Y'|]$ for the Student-T distribution. + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + nu, sigma = q.df, q.scale + + # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 + # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) + # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) + # For numerical stability, we compute in log space. + log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_gamma_df_minus_half = torch.lgamma(nu - 0.5) + log_gamma_df_half = torch.lgamma(nu / 2) + log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + + # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) + # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) + # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) + log_beta_ratio = ( + log_gamma_half + + log_gamma_df_minus_half + - torch.lgamma(nu) + - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) + ) + beta_frac = torch.exp(log_beta_ratio) + + # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac + dispersion = 2 * sigma * 2 * torch.sqrt(nu) / (nu - 1) * beta_frac + + return dispersion + + +def crps_analytical_studentt( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + + This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. + + For the standardized StudentT distribution: + + $$ + \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} + - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} + $$ + + where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT + distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. + + For the location-scale transformed distribution: + + $$ + \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) + $$ + + where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. + + Note: + This formula is only valid for degrees of freedom $\nu > 1$. + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + if torch.any(q.df <= 1): + raise ValueError("StudentT SCRPS requires degrees of freedom > 1") + + accuracy = _accuracy_studentt(q, y) + dispersion = _dispersion_studentt(q) + + return crps_abstract(accuracy, dispersion) + + +def scrps_analytical_studentt( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. + + $$ + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) + $$ + + where: + + - $F_{\nu, \mu, \sigma^2}$ is the cumulative Student-T distribution, and $F_{\nu}$ is the standardized version. + - $A = E_F[|X - y|]$ is the accuracy term. + - $A = \sigma [ z(2 F_{\nu}(z) - 1) + 2(\nu + z²) / (\nu*B(\nu/2, 1/2)) * F_{\nu+1}(z * \sqrt{(\nu+1)/(\nu+z²)}) ]$ + - $D = E_F[|X - X'|]$ is the dispersion term. + - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2$ + + Note: + This formula is only valid for degrees of freedom $\nu > 1$. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + if torch.any(q.df <= 1): + raise ValueError("StudentT SCRPS requires degrees of freedom > 1") + + accuracy = _accuracy_studentt(q, y) + dispersion = _dispersion_studentt(q) + + return scrps_abstract(accuracy, dispersion) diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py deleted file mode 100644 index 9492997..0000000 --- a/torch_crps/analytical_crps.py +++ /dev/null @@ -1,172 +0,0 @@ -import torch -from torch.distributions import Distribution, Normal, StudentT - - -def crps_analytical(q: Distribution, y: torch.Tensor) -> torch.Tensor: - """Compute the analytical CRPS. - - Note: - The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. - There exists analytical solutions for other distributions, but they are not implemented, yet. - Feel free to create an issue or pull request. - - Args: - q: A PyTorch distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - if isinstance(q, Normal): - return crps_analytical_normal(q, y) - elif isinstance(q, StudentT): - return crps_analytical_studentt(q, y) - else: - raise NotImplementedError( - f"Detected distribution of type {type(q)}, but there are only analytical solutions for " - "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " - "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." - ) - - -def crps_analytical_normal( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the analytical CRPS assuming a normal distribution. - - See Also: - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 - Equation (5) for the analytical formula for CRPS of Normal distribution. - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - # Compute standard normal CDF and PDF. - z = (y - q.loc) / q.scale # standardize - standard_normal = torch.distributions.Normal(0, 1) - phi_z = standard_normal.cdf(z) # Φ(z) - pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) - - # Analytical CRPS formula. - crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) - - return crps - - -def standardized_studentt_cdf_via_scipy(z: torch.Tensor, df: torch.Tensor | float) -> torch.Tensor: - """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has - a stable implementation. - - Note: - - The inputs `z` must be standardized. - - This breaks differentiability and requires to move tensors to the CPU. - - Args: - z: Standardized values at which to evaluate the CDF. - df: Degrees of freedom of the StudentT distribution. - - Returns: - CDF values of the standardized StudentT distribution at `z`. - """ - try: - from scipy.stats import t as scipy_student_t - except ImportError as e: - raise ImportError( - "scipy is required for the analytical solution for the StudentT distribution. " - "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`." - ) from e - - z_np = z.detach().cpu().numpy() - df_np = df.detach().cpu().numpy() if isinstance(df, torch.Tensor) else df - - cdf_np = scipy_student_t.cdf(z_np, df=df_np) - - f_cdf_z = torch.from_numpy(cdf_np).to(device=z.device, dtype=z.dtype) - return f_cdf_z - - -def crps_analytical_studentt( - q: StudentT, - y: torch.Tensor, -) -> torch.Tensor: - r"""Compute the analytical CRPS assuming a StudentT distribution. - - This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. - - For the standardized StudentT distribution: - - $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} - - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$ - - where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT - distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. - - For the location-scale transformed distribution: - - $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$ - - where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. - - Note: - This formula is only valid for degrees of freedom $\nu > 1$. - - See Also: - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2. - - Args: - q: A PyTorch StudentT distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - # Extract degrees of freedom (nu), location (mu), and scale (sigma). - df, loc, scale = q.df, q.loc, q.scale - - if torch.any(df <= 1): - raise ValueError("StudentT CRPS requires degrees of freedom > 1") - - # Standardize, and create standard StudentT distribution for CDF and PDF. - z = (y - loc) / scale - standard_t = torch.distributions.StudentT(df, loc=0, scale=1) - - # Compute standardized CDF F_nu(z) and PDF f_nu(z). - f_cdf_z = standardized_studentt_cdf_via_scipy(z, df) - f_z = torch.exp(standard_t.log_prob(z)) - - # Compute the beta function ratio: B(1/2, nu - 1/2) / B(1/2, nu/2)^2 - # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) - # B(1/2, nu - 1/2) / B(1/2, nu/2)^2 = ( Gamma(1/2) * Gamma(nu-1/2) / Gamma(nu) ) / - # ( Gamma(1/2) * Gamma(nu/2) / Gamma(nu/2 + 1/2) )^2 - # Simplifying to Gamma(nu - 1/2) Gamma(nu/2 + 1/2)^2 / ( Gamma(nu)Gamma(nu/2)^2 ) - # For numerical stability, we compute in log space. - log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=df.dtype, device=df.device)) - log_gamma_df_minus_half = torch.lgamma(df - 0.5) - log_gamma_df_half = torch.lgamma(df / 2) - log_gamma_df_half_plus_half = torch.lgamma(df / 2 + 0.5) - - # log[B(1/2, nu-1/2)] = log Gamma(1/2) + log Gamma(nu-1/2) - log Gamma(nu) - # log[B(1/2, nu/2)] = log Gamma(1/2) + log Gamma(nu/2) - log Gamma(nu/2 + 1/2) - # log[B(1/2, nu-1/2) / B(1/2, nu/2)^2] = log B(1/2, nu-1/2) - 2*log B(1/2, nu/2) - log_beta_ratio = ( - log_gamma_half - + log_gamma_df_minus_half - - torch.lgamma(df) - - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) - ) - beta_frac = torch.exp(log_beta_ratio) - - # Compute the CRPS for standardized values. - crps_standard = ( - z * (2 * f_cdf_z - 1) + 2 * f_z * (df + z**2) / (df - 1) - (2 * torch.sqrt(df) / (df - 1)) * beta_frac - ) - - # Apply location-scale transformation CRPS(F_{nu,mu,sigma}, y) = sigma * CRPS(F_nu, z) with z = (y - mu) / sigma. - crps = scale * crps_standard - - return crps diff --git a/torch_crps/ensemble.py b/torch_crps/ensemble.py new file mode 100644 index 0000000..27c062c --- /dev/null +++ b/torch_crps/ensemble.py @@ -0,0 +1,230 @@ +import torch + +from torch_crps.abstract import crps_abstract, scrps_abstract + + +def _accuracy_ensemble( + x: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """Compute accuracy term $A = E[|X - y|]$, i.e., mean absolute error, for an ensemble forecast. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + + Returns: + Accuracy values for each observation, of shape (*batch_shape). + """ + # Unsqueeze the observation for explicit broadcasting. + return torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) + + +def _dispersion_ensemble_naive( + x: torch.Tensor, + biased: bool, +) -> torch.Tensor: + """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using a naive O(m²) algorithm. + + m is the number of ensemble members. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + Dispersion values for each observation, of shape (*batch_shape). + """ + # Create a matrix of all pairwise differences between ensemble members using broadcasting. + x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) + x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) + pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) + + # Take the absolute value of every element in the matrix. + abs_pairwise_diffs = torch.abs(pairwise_diffs) + + # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. + if biased: + # For the biased estimator, we use the mean which divides by m². + dispersion = abs_pairwise_diffs.mean(dim=(-2, -1)) + else: + # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). + m = x.shape[-1] # number of ensemble members + dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) + + return dispersion + + +def _dispersion_ensemble( + x: torch.Tensor, + biased: bool, +) -> torch.Tensor: + """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using an efficient O(m log m) algorithm. + + m is the number of ensemble members. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + Dispersion values for each observation, of shape (*batch_shape). + """ + m = x.shape[-1] # number of ensemble members + + # Sort the predictions along the ensemble member dimension. + x_sorted, _ = torch.sort(x, dim=-1) + + # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. + coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 + + # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. + # We use the efficient O(m log m) implementation with a summation over a single dimension. + x_sum = torch.sum(coeffs * x_sorted, dim=-1) + + # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. + # This is half the mean absolute difference between all pairs of predictions. + denom = m * (m - 1) if not biased else m**2 + dispersion = 2 / denom * x_sum + + return dispersion + + +def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: + """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. + + This implementation uses the equality + + $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ + + It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, + as long as they are equal for `x` and `y`. + + See Also: + Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications + to Ensemble Weather Forecasts"; 2017 + + Note: + - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is + the number of ensemble members. This is done for clarity and educational purposes. + - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017). + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + biased: If True, uses the biased estimator for $D$, i.e., divides by m². If False, uses the unbiased estimator. + The unbiased estimator divides by m * (m - 1). + + Returns: + The CRPS value for each forecast in the batch, of shape (*batch_shape). + """ + if x.shape[:-1] != y.shape: + raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") + + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) + + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble_naive(x, biased) + + # CRPS value := A - 0.5 * D + return crps_abstract(accuracy, dispersion) + + +def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: + r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. + + This function implements + + $$ + \text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] + $$ + + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$. + + It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, + as long as they are equal for `x` and `y`. + + See Also: + Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications + to Ensemble Weather Forecasts"; 2017 + + Note: + - This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) + time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using + a mathematical identity to compute the mean absolute difference. You can also see this trick + [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] + + - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while + using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and + $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean dispersion, here. In (ePWM) they pulled + the mean out. The energy formula and the probability weighted moment formula are equivalent. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + The CRPS value for each forecast in the batch, of shape (*batch_shape). + """ + if x.shape[:-1] != y.shape: + raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") + + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) + + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble(x, biased) + + # CRPS value := A - 0.5 * D + return crps_abstract(accuracy, dispersion) + + +def scrps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: + r"""Computes the Scaled Continuous Ranked Probability Score (SCRPS) for an ensemble forecast. + + $$ + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) + $$ + + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. + + It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, + as long as they are equal for `x` and `y`. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + + Note: + This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) + time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using + a mathematical identity to compute the mean absolute difference. You can also see this trick + [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + The SCRPS value for each forecast in the batch, of shape (*batch_shape). + """ + if x.shape[:-1] != y.shape: + raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") + + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) + + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble(x, biased) + + # SCRPS value := A/D + 0.5 * log(D) + return scrps_abstract(accuracy, dispersion) diff --git a/torch_crps/ensemble_crps.py b/torch_crps/ensemble_crps.py deleted file mode 100644 index 6bff535..0000000 --- a/torch_crps/ensemble_crps.py +++ /dev/null @@ -1,134 +0,0 @@ -import torch - - -def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: - """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. - - This implementation uses the equality - - $$ CRPS(X, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ - - It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, - as long as they are equal for `x` and `y`. - - See Also: - Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications - to Ensemble Weather Forecasts"; 2017 - - Note: - - This implementation uses an inefficient algorithm to compute the term E[|X - X'|] in O(m²) where m is - the number of ensemble members. This is done for clarity and educational purposes. - - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017). - - Args: - x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). - y: The ground truth observations, of shape (*batch_shape). - biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. - The unbiased estimator divides by m * (m - 1) instead of m². - - Returns: - The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). - """ - if x.shape[:-1] != y.shape: - raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") - - # --- Accuracy term := E[|X - y|] - - # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. - mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) - - # --- Spread term := 0.5 * E[|X - X'|] - # This is half the mean absolute difference between all pairs of predictions. - - # Create a matrix of all pairwise differences between ensemble members using broadcasting. - x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) - x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) - pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) - - # Take the absolute value of every element in the matrix. - abs_pairwise_diffs = torch.abs(pairwise_diffs) - - # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. - if biased: - # For the biased estimator, we use the mean which divides by m². - mean_spread = abs_pairwise_diffs.mean(dim=(-2, -1)) - else: - # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). - m = x.shape[-1] # number of ensemble members - mean_spread = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) - - # --- Assemble the final CRPS value. - crps_value = mae - 0.5 * mean_spread - - return crps_value - - -def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: - r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. - - This implementation uses the equalities - - $$ CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ - - and - - $$ CRPS(F, y) = E[|X - y|] + E[X] - 2 E[X F(X)] $$ - - It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, - as long as they are equal for `x` and `y`. - - See Also: - Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications - to Ensemble Weather Forecasts"; 2017 - - Note: - - This implementation uses an efficient algorithm to compute the term E[|X - X'|] in O(m log(m)) time, where m - is the number of ensemble members. This is achieved by sorting the ensemble predictions and using a mathematical - identity to compute the mean absolute difference. You can also see this trick - [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] - - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while - using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and - $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean spread, here. In (ePWM) they pulled - the mean out. The energy formula and the probability weighted moment formula are equivalent. - - Args: - x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). - y: The ground truth observations, of shape (*batch_shape). - biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. - The unbiased estimator divides by m * (m - 1) instead of m². - - Returns: - The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). - """ - if x.shape[:-1] != y.shape: - raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") - - # Get the number of ensemble members. - m = x.shape[-1] - - # --- Accuracy term := E[|X - y|] - - # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. - mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) - - # --- Spread term B := 0.5 * E[|X - X'|] - # This is half the mean absolute difference between all pairs of predictions. - # We use the efficient O(m log m) implementation with a summation over a single dimension. - - # Sort the predictions along the ensemble member dimension. - x_sorted, _ = torch.sort(x, dim=-1) - - # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. - coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 - - # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. - x_sum = torch.sum(coeffs * x_sorted, dim=-1) - - # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. - denom = m * (m - 1) if not biased else m**2 - half_mean_spread = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step - - # --- Assemble the final CRPS value. - crps_value = mae - half_mean_spread # 0.5 already accounted for above - - return crps_value diff --git a/torch_crps/integral_crps.py b/torch_crps/integral.py similarity index 92% rename from torch_crps/integral_crps.py rename to torch_crps/integral.py index 86cd898..36f404b 100644 --- a/torch_crps/integral_crps.py +++ b/torch_crps/integral.py @@ -1,7 +1,7 @@ import torch from torch.distributions import Distribution, StudentT -from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy +from torch_crps.analytical.studentt import standardized_studentt_cdf_via_scipy def crps_integral( @@ -55,6 +55,6 @@ def integrand(x: torch.Tensor) -> torch.Tensor: # Compute the integral using the trapezoidal rule. integral_values = integrand(x_values) - crps_values = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0) + crps = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0) - return crps_values + return crps diff --git a/torch_crps/normalization.py b/torch_crps/normalization.py index 2739fdb..f1514f8 100644 --- a/torch_crps/normalization.py +++ b/torch_crps/normalization.py @@ -3,6 +3,12 @@ import torch +from torch_crps.analytical.dispatch import crps_analytical +from torch_crps.analytical.normal import crps_analytical_normal +from torch_crps.analytical.studentt import crps_analytical_studentt +from torch_crps.ensemble import crps_ensemble +from torch_crps.integral import crps_integral + WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float @@ -15,7 +21,7 @@ def normalize_by_observation(crps_fcn: Callable) -> Callable: - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1. Args: - crps_fcn: CRPS-calculating function to be wrapped. The fucntion must accept an argument called y which is + crps_fcn: CRPS-calculating function to be wrapped. The function must accept an argument called y which is at the 2nd position. Returns: @@ -25,7 +31,7 @@ def normalize_by_observation(crps_fcn: Callable) -> Callable: @functools.wraps(crps_fcn) def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor: - """The function returned by the decorator that does the normalization and the forwading to the CRPS function.""" + """The function returned by the decorator that normalizes and forwards to the CRPS function.""" # Find the observation 'y' from the arguments. if "y" in kwargs: y = kwargs["y"] @@ -47,9 +53,16 @@ def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Te abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype) # Call the original CRPS function. - crps_result = crps_fcn(*args, **kwargs) + crps = crps_fcn(*args, **kwargs) # Normalize the result. - return crps_result / abs_max_y + return crps / abs_max_y return wrapper + + +crps_analytical_obsnormalized = normalize_by_observation(crps_analytical) +crps_analytical_normal_obsnormalized = normalize_by_observation(crps_analytical_normal) +crps_analytical_studentt_obsnormalized = normalize_by_observation(crps_analytical_studentt) +crps_ensemble_obsnormalized = normalize_by_observation(crps_ensemble) +crps_integral_obsnormalized = normalize_by_observation(crps_integral) diff --git a/torch_crps/utils.py b/torch_crps/utils.py deleted file mode 100644 index e69de29..0000000