From 0b43debbbd35a0c53e1c559e52d0284c3ca54245 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 16:36:19 +0100 Subject: [PATCH 01/19] WIP --- pyproject.toml | 7 +- tests/test_analytical_crps.py | 20 ++-- torch_crps/__init__.py | 4 + torch_crps/analytical_crps.py | 211 ++++++++++++++++++++++++++++------ torch_crps/ensemble_crps.py | 30 ++--- 5 files changed, 211 insertions(+), 61 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 34b5ef6..9a2ea71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,9 @@ distance-dirty = "{base_version}" [tool.mypy] ignore_missing_imports = true # when no stubs are available, e.g. for matplotlib or tabulate -pretty = true -show_error_context = true -show_traceback = true +pretty = true +show_error_context = true +show_traceback = true [tool.pytest.ini_options] addopts = [ @@ -140,6 +140,7 @@ ignore = [ "PLC0415", # import should be at the top-level of a file "PLR", # pylint refactor "RUF001", # allow for greek letters + "RUF003", # allow for greek letters "UP035", # allow importing Callable from typing ] preview = true diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index a88bf29..d5b15f0 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -1,11 +1,12 @@ -from typing import Any +from typing import Any, Callable import pytest import torch from torch.distributions import Normal, StudentT from tests.conftest import needs_cuda -from torch_crps import crps_analytical, crps_analytical_normal, crps_analytical_studentt +from torch_crps import crps_analytical, crps_analytical_normal, crps_analytical_studentt, scrps_analytical_normal +from torch_crps.analytical_crps import scrps_analytical @pytest.mark.parametrize( @@ -15,7 +16,8 @@ pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_analytical_normal_batched_smoke(use_cuda: bool): +@pytest.mark.parametrize("crps_fcn", [crps_analytical_normal, scrps_analytical_normal], ids=["CRPS", "SCRPS"]) +def test_crps_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable): """Test that analytical solution works with batched Normal distributions.""" torch.manual_seed(0) @@ -28,13 +30,14 @@ def test_crps_analytical_normal_batched_smoke(use_cuda: bool): y = torch.tensor([[0.5, 1.5], [2.5, 3.5], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu") # Compute CRPS using the analytical method. - crps_analytical = crps_analytical_normal(normal_dist, y) + crps_analytical = crps_fcn(normal_dist, y) # Simple sanity check: CRPS should be non-negative. assert crps_analytical.shape == y.shape, "CRPS output shape should match input shape." assert crps_analytical.dtype in [torch.float32, torch.float64], "CRPS output dtype should be float." assert crps_analytical.device == y.device, "CRPS output device should match input device." - assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative." + if crps_fcn == crps_analytical_normal: + assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative." @pytest.mark.parametrize( @@ -75,16 +78,17 @@ def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, ], ids=["Normal", "StudentT", "not_supported"], ) -def test_crps_analytical_interface_smoke(q: Any): # noqa: ANN401 +@pytest.mark.parametrize("crps_fcn", [crps_analytical, scrps_analytical], ids=["CRPS", "SCRPS"]) +def test_crps_analytical_interface_smoke(q: Any, crps_fcn: Callable): # noqa: ANN401 """Test if the top-level interface function is working""" y = torch.zeros(3) # can be the same for all tests if isinstance(q, (Normal, StudentT)): # Supported, should return a result. - crps = crps_analytical(q, y) + crps = crps_fcn(q, y) assert isinstance(crps, torch.Tensor) else: # Not supported, should raise an error. with pytest.raises(NotImplementedError): - crps_analytical(q, y) + crps_fcn(q, y) diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index 4062956..8f63dce 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -2,6 +2,8 @@ crps_analytical, crps_analytical_normal, crps_analytical_studentt, + scrps_analytical, + scrps_analytical_normal, ) from .ensemble_crps import crps_ensemble, crps_ensemble_naive from .integral_crps import crps_integral @@ -25,4 +27,6 @@ "crps_ensemble_normalized", "crps_integral", "crps_integral_normalized", + "scrps_analytical", + "scrps_analytical_normal", ] diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index 9492997..676eff2 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -2,8 +2,11 @@ from torch.distributions import Distribution, Normal, StudentT -def crps_analytical(q: Distribution, y: torch.Tensor) -> torch.Tensor: - """Compute the analytical CRPS. +def crps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form. Note: The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. @@ -29,14 +32,44 @@ def crps_analytical(q: Distribution, y: torch.Tensor) -> torch.Tensor: ) +def scrps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) scaled CRPS (SCRPS) in closed-form. + + Note: + The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. + There exists analytical solutions for other distributions, but they are not implemented, yet. + Feel free to create an issue or pull request. + + Args: + q: A PyTorch distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + if isinstance(q, Normal): + return scrps_analytical_normal(q, y) + elif isinstance(q, StudentT): + return scrps_analytical_studentt(q, y) + else: + raise NotImplementedError( + f"Detected distribution of type {type(q)}, but there are only analytical solutions for " + "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " + "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need." + ) + + def crps_analytical_normal( q: Normal, y: torch.Tensor, ) -> torch.Tensor: - """Compute the analytical CRPS assuming a normal distribution. + """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. See Also: - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. Equation (5) for the analytical formula for CRPS of Normal distribution. Args: @@ -49,16 +82,58 @@ def crps_analytical_normal( # Compute standard normal CDF and PDF. z = (y - q.loc) / q.scale # standardize standard_normal = torch.distributions.Normal(0, 1) - phi_z = standard_normal.cdf(z) # Φ(z) + cdf_z = standard_normal.cdf(z) # Φ(z) pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) # Analytical CRPS formula. - crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) + sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=z.device, dtype=z.dtype)) + crps = q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / sqrt_pi) return crps -def standardized_studentt_cdf_via_scipy(z: torch.Tensor, df: torch.Tensor | float) -> torch.Tensor: +def scrps_analytical_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a normal distribution. + + Note: + In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + Equation (3) for the definition of the SCRPS. + Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + # --- Dispersion Term D := E[|X - X'|] = 2σ / √π + sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=y.device, dtype=y.dtype)) + dispersion = 2 * q.scale / sqrt_pi + + # --- Accuracy Term A := E[|X - y|] + z = (y - q.loc) / q.scale # standardize + standard_normal = torch.distributions.Normal(0, 1) + cdf_z = standard_normal.cdf(z) # Φ(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) + accuracy = q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) + + # --- SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) + scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) + + return scrps + + +def standardized_studentt_cdf_via_scipy( + z: torch.Tensor, + df: torch.Tensor | float, +) -> torch.Tensor: """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has a stable implementation. @@ -84,17 +159,16 @@ def standardized_studentt_cdf_via_scipy(z: torch.Tensor, df: torch.Tensor | floa z_np = z.detach().cpu().numpy() df_np = df.detach().cpu().numpy() if isinstance(df, torch.Tensor) else df - cdf_np = scipy_student_t.cdf(z_np, df=df_np) + cdf_z_np = scipy_student_t.cdf(z_np, df=df_np) - f_cdf_z = torch.from_numpy(cdf_np).to(device=z.device, dtype=z.dtype) - return f_cdf_z + return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) def crps_analytical_studentt( q: StudentT, y: torch.Tensor, ) -> torch.Tensor: - r"""Compute the analytical CRPS assuming a StudentT distribution. + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. @@ -125,48 +199,115 @@ def crps_analytical_studentt( Returns: CRPS values for each observation, of shape (num_samples,). """ - # Extract degrees of freedom (nu), location (mu), and scale (sigma). - df, loc, scale = q.df, q.loc, q.scale - - if torch.any(df <= 1): + # Extract degrees of freedom ν, location μ, and scale σ. + nu, mu, sigma = q.df, q.loc, q.scale + if torch.any(nu <= 1): raise ValueError("StudentT CRPS requires degrees of freedom > 1") # Standardize, and create standard StudentT distribution for CDF and PDF. - z = (y - loc) / scale - standard_t = torch.distributions.StudentT(df, loc=0, scale=1) + z = (y - mu) / sigma + standard_t = torch.distributions.StudentT(nu, loc=0, scale=1) - # Compute standardized CDF F_nu(z) and PDF f_nu(z). - f_cdf_z = standardized_studentt_cdf_via_scipy(z, df) + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + f_cdf_z = standardized_studentt_cdf_via_scipy(z, nu) f_z = torch.exp(standard_t.log_prob(z)) - # Compute the beta function ratio: B(1/2, nu - 1/2) / B(1/2, nu/2)^2 + # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) - # B(1/2, nu - 1/2) / B(1/2, nu/2)^2 = ( Gamma(1/2) * Gamma(nu-1/2) / Gamma(nu) ) / - # ( Gamma(1/2) * Gamma(nu/2) / Gamma(nu/2 + 1/2) )^2 - # Simplifying to Gamma(nu - 1/2) Gamma(nu/2 + 1/2)^2 / ( Gamma(nu)Gamma(nu/2)^2 ) + # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) # For numerical stability, we compute in log space. - log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=df.dtype, device=df.device)) - log_gamma_df_minus_half = torch.lgamma(df - 0.5) - log_gamma_df_half = torch.lgamma(df / 2) - log_gamma_df_half_plus_half = torch.lgamma(df / 2 + 0.5) - - # log[B(1/2, nu-1/2)] = log Gamma(1/2) + log Gamma(nu-1/2) - log Gamma(nu) - # log[B(1/2, nu/2)] = log Gamma(1/2) + log Gamma(nu/2) - log Gamma(nu/2 + 1/2) - # log[B(1/2, nu-1/2) / B(1/2, nu/2)^2] = log B(1/2, nu-1/2) - 2*log B(1/2, nu/2) + log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_gamma_df_minus_half = torch.lgamma(nu - 0.5) + log_gamma_df_half = torch.lgamma(nu / 2) + log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + + # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) + # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) + # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) log_beta_ratio = ( log_gamma_half + log_gamma_df_minus_half - - torch.lgamma(df) + - torch.lgamma(nu) - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) ) beta_frac = torch.exp(log_beta_ratio) # Compute the CRPS for standardized values. crps_standard = ( - z * (2 * f_cdf_z - 1) + 2 * f_z * (df + z**2) / (df - 1) - (2 * torch.sqrt(df) / (df - 1)) * beta_frac + z * (2 * f_cdf_z - 1) + 2 * f_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac ) - # Apply location-scale transformation CRPS(F_{nu,mu,sigma}, y) = sigma * CRPS(F_nu, z) with z = (y - mu) / sigma. - crps = scale * crps_standard + # Apply location-scale transformation CRPS(F_{ν,μ,σ}, y) = σ * CRPS(F_{ν}, z) with z = (y - μ) / σ. + crps = sigma * crps_standard return crps + + +def scrps_analytical_studentt( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. + + The score is calculated as: + $$ \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \cdot \log(D) $$ + + where: + - $A = E_F[|X - y|]$ is the Accuracy term. + - $D = E_F[|X - X'|]$ is the Dispersion term. + - $F$ is the Student-T distribution $t(\nu, \mu, \sigma^2)$. + + Note: + This formula is only valid for degrees of freedom $\nu > 1$. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + # Extract degrees of freedom ν, location μ, and scale σ. + nu, mu, sigma = q.df, q.loc, q.scale + if torch.any(nu <= 1): + raise ValueError("StudentT SCRPS requires degrees of freedom > 1") + + # Use the device of y for creating new (intermediate) tensors. + device, dtype = y.device, y.dtype + + # --- Dispersion Term D := E[|X - X'|] = (4σ / (ν-1)) * (Γ(ν/2) / Γ((ν-1)/2))² + # We compute in log space for numerical stability. + log_4 = torch.log(torch.tensor(4.0, dtype=dtype, device=device)) + log_dispersion = ( + log_4 + torch.log(sigma) - torch.log(nu - 1) + 2 * (torch.lgamma(nu / 2) - torch.lgamma((nu - 1) / 2)) + ) + dispersion = torch.exp(log_dispersion) + + # --- 2. Accuracy Term A := E[|X - y|] + # Standardize, and create standard StudentT distributions for CDFs. + z = (y - mu) / sigma + standard_t_nu = StudentT(nu, loc=0, scale=1) + standard_t_nu_plus_1 = StudentT(nu + 1, loc=0, scale=1) + + # Compute Beta function term B(ν/2, 1/2) + lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) + beta_term = torch.exp(log_beta_term) + + # Compute components of the 'A' formula from Bolin & Wallin Appendix A.2 + term_A1 = z * (2 * standard_t_nu.cdf(z) - 1) + + term_A2_factor = (2 * (nu + z**2)) / (nu * beta_term) + term_A2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) + term_A2 = term_A2_factor * standard_t_nu_plus_1.cdf(term_A2_cdf_arg) + + accuracy = sigma * (term_A1 + term_A2) + + # --- 3. SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) + scrps = accuracy / dispersion + 0.5 * log_dispersion + return scrps diff --git a/torch_crps/ensemble_crps.py b/torch_crps/ensemble_crps.py index 6bff535..f531f70 100644 --- a/torch_crps/ensemble_crps.py +++ b/torch_crps/ensemble_crps.py @@ -32,12 +32,12 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) - if x.shape[:-1] != y.shape: raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") - # --- Accuracy term := E[|X - y|] + # --- Accuracy term A := E[|X - y|] # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) - # --- Spread term := 0.5 * E[|X - X'|] + # --- Dispersion term D := E[|X - X'|] # This is half the mean absolute difference between all pairs of predictions. # Create a matrix of all pairwise differences between ensemble members using broadcasting. @@ -51,14 +51,14 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) - # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. if biased: # For the biased estimator, we use the mean which divides by m². - mean_spread = abs_pairwise_diffs.mean(dim=(-2, -1)) + mean_dispersion = abs_pairwise_diffs.mean(dim=(-2, -1)) else: # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). m = x.shape[-1] # number of ensemble members - mean_spread = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) + mean_dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) - # --- Assemble the final CRPS value. - crps_value = mae - 0.5 * mean_spread + # --- Assemble the CRPS value: A - 0.5 * D + crps_value = mae - 0.5 * mean_dispersion return crps_value @@ -82,13 +82,13 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc to Ensemble Weather Forecasts"; 2017 Note: - - This implementation uses an efficient algorithm to compute the term E[|X - X'|] in O(m log(m)) time, where m - is the number of ensemble members. This is achieved by sorting the ensemble predictions and using a mathematical - identity to compute the mean absolute difference. You can also see this trick + - This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) + time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using + a mathematical identity to compute the mean absolute difference. You can also see this trick [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and - $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean spread, here. In (ePWM) they pulled + $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean dispersion, here. In (ePWM) they pulled the mean out. The energy formula and the probability weighted moment formula are equivalent. Args: @@ -106,12 +106,12 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc # Get the number of ensemble members. m = x.shape[-1] - # --- Accuracy term := E[|X - y|] + # --- Accuracy term A := E[|X - y|] # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) - # --- Spread term B := 0.5 * E[|X - X'|] + # --- Dispersion term D := E[|X - X'|] # This is half the mean absolute difference between all pairs of predictions. # We use the efficient O(m log m) implementation with a summation over a single dimension. @@ -126,9 +126,9 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. denom = m * (m - 1) if not biased else m**2 - half_mean_spread = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step + half_mean_dispersion = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step - # --- Assemble the final CRPS value. - crps_value = mae - half_mean_spread # 0.5 already accounted for above + # --- CRPS value := A - 0.5 * D + crps_value = mae - half_mean_dispersion # 0.5 already accounted for above return crps_value From 86c795aad1774744f50c4afad229ddfd727e3d01 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 17:16:12 +0100 Subject: [PATCH 02/19] WIP before refactor --- torch_crps/analytical_crps.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index 676eff2..f6024e9 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -209,8 +209,8 @@ def crps_analytical_studentt( standard_t = torch.distributions.StudentT(nu, loc=0, scale=1) # Compute standardized CDF F_ν(z) and PDF f_ν(z). - f_cdf_z = standardized_studentt_cdf_via_scipy(z, nu) - f_z = torch.exp(standard_t.log_prob(z)) + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) @@ -236,7 +236,7 @@ def crps_analytical_studentt( # Compute the CRPS for standardized values. crps_standard = ( - z * (2 * f_cdf_z - 1) + 2 * f_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac + z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac ) # Apply location-scale transformation CRPS(F_{ν,μ,σ}, y) = σ * CRPS(F_{ν}, z) with z = (y - μ) / σ. @@ -255,9 +255,11 @@ def scrps_analytical_studentt( $$ \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \cdot \log(D) $$ where: - - $A = E_F[|X - y|]$ is the Accuracy term. - - $D = E_F[|X - X'|]$ is the Dispersion term. - - $F$ is the Student-T distribution $t(\nu, \mu, \sigma^2)$. + - $F_{\nu, \mu, \sigma^2}$ is the cumulative Student-T distribution, and $F_{\nu}$ is the standardized version. + - $A = E_F[|X - y|]$ is the accuracy term. + - $A = \sigma [ z(2 F_{\nu}(z) - 1) + 2(\nu + z²) / (\nu*B(\nu/2, 1/2)) * F_{\nu+1}(z * \sqrt{(\nu+1)/(\nu+z²)}) ]$ + - $D = E_F[|X - X'|]$ is the cispersion term. + - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2}$ Note: This formula is only valid for degrees of freedom $\nu > 1$. @@ -288,26 +290,27 @@ def scrps_analytical_studentt( ) dispersion = torch.exp(log_dispersion) - # --- 2. Accuracy Term A := E[|X - y|] - # Standardize, and create standard StudentT distributions for CDFs. + # --- Accuracy Term A := E[|X - y|] + # Standardize. z = (y - mu) / sigma - standard_t_nu = StudentT(nu, loc=0, scale=1) - standard_t_nu_plus_1 = StudentT(nu + 1, loc=0, scale=1) - # Compute Beta function term B(ν/2, 1/2) + # Compute Beta function term B(ν/2, 1/2). lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) beta_term = torch.exp(log_beta_term) - # Compute components of the 'A' formula from Bolin & Wallin Appendix A.2 - term_A1 = z * (2 * standard_t_nu.cdf(z) - 1) - + # Compute components for A = σ * [ z(2F_ν(z) - 1) + (2(ν+z²))/(ν*B(ν/2, 1/2)) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) ]. + # Just like for the CRPS, this includes a transformation to and back from standardized values. + cdf_nu_z = standardized_studentt_cdf_via_scipy(z, nu) + term_A1 = z * (2 * cdf_nu_z - 1) term_A2_factor = (2 * (nu + z**2)) / (nu * beta_term) term_A2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) - term_A2 = term_A2_factor * standard_t_nu_plus_1.cdf(term_A2_cdf_arg) + cdf_nu_plus_1_term = standardized_studentt_cdf_via_scipy(term_A2_cdf_arg, nu + 1) + term_A2 = term_A2_factor * cdf_nu_plus_1_term accuracy = sigma * (term_A1 + term_A2) # --- 3. SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) scrps = accuracy / dispersion + 0.5 * log_dispersion + return scrps From f0d84a0521371b2918b74fd6f012e87aedc82522 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 17:34:26 +0100 Subject: [PATCH 03/19] Does the convergence test make sense for the SCRPS? --- tests/test_analytical_crps.py | 30 ++++++++---- tests/test_ensemble_crps.py | 6 +-- tests/test_integral_crps.py | 4 +- tests/test_normalization.py | 4 +- torch_crps/__init__.py | 2 + torch_crps/analytical_crps.py | 87 +++++++++++++++++++++-------------- 6 files changed, 84 insertions(+), 49 deletions(-) diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index d5b15f0..df5a2ee 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -3,10 +3,17 @@ import pytest import torch from torch.distributions import Normal, StudentT +from typing_extensions import Literal from tests.conftest import needs_cuda -from torch_crps import crps_analytical, crps_analytical_normal, crps_analytical_studentt, scrps_analytical_normal -from torch_crps.analytical_crps import scrps_analytical +from torch_crps import ( + crps_analytical, + crps_analytical_normal, + crps_analytical_studentt, + scrps_analytical, + scrps_analytical_normal, + scrps_analytical_studentt, +) @pytest.mark.parametrize( @@ -17,7 +24,7 @@ ], ) @pytest.mark.parametrize("crps_fcn", [crps_analytical_normal, scrps_analytical_normal], ids=["CRPS", "SCRPS"]) -def test_crps_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable): +def test_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable[..., torch.Tensor]): """Test that analytical solution works with batched Normal distributions.""" torch.manual_seed(0) @@ -50,7 +57,10 @@ def test_crps_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable ids=["standard", "shifted_scaled", "neg-mean_large-var"], ) @pytest.mark.parametrize("y", [torch.tensor([-10.0, -1.0, 0.0, 0.5, 2.0, 5.0])]) -def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor): +@pytest.mark.parametrize("crps_fcn_type", ["CRPS", "SCRPS"], ids=["CRPS", "SCRPS"]) +def test_studentt_convergence_to_normal( + loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, crps_fcn_type: Literal["CRPS", "SCRPS"] +): """Test that for a very high degrees of freedom, the StudentT CRPS converges to the Normal CRPS. This validates both implementations against each other. """ @@ -60,11 +70,15 @@ def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, q_normal = Normal(loc=loc, scale=scale) # Calculate the analytical CRPS for both. - crps_studentt = crps_analytical_studentt(q_studentt, y) - crps_normal = crps_analytical_normal(q_normal, y) + if crps_fcn_type == "CRPS": + score_value_studentt = crps_analytical_studentt(q_studentt, y) + score_value_normal = crps_analytical_normal(q_normal, y) + else: + score_value_studentt = scrps_analytical_studentt(q_studentt, y) + score_value_normal = scrps_analytical_normal(q_normal, y) # Assert that their results are nearly identical. - assert torch.allclose(crps_studentt, crps_normal, atol=2e-3), ( + assert torch.allclose(score_value_studentt, score_value_normal, atol=2e-3), ( "StudentT CRPS with high 'df' should match Normal CRPS." ) @@ -79,7 +93,7 @@ def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, ids=["Normal", "StudentT", "not_supported"], ) @pytest.mark.parametrize("crps_fcn", [crps_analytical, scrps_analytical], ids=["CRPS", "SCRPS"]) -def test_crps_analytical_interface_smoke(q: Any, crps_fcn: Callable): # noqa: ANN401 +def test_analytical_interface_smoke(q: Any, crps_fcn: Callable[..., torch.Tensor]): # noqa: ANN401 """Test if the top-level interface function is working""" y = torch.zeros(3) # can be the same for all tests diff --git a/tests/test_ensemble_crps.py b/tests/test_ensemble_crps.py index 5b0db05..46c11c8 100644 --- a/tests/test_ensemble_crps.py +++ b/tests/test_ensemble_crps.py @@ -22,7 +22,7 @@ pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_ensemble_smoke( +def test_ensemble_smoke( test_case_fixture_name: str, crps_fcn: Callable, biased: bool, use_cuda: bool, request: FixtureRequest ): """Test that naive ensemble method yield.""" @@ -46,7 +46,7 @@ def test_crps_ensemble_smoke( ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], ) @pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) -def test_crps_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ensemble: int = 10): +def test_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ensemble: int = 10): """Test that both implementations of crps_ensemble yield the same result.""" torch.manual_seed(0) @@ -67,7 +67,7 @@ def test_crps_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ens ) -def test_crps_ensemble_invalid_shapes(dim_ensemble: int = 10): +def test_ensemble_invalid_shapes(dim_ensemble: int = 10): """Test that crps_ensemble raises an error for invalid input shapes.""" # Mismatch in the number of batch dimensions. x = torch.randn(2, 3, dim_ensemble) diff --git a/tests/test_integral_crps.py b/tests/test_integral_crps.py index 2906ee5..65a197d 100644 --- a/tests/test_integral_crps.py +++ b/tests/test_integral_crps.py @@ -12,7 +12,7 @@ pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_integral_vs_analytical_normal(use_cuda: bool): +def test_integral_vs_analytical_normal(use_cuda: bool): """Test that naive integral method matches the analytical solution for Normal distributions.""" torch.manual_seed(0) @@ -47,7 +47,7 @@ def test_crps_integral_vs_analytical_normal(use_cuda: bool): pytest.param(True, marks=needs_cuda, id="cuda"), ], ) -def test_crps_integral_vs_analytical_studentt(use_cuda: bool): +def test_integral_vs_analytical_studentt(use_cuda: bool): """Test that naive integral method matches the analytical solution for StudentT distributions.""" torch.manual_seed(0) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 209f48a..80b28f1 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -29,7 +29,7 @@ "crps_integral_normalized", ], ) -def test_nomrmalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: int = 3): +def test_normalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: int = 3): """Test if the normalization wrapper handles the underlying function's arguments correctly.""" torch.manual_seed(0) @@ -84,7 +84,7 @@ def test_nomrmalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: ], ) @pytest.mark.parametrize("num_y", [1, 5, 100], ids=["1_obs", "5_obs", "100_obs"]) -def test_nomrmalization_wrapper_output_consistency(wrapped_crps_fcn: Callable, num_y: int): +def test_normalization_wrapper_output_consistency(wrapped_crps_fcn: Callable, num_y: int): """Test if the normalization wrapper results in normalized CRPS values.""" torch.manual_seed(0) diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index 8f63dce..c50cb68 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -4,6 +4,7 @@ crps_analytical_studentt, scrps_analytical, scrps_analytical_normal, + scrps_analytical_studentt, ) from .ensemble_crps import crps_ensemble, crps_ensemble_naive from .integral_crps import crps_integral @@ -29,4 +30,5 @@ "crps_integral_normalized", "scrps_analytical", "scrps_analytical_normal", + "scrps_analytical_studentt", ] diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index f6024e9..00072ef 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -62,6 +62,46 @@ def scrps_analytical( ) +def _accuracy_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute accuracy term A = E[|X - y|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + z = (y - q.loc) / q.scale + standard_normal = torch.distributions.Normal(0, 1) + + cdf_z = standard_normal.cdf(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) + + return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) + + +def _dispersion_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute dispersion term D = E[|X - X'|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=y.device, dtype=y.dtype)) + + return 2 * q.scale / sqrt_pi + + def crps_analytical_normal( q: Normal, y: torch.Tensor, @@ -79,17 +119,9 @@ def crps_analytical_normal( Returns: CRPS values for each observation, of shape (num_samples,). """ - # Compute standard normal CDF and PDF. - z = (y - q.loc) / q.scale # standardize - standard_normal = torch.distributions.Normal(0, 1) - cdf_z = standard_normal.cdf(z) # Φ(z) - pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) - - # Analytical CRPS formula. - sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=z.device, dtype=z.dtype)) - crps = q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / sqrt_pi) - - return crps + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q, y) + return accuracy - dispersion / 2 def scrps_analytical_normal( @@ -113,26 +145,14 @@ def scrps_analytical_normal( Returns: SCRPS values for each observation, of shape (num_samples,). """ - # --- Dispersion Term D := E[|X - X'|] = 2σ / √π - sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=y.device, dtype=y.dtype)) - dispersion = 2 * q.scale / sqrt_pi - - # --- Accuracy Term A := E[|X - y|] - z = (y - q.loc) / q.scale # standardize - standard_normal = torch.distributions.Normal(0, 1) - cdf_z = standard_normal.cdf(z) # Φ(z) - pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) - accuracy = q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) - - # --- SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) - scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) - - return scrps + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q, y) + return accuracy / dispersion + 0.5 * torch.log(dispersion) def standardized_studentt_cdf_via_scipy( z: torch.Tensor, - df: torch.Tensor | float, + nu: torch.Tensor | float, ) -> torch.Tensor: """Since the `torch.distributions.StudentT` class does not have a `cdf()` method, we resort to scipy which has a stable implementation. @@ -143,7 +163,7 @@ def standardized_studentt_cdf_via_scipy( Args: z: Standardized values at which to evaluate the CDF. - df: Degrees of freedom of the StudentT distribution. + nu: Degrees of freedom of the StudentT distribution. Returns: CDF values of the standardized StudentT distribution at `z`. @@ -157,9 +177,9 @@ def standardized_studentt_cdf_via_scipy( ) from e z_np = z.detach().cpu().numpy() - df_np = df.detach().cpu().numpy() if isinstance(df, torch.Tensor) else df + nu_np = nu.detach().cpu().numpy() if isinstance(nu, torch.Tensor) else nu - cdf_z_np = scipy_student_t.cdf(z_np, df=df_np) + cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np) return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) @@ -215,7 +235,7 @@ def crps_analytical_studentt( # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / - # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) # For numerical stability, we compute in log space. log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) @@ -295,7 +315,7 @@ def scrps_analytical_studentt( z = (y - mu) / sigma # Compute Beta function term B(ν/2, 1/2). - lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=dtype, device=device)) log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) beta_term = torch.exp(log_beta_term) @@ -307,10 +327,9 @@ def scrps_analytical_studentt( term_A2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) cdf_nu_plus_1_term = standardized_studentt_cdf_via_scipy(term_A2_cdf_arg, nu + 1) term_A2 = term_A2_factor * cdf_nu_plus_1_term - accuracy = sigma * (term_A1 + term_A2) - # --- 3. SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) + # --- SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) scrps = accuracy / dispersion + 0.5 * log_dispersion return scrps From 701ee1bb14777b876989db664d5bbda3c7debd8a Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 19:21:11 +0100 Subject: [PATCH 04/19] WIP helper functions --- tests/test_analytical_crps.py | 58 +++++++----- torch_crps/analytical_crps.py | 172 +++++++++++++++++++++++----------- 2 files changed, 151 insertions(+), 79 deletions(-) diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index df5a2ee..1aebe8c 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -3,16 +3,13 @@ import pytest import torch from torch.distributions import Normal, StudentT -from typing_extensions import Literal from tests.conftest import needs_cuda from torch_crps import ( crps_analytical, crps_analytical_normal, - crps_analytical_studentt, scrps_analytical, scrps_analytical_normal, - scrps_analytical_studentt, ) @@ -51,35 +48,48 @@ def test_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable[..., "loc, scale", [ (torch.tensor(0.0), torch.tensor(1.0)), - (torch.tensor(2.0), torch.tensor(0.5)), - (torch.tensor(-5.0), torch.tensor(10.0)), + (torch.tensor(-1.0), torch.tensor(0.5)), + (torch.tensor(1.0), torch.tensor(0.5)), + (torch.tensor(10.0), torch.tensor(20.0)), + (torch.tensor(-10.0), torch.tensor(20.0)), + (torch.tensor(100.0), torch.tensor(5.0)), + (torch.tensor(-100.0), torch.tensor(5.0)), + ], + ids=[ + "standard", + "small-neg-mean_small-var", + "small-pos-mean_small-var", + "pos-mean_large-var", + "neg-mean_large-var", + "large-mean_medium-var", + "large-neg-mean_medium-var", ], - ids=["standard", "shifted_scaled", "neg-mean_large-var"], ) -@pytest.mark.parametrize("y", [torch.tensor([-10.0, -1.0, 0.0, 0.5, 2.0, 5.0])]) -@pytest.mark.parametrize("crps_fcn_type", ["CRPS", "SCRPS"], ids=["CRPS", "SCRPS"]) -def test_studentt_convergence_to_normal( - loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, crps_fcn_type: Literal["CRPS", "SCRPS"] -): - """Test that for a very high degrees of freedom, the StudentT CRPS converges to the Normal CRPS. - This validates both implementations against each other. +@pytest.mark.parametrize("y", [torch.tensor([-95.0, -80.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50.0])]) +def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, atol: float = 3e-3): + """Test that for a high degrees of freedom, the StudentT score converges to the Normal score + when their standard deviations are matched. + + Note: + This test only works for the CRPS. For the SCRPS, the differences are too big. """ - # Create the two distributions with identical parameters. + # Create the StudentT distribution with a high degree of freedom. high_df = torch.tensor(1000.0) q_studentt = StudentT(df=high_df, loc=loc, scale=scale) - q_normal = Normal(loc=loc, scale=scale) - # Calculate the analytical CRPS for both. - if crps_fcn_type == "CRPS": - score_value_studentt = crps_analytical_studentt(q_studentt, y) - score_value_normal = crps_analytical_normal(q_normal, y) - else: - score_value_studentt = scrps_analytical_studentt(q_studentt, y) - score_value_normal = scrps_analytical_normal(q_normal, y) + # Calculate the standard deviation of the StudentT distribution. The variance is (df / (df - 2)) * scale^2 + student_t_std_dev = scale * torch.sqrt(high_df / (high_df - 2)) + + # Create the Normal distribution with matching standard deviation. + q_normal = Normal(loc=loc, scale=student_t_std_dev) + + # Calculate the analytical scores for both. + score_value_studentt = crps_analytical(q_studentt, y) + score_value_normal = crps_analytical(q_normal, y) # Assert that their results are nearly identical. - assert torch.allclose(score_value_studentt, score_value_normal, atol=2e-3), ( - "StudentT CRPS with high 'df' should match Normal CRPS." + assert torch.allclose(score_value_studentt, score_value_normal, atol=atol), ( + f"StudentT CRPS with high 'df' should match Normal CRPS with atol={atol}." ) diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index 00072ef..4b911ff 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -121,7 +121,9 @@ def crps_analytical_normal( """ accuracy = _accuracy_normal(q, y) dispersion = _dispersion_normal(q, y) - return accuracy - dispersion / 2 + + crps = accuracy - dispersion / 2 + return crps def scrps_analytical_normal( @@ -147,7 +149,9 @@ def scrps_analytical_normal( """ accuracy = _accuracy_normal(q, y) dispersion = _dispersion_normal(q, y) - return accuracy / dispersion + 0.5 * torch.log(dispersion) + + scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) + return scrps def standardized_studentt_cdf_via_scipy( @@ -184,40 +188,76 @@ def standardized_studentt_cdf_via_scipy( return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) -def crps_analytical_studentt( +def _dispersion_studentt( q: StudentT, - y: torch.Tensor, ) -> torch.Tensor: - r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + """Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. - This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. - - For the standardized StudentT distribution: - - $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} - - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$ + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. - where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT - distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + nu, sigma = q.df, q.scale - For the location-scale transformed distribution: + # D = (4σ / (ν-1)) * (Γ(ν/2) / Γ((ν-1)/2))² + # We compute in log space for numerical stability. + log_4 = torch.log(torch.tensor(4.0, dtype=nu.dtype, device=nu.device)) + log_dispersion = ( + log_4 + torch.log(sigma) - torch.log(nu - 1) + 2 * (torch.lgamma(nu / 2) - torch.lgamma((nu - 1) / 2)) + ) - $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$ + return torch.exp(log_dispersion) - where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. - Note: - This formula is only valid for degrees of freedom $\nu > 1$. +def _accuracy_studentt_bollin_wallin(q: StudentT, y: torch.Tensor) -> torch.Tensor: + """Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. See Also: - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2. + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. Args: q: A PyTorch StudentT distribution object, typically a model's output distribution. y: Observed values, of shape (num_samples,). Returns: - CRPS values for each observation, of shape (num_samples,). + Accuracy values for each observation, of shape (num_samples,). + """ + nu, mu, sigma = q.df, q.loc, q.scale + if torch.any(nu <= 1): + raise ValueError("StudentT accuracy requires degrees of freedom > 1") + + # Standardize the observed values. + z = (y - mu) / sigma + + # Compute Beta function term B(ν/2, 1/2). + lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) + beta_term = torch.exp(log_beta_term) + + # z(2 F_ν(z) - 1) + cdf_nu_z = standardized_studentt_cdf_via_scipy(z, nu) + term_1 = z * (2 * cdf_nu_z - 1) + + # 2(ν+z²) / ( ν*B(ν/2, 1/2) ) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) + term_2_factor = (2 * (nu + z**2)) / (nu * beta_term) + term_2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) + cdf_nu_plus_1_term = standardized_studentt_cdf_via_scipy(term_2_cdf_arg, nu + 1) + term_2 = term_2_factor * cdf_nu_plus_1_term + + # Just like for the CRPS, this includes a transformation to and back from standardized values. + # A = σ * [ z(2 F_ν(z) - 1) + 2(ν+z²) / ( ν*B(ν/2, 1/2) ) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) ]. + return sigma * (term_1 + term_2) + + +def _crps_analytical_studentt_jordan( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + + This is a helper function such that we can use it to consistently compute the accuracy term for the SCRPS. """ # Extract degrees of freedom ν, location μ, and scale σ. nu, mu, sigma = q.df, q.loc, q.scale @@ -265,6 +305,59 @@ def crps_analytical_studentt( return crps +def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: + """Compute the consistent accuracy term A by deriving it from the CRPS identity `A = CRPS + 0.5 * D`. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + crps = _crps_analytical_studentt_jordan(q, y) + dispersion = _dispersion_studentt(q) + return crps + 0.5 * dispersion + + +def crps_analytical_studentt( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + + This implements the closed-form formula from Jordan et al. (2019), see Appendix A.2. + + For the standardized StudentT distribution: + + $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} + - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$ + + where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT + distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. + + For the location-scale transformed distribution: + + $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$ + + where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. + + Note: + This formula is only valid for degrees of freedom $\nu > 1$. + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + return _crps_analytical_studentt_jordan(q, y) + + def scrps_analytical_studentt( q: StudentT, y: torch.Tensor, @@ -294,42 +387,11 @@ def scrps_analytical_studentt( Returns: SCRPS values for each observation, of shape (num_samples,). """ - # Extract degrees of freedom ν, location μ, and scale σ. - nu, mu, sigma = q.df, q.loc, q.scale - if torch.any(nu <= 1): + if torch.any(q.df <= 1): raise ValueError("StudentT SCRPS requires degrees of freedom > 1") - # Use the device of y for creating new (intermediate) tensors. - device, dtype = y.device, y.dtype - - # --- Dispersion Term D := E[|X - X'|] = (4σ / (ν-1)) * (Γ(ν/2) / Γ((ν-1)/2))² - # We compute in log space for numerical stability. - log_4 = torch.log(torch.tensor(4.0, dtype=dtype, device=device)) - log_dispersion = ( - log_4 + torch.log(sigma) - torch.log(nu - 1) + 2 * (torch.lgamma(nu / 2) - torch.lgamma((nu - 1) / 2)) - ) - dispersion = torch.exp(log_dispersion) - - # --- Accuracy Term A := E[|X - y|] - # Standardize. - z = (y - mu) / sigma - - # Compute Beta function term B(ν/2, 1/2). - lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=dtype, device=device)) - log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) - beta_term = torch.exp(log_beta_term) - - # Compute components for A = σ * [ z(2F_ν(z) - 1) + (2(ν+z²))/(ν*B(ν/2, 1/2)) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) ]. - # Just like for the CRPS, this includes a transformation to and back from standardized values. - cdf_nu_z = standardized_studentt_cdf_via_scipy(z, nu) - term_A1 = z * (2 * cdf_nu_z - 1) - term_A2_factor = (2 * (nu + z**2)) / (nu * beta_term) - term_A2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) - cdf_nu_plus_1_term = standardized_studentt_cdf_via_scipy(term_A2_cdf_arg, nu + 1) - term_A2 = term_A2_factor * cdf_nu_plus_1_term - accuracy = sigma * (term_A1 + term_A2) - - # --- SCRPS (negatively-oriented) := (A / D) + 0.5 * log(D) - scrps = accuracy / dispersion + 0.5 * log_dispersion + accuracy = _accuracy_studentt_jordan(q, y) + dispersion = _dispersion_studentt(q) + scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) return scrps From 6ae040c3b292ac4660f2ee7e6f191c4161eb49a2 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 21:01:52 +0100 Subject: [PATCH 05/19] WIP copilot running in circles --- tests/test_analytical_crps.py | 51 +++++++++++--- torch_crps/analytical_crps.py | 126 ++++++++++++++++++++++++++++++---- 2 files changed, 156 insertions(+), 21 deletions(-) diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index 1aebe8c..20445ae 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -3,14 +3,17 @@ import pytest import torch from torch.distributions import Normal, StudentT +from typing_extensions import Literal from tests.conftest import needs_cuda from torch_crps import ( crps_analytical, crps_analytical_normal, + crps_analytical_studentt, scrps_analytical, scrps_analytical_normal, ) +from torch_crps.analytical_crps import _crps_analytical_studentt_jordan @pytest.mark.parametrize( @@ -65,13 +68,13 @@ def test_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable[..., "large-neg-mean_medium-var", ], ) -@pytest.mark.parametrize("y", [torch.tensor([-95.0, -80.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50.0])]) -def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, atol: float = 3e-3): +@pytest.mark.parametrize("y", [torch.tensor([-10.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50])]) +@pytest.mark.parametrize("crps_fcn_type", ["CRPS", "SCRPS"], ids=["CRPS", "SCRPS"]) +def test_studentt_convergence_to_normal( + loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, crps_fcn_type: Literal["CRPS", "SCRPS"] +): """Test that for a high degrees of freedom, the StudentT score converges to the Normal score when their standard deviations are matched. - - Note: - This test only works for the CRPS. For the SCRPS, the differences are too big. """ # Create the StudentT distribution with a high degree of freedom. high_df = torch.tensor(1000.0) @@ -84,12 +87,17 @@ def test_studentt_convergence_to_normal(loc: torch.Tensor, scale: torch.Tensor, q_normal = Normal(loc=loc, scale=student_t_std_dev) # Calculate the analytical scores for both. - score_value_studentt = crps_analytical(q_studentt, y) - score_value_normal = crps_analytical(q_normal, y) + if crps_fcn_type == "CRPS": + score_value_studentt = crps_analytical(q_studentt, y) + score_value_normal = crps_analytical(q_normal, y) + else: + score_value_studentt = scrps_analytical(q_studentt, y) + score_value_normal = scrps_analytical(q_normal, y) # Assert that their results are nearly identical. - assert torch.allclose(score_value_studentt, score_value_normal, atol=atol), ( - f"StudentT CRPS with high 'df' should match Normal CRPS with atol={atol}." + # The tolerance can be quite tight now. + assert torch.allclose(score_value_studentt, score_value_normal, atol=3e-3), ( + f"StudentT {crps_fcn_type} with high 'df' should match Normal {crps_fcn_type}." ) @@ -116,3 +124,28 @@ def test_analytical_interface_smoke(q: Any, crps_fcn: Callable[..., torch.Tensor # Not supported, should raise an error. with pytest.raises(NotImplementedError): crps_fcn(q, y) + + +def test_analytical_studentt_consistency(): + """Test if the two ways to compute the CRPS for StudentT distributions give the same result: + + - old method: `_crps_analytical_studentt_jordan` + - new method: `_accuracy_studentt_jordan` and `_dispersion_studentt_jordan` packaged in `crps_analytical_studentt` + """ + torch.manual_seed(0) + + # Create a StudentT distribution. + df = torch.tensor([3.0, 5.0, 10.0]) + loc = torch.tensor([0.0, 1.0, -1.0]) + scale = torch.tensor([1.0, 2.0, 0.5]) + studentt_dist = torch.distributions.StudentT(df=df, loc=loc, scale=scale) + + # Define observed values. + y = torch.tensor([0.5, 2.0, -0.5]) + + # Compute CRPS values. + crps_old = _crps_analytical_studentt_jordan(studentt_dist, y) + crps_new = crps_analytical_studentt(studentt_dist, y) + + # Assert that both methods give the same result. + assert torch.allclose(crps_old, crps_new, atol=1e-4), "CRPS values from both methods should match." diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index 4b911ff..29b5634 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -102,6 +102,35 @@ def _dispersion_normal( return 2 * q.scale / sqrt_pi +def _crps_analytical_normal_gneiting( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the analytical CRPS assuming a normal distribution. + + See Also: + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 + Equation (5) for the analytical formula for CRPS of Normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + # Compute standard normal CDF and PDF. + z = (y - q.loc) / q.scale # standardize + standard_normal = torch.distributions.Normal(0, 1) + phi_z = standard_normal.cdf(z) # Φ(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) + + # Analytical CRPS formula. + crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) + + return crps + + def crps_analytical_normal( q: Normal, y: torch.Tensor, @@ -188,11 +217,47 @@ def standardized_studentt_cdf_via_scipy( return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) -def _dispersion_studentt( +def _dispersion_studentt_jordan( + q: StudentT, +) -> torch.Tensor: + r"""Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. + + $$ + D = 2 \sigma \left( \frac{2\sqrt{\nu}}{\nu-1} \frac{\Gamma(\nu/2)}{\Gamma((\nu-1)/2)\Gamma(1/2)} \right) + $$ + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + nu, sigma = q.df, q.scale + + # D = 2σ * (2 * sqrt(v) / (v - 1)) * Γ(v/2) / (Γ((v-1)/2) * Γ(0.5)) + # We compute in log space for numerical stability (prevent under- or overflow). + lgamma_nu_half = torch.lgamma(nu / 2) + lgamma_nu_minus_1_half = torch.lgamma((nu - 1) / 2) + # lgamma_half = torch.log(torch.sqrt(torch.tensor(torch.pi, device=nu.device, dtype=nu.dtype))) # Γ(0.5) = sqrt(π) + # TODO + lgamma_half = torch.lgamma(torch.tensor(0.5, device=nu.device)) + gamma_term = torch.exp(lgamma_nu_half - lgamma_nu_minus_1_half - lgamma_half) + + dispersion = 2 * sigma * (2 * torch.sqrt(nu) / (nu - 1)) * gamma_term + return dispersion + + +def _dispersion_studentt_bollin( q: StudentT, ) -> torch.Tensor: """Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + Args: q: A PyTorch StudentT distribution object, typically a model's output distribution. @@ -202,7 +267,7 @@ def _dispersion_studentt( nu, sigma = q.df, q.scale # D = (4σ / (ν-1)) * (Γ(ν/2) / Γ((ν-1)/2))² - # We compute in log space for numerical stability. + # We compute in log space for numerical stability (prevent under- or overflow). log_4 = torch.log(torch.tensor(4.0, dtype=nu.dtype, device=nu.device)) log_dispersion = ( log_4 + torch.log(sigma) - torch.log(nu - 1) + 2 * (torch.lgamma(nu / 2) - torch.lgamma((nu - 1) / 2)) @@ -211,7 +276,41 @@ def _dispersion_studentt( return torch.exp(log_dispersion) -def _accuracy_studentt_bollin_wallin(q: StudentT, y: torch.Tensor) -> torch.Tensor: +def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: + r"""Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. + + $$ + A = \sigma \left[ z(2F_{t,\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{t,\nu}(z) \right] + $$ + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + nu, mu, sigma = q.df, q.loc, q.scale + + # Standardize, and create standard StudentT distribution for CDF and PDF. + z = (y - mu) / sigma + standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma)) + + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) + + # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ] + accuracy_unscaled = z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) + + accuracy = sigma * accuracy_unscaled + return accuracy + + +def _accuracy_studentt_bollin(q: StudentT, y: torch.Tensor) -> torch.Tensor: """Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. See Also: @@ -225,8 +324,6 @@ def _accuracy_studentt_bollin_wallin(q: StudentT, y: torch.Tensor) -> torch.Tens Accuracy values for each observation, of shape (num_samples,). """ nu, mu, sigma = q.df, q.loc, q.scale - if torch.any(nu <= 1): - raise ValueError("StudentT accuracy requires degrees of freedom > 1") # Standardize the observed values. z = (y - mu) / sigma @@ -304,8 +401,6 @@ def _crps_analytical_studentt_jordan( return crps - -def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: """Compute the consistent accuracy term A by deriving it from the CRPS identity `A = CRPS + 0.5 * D`. Args: @@ -316,7 +411,7 @@ def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: Accuracy values for each observation, of shape (num_samples,). """ crps = _crps_analytical_studentt_jordan(q, y) - dispersion = _dispersion_studentt(q) + dispersion = _dispersion_studentt_jordan(q) return crps + 0.5 * dispersion @@ -346,7 +441,7 @@ def crps_analytical_studentt( This formula is only valid for degrees of freedom $\nu > 1$. See Also: - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019; Appendix A.2. + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. Args: q: A PyTorch StudentT distribution object, typically a model's output distribution. @@ -355,7 +450,14 @@ def crps_analytical_studentt( Returns: CRPS values for each observation, of shape (num_samples,). """ - return _crps_analytical_studentt_jordan(q, y) + if torch.any(q.df <= 1): + raise ValueError("StudentT SCRPS requires degrees of freedom > 1") + + accuracy = _accuracy_studentt_jordan(q, y) + dispersion = _dispersion_studentt_jordan(q) + + crps = accuracy - dispersion / 2 + return crps def scrps_analytical_studentt( @@ -391,7 +493,7 @@ def scrps_analytical_studentt( raise ValueError("StudentT SCRPS requires degrees of freedom > 1") accuracy = _accuracy_studentt_jordan(q, y) - dispersion = _dispersion_studentt(q) + dispersion = _dispersion_studentt_jordan(q) - scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) + scrps = accuracy / dispersion + torch.log(dispersion) / 2 return scrps From e3491774d8914a5d4df27ccf61848a8151692c01 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 21:08:32 +0100 Subject: [PATCH 06/19] Promising direction --- torch_crps/analytical_crps.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index 29b5634..deee90a 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -237,16 +237,23 @@ def _dispersion_studentt_jordan( """ nu, sigma = q.df, q.scale - # D = 2σ * (2 * sqrt(v) / (v - 1)) * Γ(v/2) / (Γ((v-1)/2) * Γ(0.5)) - # We compute in log space for numerical stability (prevent under- or overflow). - lgamma_nu_half = torch.lgamma(nu / 2) - lgamma_nu_minus_1_half = torch.lgamma((nu - 1) / 2) - # lgamma_half = torch.log(torch.sqrt(torch.tensor(torch.pi, device=nu.device, dtype=nu.dtype))) # Γ(0.5) = sqrt(π) - # TODO - lgamma_half = torch.lgamma(torch.tensor(0.5, device=nu.device)) - gamma_term = torch.exp(lgamma_nu_half - lgamma_nu_minus_1_half - lgamma_half) - - dispersion = 2 * sigma * (2 * torch.sqrt(nu) / (nu - 1)) * gamma_term + # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac, + # where beta_frac = B(1/2, v - 1/2) / B(1/2, v/2)^2 + log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_gamma_df_minus_half = torch.lgamma(nu - 0.5) + log_gamma_df_half = torch.lgamma(nu / 2) + log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + + log_beta_ratio = ( + log_gamma_half + + log_gamma_df_minus_half + - torch.lgamma(nu) + - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) + ) + beta_frac = torch.exp(log_beta_ratio) + + dispersion_constant = (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac + dispersion = sigma * 2 * dispersion_constant return dispersion From cb4bfa85af00e2b423fd17732b99be41a78eac21 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 21:23:21 +0100 Subject: [PATCH 07/19] Tests running --- tests/conftest.py | 68 ++++++++++++++ tests/test_analytical_crps.py | 8 +- torch_crps/analytical_crps.py | 163 ++++------------------------------ 3 files changed, 90 insertions(+), 149 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 633a7cd..42387a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,9 @@ import pytest import torch +from torch.distributions import StudentT + +from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy results_dir = Path(__file__).parent / "results" results_dir.mkdir(parents=True, exist_ok=True) @@ -49,3 +52,68 @@ def case_batched_3d(): "y": torch.randn(2, 3) * 10 + 50, "expected_shape": torch.Size([2, 3]), } + + +def _crps_analytical_studentt_jordan( + q: StudentT, + y: torch.Tensor, +) -> torch.Tensor: + r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. + + This is the previous implementation of the analytical CRPS for StudentT distributions.. It is provided here for + testing and comparison purposes. + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + # Extract degrees of freedom ν, location μ, and scale σ. + nu, mu, sigma = q.df, q.loc, q.scale + if torch.any(nu <= 1): + raise ValueError("StudentT CRPS requires degrees of freedom > 1") + + # Standardize, and create standard StudentT distribution for CDF and PDF. + z = (y - mu) / sigma + standard_t = torch.distributions.StudentT(nu, loc=0, scale=1) + + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) + + # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 + # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) + # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) + # For numerical stability, we compute in log space. + log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) + log_gamma_df_minus_half = torch.lgamma(nu - 0.5) + log_gamma_df_half = torch.lgamma(nu / 2) + log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + + # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) + # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) + # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) + log_beta_ratio = ( + log_gamma_half + + log_gamma_df_minus_half + - torch.lgamma(nu) + - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) + ) + beta_frac = torch.exp(log_beta_ratio) + + # Compute the CRPS for standardized values. + crps_standard = ( + z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac + ) + + # Apply location-scale transformation CRPS(F_{ν,μ,σ}, y) = σ * CRPS(F_{ν}, z) with z = (y - μ) / σ. + crps = sigma * crps_standard + + return crps diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index 20445ae..351cd3f 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -5,7 +5,7 @@ from torch.distributions import Normal, StudentT from typing_extensions import Literal -from tests.conftest import needs_cuda +from tests.conftest import _crps_analytical_studentt_jordan, needs_cuda from torch_crps import ( crps_analytical, crps_analytical_normal, @@ -13,7 +13,6 @@ scrps_analytical, scrps_analytical_normal, ) -from torch_crps.analytical_crps import _crps_analytical_studentt_jordan @pytest.mark.parametrize( @@ -68,7 +67,7 @@ def test_analytical_normal_batched_smoke(use_cuda: bool, crps_fcn: Callable[..., "large-neg-mean_medium-var", ], ) -@pytest.mark.parametrize("y", [torch.tensor([-10.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50])]) +@pytest.mark.parametrize("y", [torch.tensor([-100.0, -10.0, -1.0, 0.0, 0.5, 2.0, 5.0, 50.0])]) @pytest.mark.parametrize("crps_fcn_type", ["CRPS", "SCRPS"], ids=["CRPS", "SCRPS"]) def test_studentt_convergence_to_normal( loc: torch.Tensor, scale: torch.Tensor, y: torch.Tensor, crps_fcn_type: Literal["CRPS", "SCRPS"] @@ -96,7 +95,8 @@ def test_studentt_convergence_to_normal( # Assert that their results are nearly identical. # The tolerance can be quite tight now. - assert torch.allclose(score_value_studentt, score_value_normal, atol=3e-3), ( + atol = 6e-3 if crps_fcn_type == "CRPS" else 2e-2 + assert torch.allclose(score_value_studentt, score_value_normal, atol=atol), ( f"StudentT {crps_fcn_type} with high 'df' should match Normal {crps_fcn_type}." ) diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical_crps.py index deee90a..946c04c 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical_crps.py @@ -217,15 +217,11 @@ def standardized_studentt_cdf_via_scipy( return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) -def _dispersion_studentt_jordan( +def _dispersion_studentt( q: StudentT, ) -> torch.Tensor: r"""Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. - $$ - D = 2 \sigma \left( \frac{2\sqrt{\nu}}{\nu-1} \frac{\Gamma(\nu/2)}{\Gamma((\nu-1)/2)\Gamma(1/2)} \right) - $$ - See Also: Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. @@ -237,13 +233,20 @@ def _dispersion_studentt_jordan( """ nu, sigma = q.df, q.scale - # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac, - # where beta_frac = B(1/2, v - 1/2) / B(1/2, v/2)^2 + # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 + # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) + # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / + # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 + # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) + # For numerical stability, we compute in log space. log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) log_gamma_df_minus_half = torch.lgamma(nu - 0.5) log_gamma_df_half = torch.lgamma(nu / 2) log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) + # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) + # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) + # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) log_beta_ratio = ( log_gamma_half + log_gamma_df_minus_half @@ -252,38 +255,13 @@ def _dispersion_studentt_jordan( ) beta_frac = torch.exp(log_beta_ratio) - dispersion_constant = (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac - dispersion = sigma * 2 * dispersion_constant - return dispersion - - -def _dispersion_studentt_bollin( - q: StudentT, -) -> torch.Tensor: - """Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. - - See Also: - Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. - - Args: - q: A PyTorch StudentT distribution object, typically a model's output distribution. - - Returns: - Dispersion values for each observation, of shape (num_samples,). - """ - nu, sigma = q.df, q.scale - - # D = (4σ / (ν-1)) * (Γ(ν/2) / Γ((ν-1)/2))² - # We compute in log space for numerical stability (prevent under- or overflow). - log_4 = torch.log(torch.tensor(4.0, dtype=nu.dtype, device=nu.device)) - log_dispersion = ( - log_4 + torch.log(sigma) - torch.log(nu - 1) + 2 * (torch.lgamma(nu / 2) - torch.lgamma((nu - 1) / 2)) - ) + # D = 2σ * 2 * torch.sqrt(v) / (v - 1) * beta_frac + dispersion = 2 * sigma * 2 * torch.sqrt(nu) / (nu - 1) * beta_frac - return torch.exp(log_dispersion) + return dispersion -def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: +def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: r"""Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. $$ @@ -317,111 +295,6 @@ def _accuracy_studentt_jordan(q: StudentT, y: torch.Tensor) -> torch.Tensor: return accuracy -def _accuracy_studentt_bollin(q: StudentT, y: torch.Tensor) -> torch.Tensor: - """Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. - - See Also: - Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. - - Args: - q: A PyTorch StudentT distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - Accuracy values for each observation, of shape (num_samples,). - """ - nu, mu, sigma = q.df, q.loc, q.scale - - # Standardize the observed values. - z = (y - mu) / sigma - - # Compute Beta function term B(ν/2, 1/2). - lgamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) - log_beta_term = torch.lgamma(nu / 2) + lgamma_half - torch.lgamma((nu + 1) / 2) - beta_term = torch.exp(log_beta_term) - - # z(2 F_ν(z) - 1) - cdf_nu_z = standardized_studentt_cdf_via_scipy(z, nu) - term_1 = z * (2 * cdf_nu_z - 1) - - # 2(ν+z²) / ( ν*B(ν/2, 1/2) ) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) - term_2_factor = (2 * (nu + z**2)) / (nu * beta_term) - term_2_cdf_arg = z * torch.sqrt((nu + 1) / (nu + z**2)) - cdf_nu_plus_1_term = standardized_studentt_cdf_via_scipy(term_2_cdf_arg, nu + 1) - term_2 = term_2_factor * cdf_nu_plus_1_term - - # Just like for the CRPS, this includes a transformation to and back from standardized values. - # A = σ * [ z(2 F_ν(z) - 1) + 2(ν+z²) / ( ν*B(ν/2, 1/2) ) * F_{ν+1}(z * sqrt( (ν+1)/(ν+z²)) ) ]. - return sigma * (term_1 + term_2) - - -def _crps_analytical_studentt_jordan( - q: StudentT, - y: torch.Tensor, -) -> torch.Tensor: - r"""Compute the (negatively-oriented) CRPS in closed-form assuming a StudentT distribution. - - This is a helper function such that we can use it to consistently compute the accuracy term for the SCRPS. - """ - # Extract degrees of freedom ν, location μ, and scale σ. - nu, mu, sigma = q.df, q.loc, q.scale - if torch.any(nu <= 1): - raise ValueError("StudentT CRPS requires degrees of freedom > 1") - - # Standardize, and create standard StudentT distribution for CDF and PDF. - z = (y - mu) / sigma - standard_t = torch.distributions.StudentT(nu, loc=0, scale=1) - - # Compute standardized CDF F_ν(z) and PDF f_ν(z). - cdf_z = standardized_studentt_cdf_via_scipy(z, nu) - pdf_z = torch.exp(standard_t.log_prob(z)) - - # Compute the beta function ratio: B(1/2, ν - 1/2) / B(1/2, ν/2)^2 - # Using the relationship: B(a,b) = Gamma(a) * Gamma(b) / Gamma(a+b) - # B(1/2, ν - 1/2) / B(1/2, ν/2)^2 = ( Gamma(1/2) * Gamma(ν-1/2) / Gamma(ν) ) / - # ( Gamma(1/2) * Gamma(ν/2) / Gamma(ν/2 + 1/2) )^2 - # Simplifying to Gamma(ν - 1/2) Gamma(ν/2 + 1/2)^2 / ( Gamma(ν)Gamma(ν/2)^2 ) - # For numerical stability, we compute in log space. - log_gamma_half = torch.lgamma(torch.tensor(0.5, dtype=nu.dtype, device=nu.device)) - log_gamma_df_minus_half = torch.lgamma(nu - 0.5) - log_gamma_df_half = torch.lgamma(nu / 2) - log_gamma_df_half_plus_half = torch.lgamma(nu / 2 + 0.5) - - # log[B(1/2, ν-1/2)] = log Gamma(1/2) + log Gamma(ν-1/2) - log Gamma(ν) - # log[B(1/2, ν/2)] = log Gamma(1/2) + log Gamma(ν/2) - log Gamma(ν/2 + 1/2) - # log[B(1/2, ν-1/2) / B(1/2, ν/2)^2] = log B(1/2, ν-1/2) - 2*log B(1/2, ν/2) - log_beta_ratio = ( - log_gamma_half - + log_gamma_df_minus_half - - torch.lgamma(nu) - - 2 * (log_gamma_half + log_gamma_df_half - log_gamma_df_half_plus_half) - ) - beta_frac = torch.exp(log_beta_ratio) - - # Compute the CRPS for standardized values. - crps_standard = ( - z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) - (2 * torch.sqrt(nu) / (nu - 1)) * beta_frac - ) - - # Apply location-scale transformation CRPS(F_{ν,μ,σ}, y) = σ * CRPS(F_{ν}, z) with z = (y - μ) / σ. - crps = sigma * crps_standard - - return crps - - """Compute the consistent accuracy term A by deriving it from the CRPS identity `A = CRPS + 0.5 * D`. - - Args: - q: A PyTorch StudentT distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - Accuracy values for each observation, of shape (num_samples,). - """ - crps = _crps_analytical_studentt_jordan(q, y) - dispersion = _dispersion_studentt_jordan(q) - return crps + 0.5 * dispersion - - def crps_analytical_studentt( q: StudentT, y: torch.Tensor, @@ -460,8 +333,8 @@ def crps_analytical_studentt( if torch.any(q.df <= 1): raise ValueError("StudentT SCRPS requires degrees of freedom > 1") - accuracy = _accuracy_studentt_jordan(q, y) - dispersion = _dispersion_studentt_jordan(q) + accuracy = _accuracy_studentt(q, y) + dispersion = _dispersion_studentt(q) crps = accuracy - dispersion / 2 return crps @@ -499,8 +372,8 @@ def scrps_analytical_studentt( if torch.any(q.df <= 1): raise ValueError("StudentT SCRPS requires degrees of freedom > 1") - accuracy = _accuracy_studentt_jordan(q, y) - dispersion = _dispersion_studentt_jordan(q) + accuracy = _accuracy_studentt(q, y) + dispersion = _dispersion_studentt(q) scrps = accuracy / dispersion + torch.log(dispersion) / 2 return scrps From 4f3d71d2ac43396c78ae5a86fad00351b0045027 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Thu, 29 Jan 2026 21:44:57 +0100 Subject: [PATCH 08/19] Refactored --- examples/time_series_learning.py | 2 +- tests/conftest.py | 35 ++- tests/test_analytical_crps.py | 40 ++- tests/test_integral_crps.py | 4 +- torch_crps/__init__.py | 11 +- torch_crps/analytical/__init__.py | 3 + torch_crps/analytical/dispatch.py | 68 +++++ torch_crps/analytical/normal.py | 92 +++++++ .../studentt.py} | 251 +++--------------- torch_crps/integral_crps.py | 2 +- 10 files changed, 269 insertions(+), 239 deletions(-) create mode 100644 torch_crps/analytical/__init__.py create mode 100644 torch_crps/analytical/dispatch.py create mode 100644 torch_crps/analytical/normal.py rename torch_crps/{analytical_crps.py => analytical/studentt.py} (52%) diff --git a/examples/time_series_learning.py b/examples/time_series_learning.py index 562677e..0f3fa9a 100644 --- a/examples/time_series_learning.py +++ b/examples/time_series_learning.py @@ -15,7 +15,7 @@ import torch from matplotlib.figure import Figure -from torch_crps import crps_analytical +from torch_crps.analytical import crps_analytical EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent) diff --git a/tests/conftest.py b/tests/conftest.py index 42387a8..ed0ce0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,9 @@ import pytest import torch -from torch.distributions import StudentT +from torch.distributions import Normal, StudentT -from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy +from torch_crps.analytical.studentt import standardized_studentt_cdf_via_scipy results_dir = Path(__file__).parent / "results" results_dir.mkdir(parents=True, exist_ok=True) @@ -54,7 +54,36 @@ def case_batched_3d(): } -def _crps_analytical_studentt_jordan( +def crps_analytical_normal_gneiting( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the analytical CRPS assuming a normal distribution. + + See Also: + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 + Equation (5) for the analytical formula for CRPS of Normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + # Compute standard normal CDF and PDF. + z = (y - q.loc) / q.scale + standard_normal = torch.distributions.Normal(0, 1) + phi_z = standard_normal.cdf(z) # Φ(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) + + # Analytical CRPS formula. + crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) + + return crps + + +def crps_analytical_studentt_jordan( q: StudentT, y: torch.Tensor, ) -> torch.Tensor: diff --git a/tests/test_analytical_crps.py b/tests/test_analytical_crps.py index 351cd3f..69ebb91 100644 --- a/tests/test_analytical_crps.py +++ b/tests/test_analytical_crps.py @@ -5,13 +5,11 @@ from torch.distributions import Normal, StudentT from typing_extensions import Literal -from tests.conftest import _crps_analytical_studentt_jordan, needs_cuda -from torch_crps import ( - crps_analytical, - crps_analytical_normal, +from tests.conftest import crps_analytical_normal_gneiting, crps_analytical_studentt_jordan, needs_cuda +from torch_crps.analytical import crps_analytical, scrps_analytical +from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal +from torch_crps.analytical.studentt import ( crps_analytical_studentt, - scrps_analytical, - scrps_analytical_normal, ) @@ -126,7 +124,31 @@ def test_analytical_interface_smoke(q: Any, crps_fcn: Callable[..., torch.Tensor crps_fcn(q, y) -def test_analytical_studentt_consistency(): +def test_analytical_crps_normal_consistency(): + """Test if the two ways to compute the CRPS for normal distributions give the same result: + + - old method: `crps_analytical_normal_gneiting` + - new method: `_accuracy_normal_gneiting` and `_dispersion_normal_gneiting` packaged in `crps_analytical_normal` + """ + torch.manual_seed(0) + + # Create a Normal distribution. + loc = torch.tensor([0.0, 1.0, -1.0]) + scale = torch.tensor([1.0, 2.0, 0.5]) + normal_dist = torch.distributions.Normal(loc=loc, scale=scale) + + # Define observed values. + y = torch.tensor([0.5, 2.0, -0.5]) + + # Compute CRPS values. + crps_old = crps_analytical_normal_gneiting(normal_dist, y) + crps_new = crps_analytical_normal(normal_dist, y) + + # Assert that both methods give the same result. + assert torch.allclose(crps_old, crps_new, atol=1e-6), "CRPS values from both methods should match." + + +def test_analytical_crps_studentt_consistency(): """Test if the two ways to compute the CRPS for StudentT distributions give the same result: - old method: `_crps_analytical_studentt_jordan` @@ -144,8 +166,8 @@ def test_analytical_studentt_consistency(): y = torch.tensor([0.5, 2.0, -0.5]) # Compute CRPS values. - crps_old = _crps_analytical_studentt_jordan(studentt_dist, y) + crps_old = crps_analytical_studentt_jordan(studentt_dist, y) crps_new = crps_analytical_studentt(studentt_dist, y) # Assert that both methods give the same result. - assert torch.allclose(crps_old, crps_new, atol=1e-4), "CRPS values from both methods should match." + assert torch.allclose(crps_old, crps_new, atol=1e-6), "CRPS values from both methods should match." diff --git a/tests/test_integral_crps.py b/tests/test_integral_crps.py index 65a197d..4d97c05 100644 --- a/tests/test_integral_crps.py +++ b/tests/test_integral_crps.py @@ -2,7 +2,9 @@ import torch from tests.conftest import needs_cuda -from torch_crps import crps_analytical_normal, crps_analytical_studentt, crps_integral +from torch_crps import crps_integral +from torch_crps.analytical.normal import crps_analytical_normal +from torch_crps.analytical.studentt import crps_analytical_studentt @pytest.mark.parametrize( diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index c50cb68..c8289a4 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -1,11 +1,6 @@ -from .analytical_crps import ( - crps_analytical, - crps_analytical_normal, - crps_analytical_studentt, - scrps_analytical, - scrps_analytical_normal, - scrps_analytical_studentt, -) +from .analytical import crps_analytical, scrps_analytical +from .analytical.normal import crps_analytical_normal, scrps_analytical_normal +from .analytical.studentt import crps_analytical_studentt, scrps_analytical_studentt from .ensemble_crps import crps_ensemble, crps_ensemble_naive from .integral_crps import crps_integral from .normalization import normalize_by_observation diff --git a/torch_crps/analytical/__init__.py b/torch_crps/analytical/__init__.py new file mode 100644 index 0000000..4d864e8 --- /dev/null +++ b/torch_crps/analytical/__init__.py @@ -0,0 +1,3 @@ +from .dispatch import crps_analytical, scrps_analytical # noqa: F401 +from .normal import crps_analytical_normal, scrps_analytical_normal # noqa: F401 +from .studentt import crps_analytical_studentt, scrps_analytical_studentt # noqa: F401 diff --git a/torch_crps/analytical/dispatch.py b/torch_crps/analytical/dispatch.py new file mode 100644 index 0000000..e1ae0d1 --- /dev/null +++ b/torch_crps/analytical/dispatch.py @@ -0,0 +1,68 @@ +import torch +from torch.distributions import Distribution, Normal, StudentT + +from torch_crps.analytical.normal import crps_analytical_normal, scrps_analytical_normal +from torch_crps.analytical.studentt import ( + crps_analytical_studentt, + scrps_analytical_studentt, +) + + +def crps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form. + + Note: + The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. + There exists analytical solutions for other distributions, but they are not implemented, yet. + Feel free to create an issue or pull request. + + Args: + q: A PyTorch distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + if isinstance(q, Normal): + return crps_analytical_normal(q, y) + elif isinstance(q, StudentT): + return crps_analytical_studentt(q, y) + else: + raise NotImplementedError( + f"Detected distribution of type {type(q)}, but there are only analytical solutions for " + "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " + "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." + ) + + +def scrps_analytical( + q: Distribution, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented, i.e., lower is better) scaled CRPS (SCRPS) in closed-form. + + Note: + The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. + There exists analytical solutions for other distributions, but they are not implemented, yet. + Feel free to create an issue or pull request. + + Args: + q: A PyTorch distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + if isinstance(q, Normal): + return scrps_analytical_normal(q, y) + elif isinstance(q, StudentT): + return scrps_analytical_studentt(q, y) + else: + raise NotImplementedError( + f"Detected distribution of type {type(q)}, but there are only analytical solutions for " + "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " + "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need." + ) diff --git a/torch_crps/analytical/normal.py b/torch_crps/analytical/normal.py new file mode 100644 index 0000000..26d771f --- /dev/null +++ b/torch_crps/analytical/normal.py @@ -0,0 +1,92 @@ +import torch +from torch.distributions import Normal + + +def _accuracy_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute accuracy term A = E[|X - y|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + z = (y - q.loc) / q.scale + standard_normal = torch.distributions.Normal(0, 1) + + cdf_z = standard_normal.cdf(z) + pdf_z = torch.exp(standard_normal.log_prob(z)) + + return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) + + +def _dispersion_normal( + q: Normal, +) -> torch.Tensor: + """Compute dispersion term D = E[|X - X'|] for a normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + + Returns: + Dispersion values for each observation, of shape (num_samples,). + """ + sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=q.loc.device, dtype=q.loc.dtype)) + + return 2 * q.scale / sqrt_pi + + +def crps_analytical_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. + + See Also: + Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. + Equation (5) for the analytical formula for CRPS of Normal distribution. + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + CRPS values for each observation, of shape (num_samples,). + """ + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q) + + crps = accuracy - dispersion / 2 + return crps + + +def scrps_analytical_normal( + q: Normal, + y: torch.Tensor, +) -> torch.Tensor: + """Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a normal distribution. + + Note: + In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + Equation (3) for the definition of the SCRPS. + Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution + + Args: + q: A PyTorch Normal distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + SCRPS values for each observation, of shape (num_samples,). + """ + accuracy = _accuracy_normal(q, y) + dispersion = _dispersion_normal(q) + + scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) + return scrps diff --git a/torch_crps/analytical_crps.py b/torch_crps/analytical/studentt.py similarity index 52% rename from torch_crps/analytical_crps.py rename to torch_crps/analytical/studentt.py index 946c04c..4407468 100644 --- a/torch_crps/analytical_crps.py +++ b/torch_crps/analytical/studentt.py @@ -1,186 +1,5 @@ import torch -from torch.distributions import Distribution, Normal, StudentT - - -def crps_analytical( - q: Distribution, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the (negatively-oriented, i.e., lower is better) CRPS in closed-form. - - Note: - The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. - There exists analytical solutions for other distributions, but they are not implemented, yet. - Feel free to create an issue or pull request. - - Args: - q: A PyTorch distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - if isinstance(q, Normal): - return crps_analytical_normal(q, y) - elif isinstance(q, StudentT): - return crps_analytical_studentt(q, y) - else: - raise NotImplementedError( - f"Detected distribution of type {type(q)}, but there are only analytical solutions for " - "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " - "`torch_crps.crps_integral` or `torch_crps.crps_ensemble`, or create an issue for the method you need." - ) - - -def scrps_analytical( - q: Distribution, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the (negatively-oriented, i.e., lower is better) scaled CRPS (SCRPS) in closed-form. - - Note: - The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. - There exists analytical solutions for other distributions, but they are not implemented, yet. - Feel free to create an issue or pull request. - - Args: - q: A PyTorch distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - SCRPS values for each observation, of shape (num_samples,). - """ - if isinstance(q, Normal): - return scrps_analytical_normal(q, y) - elif isinstance(q, StudentT): - return scrps_analytical_studentt(q, y) - else: - raise NotImplementedError( - f"Detected distribution of type {type(q)}, but there are only analytical solutions for " - "`torch.distributions.Normal` or `torch.distributions.StudentT`. Either use an alternative method, e.g. " - "`torch_crps.scrps_integral` or `torch_crps.scrps_ensemble`, or create an issue for the method you need." - ) - - -def _accuracy_normal( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute accuracy term A = E[|X - y|] for a normal distribution. - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - Accuracy values for each observation, of shape (num_samples,). - """ - z = (y - q.loc) / q.scale - standard_normal = torch.distributions.Normal(0, 1) - - cdf_z = standard_normal.cdf(z) - pdf_z = torch.exp(standard_normal.log_prob(z)) - - return q.scale * (z * (2 * cdf_z - 1) + 2 * pdf_z) - - -def _dispersion_normal( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute dispersion term D = E[|X - X'|] for a normal distribution. - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - Dispersion values for each observation, of shape (num_samples,). - """ - sqrt_pi = torch.sqrt(torch.tensor(torch.pi, device=y.device, dtype=y.dtype)) - - return 2 * q.scale / sqrt_pi - - -def _crps_analytical_normal_gneiting( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the analytical CRPS assuming a normal distribution. - - See Also: - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 - Equation (5) for the analytical formula for CRPS of Normal distribution. - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - # Compute standard normal CDF and PDF. - z = (y - q.loc) / q.scale # standardize - standard_normal = torch.distributions.Normal(0, 1) - phi_z = standard_normal.cdf(z) # Φ(z) - pdf_z = torch.exp(standard_normal.log_prob(z)) # φ(z) - - # Analytical CRPS formula. - crps = q.scale * (z * (2 * phi_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi))) - - return crps - - -def crps_analytical_normal( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the (negatively-oriented) CRPS in closed-form assuming a normal distribution. - - See Also: - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007. - Equation (5) for the analytical formula for CRPS of Normal distribution. - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - CRPS values for each observation, of shape (num_samples,). - """ - accuracy = _accuracy_normal(q, y) - dispersion = _dispersion_normal(q, y) - - crps = accuracy - dispersion / 2 - return crps - - -def scrps_analytical_normal( - q: Normal, - y: torch.Tensor, -) -> torch.Tensor: - """Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a normal distribution. - - Note: - In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. - - See Also: - Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. - Equation (3) for the definition of the SCRPS. - Appendix A.1 for the component formulas (Accuracy and Dispersion) for the Normal distribution - - Args: - q: A PyTorch Normal distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - SCRPS values for each observation, of shape (num_samples,). - """ - accuracy = _accuracy_normal(q, y) - dispersion = _dispersion_normal(q, y) - - scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) - return scrps +from torch.distributions import StudentT def standardized_studentt_cdf_via_scipy( @@ -217,6 +36,40 @@ def standardized_studentt_cdf_via_scipy( return torch.from_numpy(cdf_z_np).to(device=z.device, dtype=z.dtype) +def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: + r"""Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. + + $$ + A = \sigma \left[ z(2F_{\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{\nu}(z) \right] + $$ + + See Also: + Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. + + Args: + q: A PyTorch StudentT distribution object, typically a model's output distribution. + y: Observed values, of shape (num_samples,). + + Returns: + Accuracy values for each observation, of shape (num_samples,). + """ + nu, mu, sigma = q.df, q.loc, q.scale + + # Standardize, and create standard StudentT distribution for CDF and PDF. + z = (y - mu) / sigma + standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma)) + + # Compute standardized CDF F_ν(z) and PDF f_ν(z). + cdf_z = standardized_studentt_cdf_via_scipy(z, nu) + pdf_z = torch.exp(standard_t.log_prob(z)) + + # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ] + accuracy_unscaled = z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) + + accuracy = sigma * accuracy_unscaled + return accuracy + + def _dispersion_studentt( q: StudentT, ) -> torch.Tensor: @@ -261,40 +114,6 @@ def _dispersion_studentt( return dispersion -def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: - r"""Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. - - $$ - A = \sigma \left[ z(2F_{t,\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{t,\nu}(z) \right] - $$ - - See Also: - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. - - Args: - q: A PyTorch StudentT distribution object, typically a model's output distribution. - y: Observed values, of shape (num_samples,). - - Returns: - Accuracy values for each observation, of shape (num_samples,). - """ - nu, mu, sigma = q.df, q.loc, q.scale - - # Standardize, and create standard StudentT distribution for CDF and PDF. - z = (y - mu) / sigma - standard_t = StudentT(nu, loc=torch.zeros_like(mu), scale=torch.ones_like(sigma)) - - # Compute standardized CDF F_ν(z) and PDF f_ν(z). - cdf_z = standardized_studentt_cdf_via_scipy(z, nu) - pdf_z = torch.exp(standard_t.log_prob(z)) - - # A = sigma * [z * (2*F(z) - 1) + 2*f(z) * (v + z^2) / (v-1) ] - accuracy_unscaled = z * (2 * cdf_z - 1) + 2 * pdf_z * (nu + z**2) / (nu - 1) - - accuracy = sigma * accuracy_unscaled - return accuracy - - def crps_analytical_studentt( q: StudentT, y: torch.Tensor, diff --git a/torch_crps/integral_crps.py b/torch_crps/integral_crps.py index 86cd898..4d40e7c 100644 --- a/torch_crps/integral_crps.py +++ b/torch_crps/integral_crps.py @@ -1,7 +1,7 @@ import torch from torch.distributions import Distribution, StudentT -from torch_crps.analytical_crps import standardized_studentt_cdf_via_scipy +from torch_crps.analytical.studentt import standardized_studentt_cdf_via_scipy def crps_integral( From 7e1e1ea564f9fe7967bd44df255345b619b00b3b Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 09:21:19 +0100 Subject: [PATCH 09/19] Use abstract interface --- torch_crps/abstract.py | 27 +++++++++++++++++++++++++++ torch_crps/analytical/normal.py | 8 ++++---- torch_crps/analytical/studentt.py | 28 +++++++++++++++++----------- torch_crps/ensemble_crps.py | 6 +++--- torch_crps/integral_crps.py | 4 ++-- torch_crps/normalization.py | 4 ++-- torch_crps/utils.py | 0 7 files changed, 55 insertions(+), 22 deletions(-) create mode 100644 torch_crps/abstract.py delete mode 100644 torch_crps/utils.py diff --git a/torch_crps/abstract.py b/torch_crps/abstract.py new file mode 100644 index 0000000..1c387ab --- /dev/null +++ b/torch_crps/abstract.py @@ -0,0 +1,27 @@ +import torch + + +def crps_abstract(accuracy: torch.Tensor, dispersion: torch.Tensor) -> torch.Tensor: + """High-level function to compute the CRPS from the accuracy and dispersion terms. + + Args: + accuracy: The accuracy term A, independent of the methods used to compute it, of shape (*batch_shape,). + dispersion: The dispersion term D, independent of the methods used to compute it, of shape (*batch_shape,). + + Returns: + The CRPS value for each forecast in the batch, of shape (*batch_shape,). + """ + return accuracy - 0.5 * dispersion + + +def scrps_abstract(accuracy: torch.Tensor, dispersion: torch.Tensor) -> torch.Tensor: + """High-level function to compute the SCRPS from the accuracy and dispersion terms. + + Args: + accuracy: The accuracy term A, independent of the methods used to compute it, of shape (*batch_shape,). + dispersion: The dispersion term D, independent of the methods used to compute it, of shape (*batch_shape,). + + Returns: + The SCRPS value for each forecast in the batch, of shape (*batch_shape,). + """ + return accuracy / dispersion + 0.5 * torch.log(dispersion) diff --git a/torch_crps/analytical/normal.py b/torch_crps/analytical/normal.py index 26d771f..748d461 100644 --- a/torch_crps/analytical/normal.py +++ b/torch_crps/analytical/normal.py @@ -1,6 +1,8 @@ import torch from torch.distributions import Normal +from torch_crps.abstract import crps_abstract, scrps_abstract + def _accuracy_normal( q: Normal, @@ -60,8 +62,7 @@ def crps_analytical_normal( accuracy = _accuracy_normal(q, y) dispersion = _dispersion_normal(q) - crps = accuracy - dispersion / 2 - return crps + return crps_abstract(accuracy, dispersion) def scrps_analytical_normal( @@ -88,5 +89,4 @@ def scrps_analytical_normal( accuracy = _accuracy_normal(q, y) dispersion = _dispersion_normal(q) - scrps = accuracy / dispersion + 0.5 * torch.log(dispersion) - return scrps + return scrps_abstract(accuracy, dispersion) diff --git a/torch_crps/analytical/studentt.py b/torch_crps/analytical/studentt.py index 4407468..25e1034 100644 --- a/torch_crps/analytical/studentt.py +++ b/torch_crps/analytical/studentt.py @@ -1,6 +1,8 @@ import torch from torch.distributions import StudentT +from torch_crps.abstract import crps_abstract, scrps_abstract + def standardized_studentt_cdf_via_scipy( z: torch.Tensor, @@ -37,7 +39,7 @@ def standardized_studentt_cdf_via_scipy( def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: - r"""Computes the accuracy term A = E[|Y - y|] for the Student-T distribution. + r"""Computes the accuracy term $A = E[|Y - y|]$ for the Student-T distribution. $$ A = \sigma \left[ z(2F_{\nu}(z) - 1) + 2 \frac{\nu+z^2}{\nu-1} f_{\nu}(z) \right] @@ -73,7 +75,7 @@ def _accuracy_studentt(q: StudentT, y: torch.Tensor) -> torch.Tensor: def _dispersion_studentt( q: StudentT, ) -> torch.Tensor: - r"""Computes the dispersion term D = E[|Y - Y'|] for the Student-T distribution. + r"""Computes the dispersion term $D = E[|Y - Y'|]$ for the Student-T distribution. See Also: Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019. @@ -124,15 +126,19 @@ def crps_analytical_studentt( For the standardized StudentT distribution: - $$ \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} - - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} $$ + $$ + \text{CRPS}(F_\nu, z) = z(2F_\nu(z) - 1) + 2f_\nu(z)\frac{\nu + z^2}{\nu - 1} + - \frac{2\sqrt{\nu}}{\nu - 1} \frac{B(\frac{1}{2}, \nu - \frac{1}{2})}{B(\frac{1}{2}, \frac{\nu}{2})^2} + $$ where $z$ is the standardized value, $F_\nu$ is the CDF, $f_\nu$ is the PDF of the standard StudentT distribution, $\nu$ is the degrees of freedom, and $B$ is the beta function. For the location-scale transformed distribution: - $$ \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) $$ + $$ + \text{CRPS}(F_{\nu,\mu,\sigma}, y) = \sigma \cdot \text{CRPS}\left(F_\nu, \frac{y-\mu}{\sigma}\right) + $$ where $\mu$ is the location parameter, $\sigma$ is the scale parameter, and $y$ is the observation. @@ -155,8 +161,7 @@ def crps_analytical_studentt( accuracy = _accuracy_studentt(q, y) dispersion = _dispersion_studentt(q) - crps = accuracy - dispersion / 2 - return crps + return crps_abstract(accuracy, dispersion) def scrps_analytical_studentt( @@ -166,13 +171,15 @@ def scrps_analytical_studentt( r"""Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. The score is calculated as: - $$ \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \cdot \log(D) $$ + $$ + \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \log(D) + $$ where: - $F_{\nu, \mu, \sigma^2}$ is the cumulative Student-T distribution, and $F_{\nu}$ is the standardized version. - $A = E_F[|X - y|]$ is the accuracy term. - $A = \sigma [ z(2 F_{\nu}(z) - 1) + 2(\nu + z²) / (\nu*B(\nu/2, 1/2)) * F_{\nu+1}(z * \sqrt{(\nu+1)/(\nu+z²)}) ]$ - - $D = E_F[|X - X'|]$ is the cispersion term. + - $D = E_F[|X - X'|]$ is the dispersion term. - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2}$ Note: @@ -194,5 +201,4 @@ def scrps_analytical_studentt( accuracy = _accuracy_studentt(q, y) dispersion = _dispersion_studentt(q) - scrps = accuracy / dispersion + torch.log(dispersion) / 2 - return scrps + return scrps_abstract(accuracy, dispersion) diff --git a/torch_crps/ensemble_crps.py b/torch_crps/ensemble_crps.py index f531f70..2148737 100644 --- a/torch_crps/ensemble_crps.py +++ b/torch_crps/ensemble_crps.py @@ -109,7 +109,7 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc # --- Accuracy term A := E[|X - y|] # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. - mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) + accuracy = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) # --- Dispersion term D := E[|X - X'|] # This is half the mean absolute difference between all pairs of predictions. @@ -128,7 +128,7 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc denom = m * (m - 1) if not biased else m**2 half_mean_dispersion = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step - # --- CRPS value := A - 0.5 * D - crps_value = mae - half_mean_dispersion # 0.5 already accounted for above + # --- CRPS value := A - D / 2 + crps_value = accuracy - half_mean_dispersion # 0.5 already accounted for above return crps_value diff --git a/torch_crps/integral_crps.py b/torch_crps/integral_crps.py index 4d40e7c..36f404b 100644 --- a/torch_crps/integral_crps.py +++ b/torch_crps/integral_crps.py @@ -55,6 +55,6 @@ def integrand(x: torch.Tensor) -> torch.Tensor: # Compute the integral using the trapezoidal rule. integral_values = integrand(x_values) - crps_values = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0) + crps = torch.trapezoid(integral_values, x_values.squeeze(-1), dim=0) - return crps_values + return crps diff --git a/torch_crps/normalization.py b/torch_crps/normalization.py index 2739fdb..2076929 100644 --- a/torch_crps/normalization.py +++ b/torch_crps/normalization.py @@ -47,9 +47,9 @@ def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Te abs_max_y = torch.ones(1, device=abs_max_y.device, dtype=abs_max_y.dtype) # Call the original CRPS function. - crps_result = crps_fcn(*args, **kwargs) + crps = crps_fcn(*args, **kwargs) # Normalize the result. - return crps_result / abs_max_y + return crps / abs_max_y return wrapper diff --git a/torch_crps/utils.py b/torch_crps/utils.py deleted file mode 100644 index e69de29..0000000 From d32fc0cc15651a744aa5e723be18a386728b914f Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 09:24:37 +0100 Subject: [PATCH 10/19] Renaming --- tests/{test_analytical_crps.py => test_analytical.py} | 0 tests/{test_ensemble_crps.py => test_ensemble.py} | 0 tests/{test_integral_crps.py => test_integral.py} | 0 torch_crps/{ensemble_crps.py => ensemble.py} | 0 torch_crps/{integral_crps.py => integral.py} | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename tests/{test_analytical_crps.py => test_analytical.py} (100%) rename tests/{test_ensemble_crps.py => test_ensemble.py} (100%) rename tests/{test_integral_crps.py => test_integral.py} (100%) rename torch_crps/{ensemble_crps.py => ensemble.py} (100%) rename torch_crps/{integral_crps.py => integral.py} (100%) diff --git a/tests/test_analytical_crps.py b/tests/test_analytical.py similarity index 100% rename from tests/test_analytical_crps.py rename to tests/test_analytical.py diff --git a/tests/test_ensemble_crps.py b/tests/test_ensemble.py similarity index 100% rename from tests/test_ensemble_crps.py rename to tests/test_ensemble.py diff --git a/tests/test_integral_crps.py b/tests/test_integral.py similarity index 100% rename from tests/test_integral_crps.py rename to tests/test_integral.py diff --git a/torch_crps/ensemble_crps.py b/torch_crps/ensemble.py similarity index 100% rename from torch_crps/ensemble_crps.py rename to torch_crps/ensemble.py diff --git a/torch_crps/integral_crps.py b/torch_crps/integral.py similarity index 100% rename from torch_crps/integral_crps.py rename to torch_crps/integral.py From 59816c55b873f44489afd73adc78600df42302cc Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 09:44:50 +0100 Subject: [PATCH 11/19] Moved the normalization by observation --- tests/test_normalization.py | 70 +++++++++++++++++++++---------------- torch_crps/__init__.py | 28 +++++++-------- torch_crps/normalization.py | 13 +++++++ 3 files changed, 66 insertions(+), 45 deletions(-) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 80b28f1..274a45b 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -4,29 +4,29 @@ import torch from torch_crps import ( - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ) @pytest.mark.parametrize( "wrapped_crps_fcn", [ - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ], ids=[ - "crps_analytical_normal_normalized", - "crps_analytical_normalized", - "crps_analytical_studentt_normalized", - "crps_ensemble_normalized", - "crps_integral_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", + "crps_analytical_studentt_obsnormalized", + "crps_ensemble_obsnormalized", + "crps_integral_obsnormalized", ], ) def test_normalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: int = 3): @@ -34,11 +34,15 @@ def test_normalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: i torch.manual_seed(0) # Set up test cases. - if wrapped_crps_fcn in (crps_analytical_normal_normalized, crps_analytical_normalized, crps_integral_normalized): + if wrapped_crps_fcn in ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_integral_obsnormalized, + ): q = torch.distributions.Normal(loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_analytical_studentt_normalized: + elif wrapped_crps_fcn == crps_analytical_studentt_obsnormalized: q = torch.distributions.StudentT(df=5 * torch.ones(num_y), loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_ensemble_normalized: + elif wrapped_crps_fcn == crps_ensemble_obsnormalized: q = torch.randn(num_y, 10) # dim_ensemble = 10 else: raise NotImplementedError("Test case setup error.") @@ -69,18 +73,18 @@ def test_normalization_wrapper_input_errors(wrapped_crps_fcn: Callable, num_y: i @pytest.mark.parametrize( "wrapped_crps_fcn", [ - crps_analytical_normal_normalized, - crps_analytical_normalized, - crps_analytical_studentt_normalized, - crps_ensemble_normalized, - crps_integral_normalized, + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, ], ids=[ - "crps_analytical_normal_normalized", - "crps_analytical_normalized", - "crps_analytical_studentt_normalized", - "crps_ensemble_normalized", - "crps_integral_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", + "crps_analytical_studentt_obsnormalized", + "crps_ensemble_obsnormalized", + "crps_integral_obsnormalized", ], ) @pytest.mark.parametrize("num_y", [1, 5, 100], ids=["1_obs", "5_obs", "100_obs"]) @@ -89,11 +93,15 @@ def test_normalization_wrapper_output_consistency(wrapped_crps_fcn: Callable, nu torch.manual_seed(0) # Set up test cases. - if wrapped_crps_fcn in (crps_analytical_normal_normalized, crps_analytical_normalized, crps_integral_normalized): + if wrapped_crps_fcn in ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_integral_obsnormalized, + ): q = torch.distributions.Normal(loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_analytical_studentt_normalized: + elif wrapped_crps_fcn == crps_analytical_studentt_obsnormalized: q = torch.distributions.StudentT(df=5 * torch.ones(num_y), loc=torch.zeros(num_y), scale=torch.ones(num_y)) - elif wrapped_crps_fcn == crps_ensemble_normalized: + elif wrapped_crps_fcn == crps_ensemble_obsnormalized: q = torch.randn(num_y, 10) # dim_ensemble = 10 else: raise NotImplementedError("Test case setup error.") diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index c8289a4..590dd41 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -1,28 +1,28 @@ from .analytical import crps_analytical, scrps_analytical from .analytical.normal import crps_analytical_normal, scrps_analytical_normal from .analytical.studentt import crps_analytical_studentt, scrps_analytical_studentt -from .ensemble_crps import crps_ensemble, crps_ensemble_naive -from .integral_crps import crps_integral -from .normalization import normalize_by_observation - -crps_analytical_normalized = normalize_by_observation(crps_analytical) -crps_analytical_normal_normalized = normalize_by_observation(crps_analytical_normal) -crps_analytical_studentt_normalized = normalize_by_observation(crps_analytical_studentt) -crps_ensemble_normalized = normalize_by_observation(crps_ensemble) -crps_integral_normalized = normalize_by_observation(crps_integral) +from .ensemble import crps_ensemble, crps_ensemble_naive +from .integral import crps_integral +from .normalization import ( + crps_analytical_normal_obsnormalized, + crps_analytical_obsnormalized, + crps_analytical_studentt_obsnormalized, + crps_ensemble_obsnormalized, + crps_integral_obsnormalized, +) __all__ = [ "crps_analytical", "crps_analytical_normal", - "crps_analytical_normal_normalized", - "crps_analytical_normalized", + "crps_analytical_normal_obsnormalized", + "crps_analytical_obsnormalized", "crps_analytical_studentt", - "crps_analytical_studentt_normalized", + "crps_analytical_studentt_obsnormalized", "crps_ensemble", "crps_ensemble_naive", - "crps_ensemble_normalized", + "crps_ensemble_obsnormalized", "crps_integral", - "crps_integral_normalized", + "crps_integral_obsnormalized", "scrps_analytical", "scrps_analytical_normal", "scrps_analytical_studentt", diff --git a/torch_crps/normalization.py b/torch_crps/normalization.py index 2076929..b7e70ab 100644 --- a/torch_crps/normalization.py +++ b/torch_crps/normalization.py @@ -3,6 +3,12 @@ import torch +from torch_crps.analytical.dispatch import crps_analytical +from torch_crps.analytical.normal import crps_analytical_normal +from torch_crps.analytical.studentt import crps_analytical_studentt +from torch_crps.ensemble import crps_ensemble +from torch_crps.integral import crps_integral + WRAPPED_INPUT_TYPE: TypeAlias = torch.distributions.Distribution | torch.Tensor | float @@ -53,3 +59,10 @@ def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Te return crps / abs_max_y return wrapper + + +crps_analytical_obsnormalized = normalize_by_observation(crps_analytical) +crps_analytical_normal_obsnormalized = normalize_by_observation(crps_analytical_normal) +crps_analytical_studentt_obsnormalized = normalize_by_observation(crps_analytical_studentt) +crps_ensemble_obsnormalized = normalize_by_observation(crps_ensemble) +crps_integral_obsnormalized = normalize_by_observation(crps_integral) From df64e8b7420ef7e0c2f5b245b001d13a5e8380dc Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 10:07:31 +0100 Subject: [PATCH 12/19] Refactored the ensemble methods --- torch_crps/ensemble.py | 179 +++++++++++++++++++++++------------- torch_crps/normalization.py | 4 +- 2 files changed, 117 insertions(+), 66 deletions(-) diff --git a/torch_crps/ensemble.py b/torch_crps/ensemble.py index 2148737..a5eadf1 100644 --- a/torch_crps/ensemble.py +++ b/torch_crps/ensemble.py @@ -1,7 +1,98 @@ import torch +from torch_crps.abstract import crps_abstract -def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: + +def _accuracy_ensemble( + x: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """Compute accuracy term $A = E[|X - y|]$, i.e., mean absolute error, for an ensemble forecast. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + + Returns: + Accuracy values for each observation, of shape (*batch_shape). + """ + # Unsqueeze the observation for explicit broadcasting. + return torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) + + +def _dispersion_ensemble_naive( + x: torch.Tensor, + biased: bool, +) -> torch.Tensor: + """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using a naive O(m²) algorithm. + + m is the number of ensemble members. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + Dispersion values for each observation, of shape (*batch_shape). + """ + # Create a matrix of all pairwise differences between ensemble members using broadcasting. + x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) + x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) + pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) + + # Take the absolute value of every element in the matrix. + abs_pairwise_diffs = torch.abs(pairwise_diffs) + + # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. + if biased: + # For the biased estimator, we use the mean which divides by m². + dispersion = abs_pairwise_diffs.mean(dim=(-2, -1)) + else: + # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). + m = x.shape[-1] # number of ensemble members + dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) + + return dispersion + + +def _dispersion_ensemble( + x: torch.Tensor, + biased: bool, +) -> torch.Tensor: + """Compute dispersion term $D = E[|X - X'|]$ for an ensemble forecast using an efficient O(m log m) algorithm. + + m is the number of ensemble members. + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + Dispersion values for each observation, of shape (*batch_shape). + """ + m = x.shape[-1] # number of ensemble members + + # Sort the predictions along the ensemble member dimension. + x_sorted, _ = torch.sort(x, dim=-1) + + # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. + coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 + + # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. + # We use the efficient O(m log m) implementation with a summation over a single dimension. + x_sum = torch.sum(coeffs * x_sorted, dim=-1) + + # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. + # This is half the mean absolute difference between all pairs of predictions. + denom = m * (m - 1) if not biased else m**2 + dispersion = 2 / denom * x_sum + + return dispersion + + +def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: """Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. This implementation uses the equality @@ -23,8 +114,8 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) - Args: x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). y: The ground truth observations, of shape (*batch_shape). - biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. - The unbiased estimator divides by m * (m - 1) instead of m². + biased: If True, uses the biased estimator for $D$, i.e., divides by m². If False, uses the unbiased estimator. + The unbiased estimator divides by m * (m - 1). Returns: The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). @@ -32,47 +123,27 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = True) - if x.shape[:-1] != y.shape: raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") - # --- Accuracy term A := E[|X - y|] - - # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. - mae = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) - # --- Dispersion term D := E[|X - X'|] - # This is half the mean absolute difference between all pairs of predictions. + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble_naive(x, biased) - # Create a matrix of all pairwise differences between ensemble members using broadcasting. - x_i = x.unsqueeze(-1) # shape: (*batch_shape, m, 1) - x_j = x.unsqueeze(-2) # shape: (*batch_shape, 1, m) - pairwise_diffs = x_i - x_j # shape: (*batch_shape, m, m) + # CRPS value := A - 0.5 * D + return crps_abstract(accuracy, dispersion) - # Take the absolute value of every element in the matrix. - abs_pairwise_diffs = torch.abs(pairwise_diffs) - # Calculate the mean of the m x m matrix for each batch item, i.e, not the batch shapes. - if biased: - # For the biased estimator, we use the mean which divides by m². - mean_dispersion = abs_pairwise_diffs.mean(dim=(-2, -1)) - else: - # For the unbiased estimator, we need to exclude the diagonal (where i=j) and divide by m(m-1). - m = x.shape[-1] # number of ensemble members - mean_dispersion = abs_pairwise_diffs.sum(dim=(-2, -1)) / (m * (m - 1)) - - # --- Assemble the CRPS value: A - 0.5 * D - crps_value = mae - 0.5 * mean_dispersion - - return crps_value - - -def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torch.Tensor: +def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. This implementation uses the equalities - $$ CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] $$ - - and + $$ + CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] + $$ - $$ CRPS(F, y) = E[|X - y|] + E[X] - 2 E[X F(X)] $$ + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$. It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, as long as they are equal for `x` and `y`. @@ -86,6 +157,7 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using a mathematical identity to compute the mean absolute difference. You can also see this trick [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] + - This implementation exactly matches the energy formula, see (NRG) and (eNRG), in Zamo & Naveau (2017) while using the compuational trick which can be read from (ePWM) in the same paper. The factors &\beta_0$ and $\beta_1$ in (ePWM) together equal the second term, i.e., the half mean dispersion, here. In (ePWM) they pulled @@ -94,8 +166,8 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc Args: x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). y: The ground truth observations, of shape (*batch_shape). - biased: If True, uses the biased estimator for E[|X - X'|]. If False, uses the unbiased estimator. - The unbiased estimator divides by m * (m - 1) instead of m². + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). Returns: The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). @@ -103,32 +175,11 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = True) -> torc if x.shape[:-1] != y.shape: raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") - # Get the number of ensemble members. - m = x.shape[-1] - - # --- Accuracy term A := E[|X - y|] - - # Compute the mean absolute error across all ensemble members. Unsqueeze the observation for explicit broadcasting. - accuracy = torch.abs(x - y.unsqueeze(-1)).mean(dim=-1) - - # --- Dispersion term D := E[|X - X'|] - # This is half the mean absolute difference between all pairs of predictions. - # We use the efficient O(m log m) implementation with a summation over a single dimension. - - # Sort the predictions along the ensemble member dimension. - x_sorted, _ = torch.sort(x, dim=-1) - - # Calculate the coefficients (2i - m - 1) for the linear-time sum. These are the same for every item in the batch. - coeffs = 2 * torch.arange(1, m + 1, device=x.device, dtype=x.dtype) - m - 1 - - # Calculate the sum Σᵢ (2i - m - 1)xᵢ for each forecast in the batch along the member dimension. - x_sum = torch.sum(coeffs * x_sorted, dim=-1) - - # Calculate the full expectation E[|X - X'|] = 2 / m² * Σᵢ (2i - m - 1)xᵢ. - denom = m * (m - 1) if not biased else m**2 - half_mean_dispersion = 1 / denom * x_sum # 2 in numerator here cancels with 0.5 in the next step + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) - # --- CRPS value := A - D / 2 - crps_value = accuracy - half_mean_dispersion # 0.5 already accounted for above + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble(x, biased) - return crps_value + # CRPS value := A - 0.5 * D + return crps_abstract(accuracy, dispersion) diff --git a/torch_crps/normalization.py b/torch_crps/normalization.py index b7e70ab..f1514f8 100644 --- a/torch_crps/normalization.py +++ b/torch_crps/normalization.py @@ -21,7 +21,7 @@ def normalize_by_observation(crps_fcn: Callable) -> Callable: - If the observations `y` are all close to zero, then the normalization is done by 1, so the CRPS can be > 1. Args: - crps_fcn: CRPS-calculating function to be wrapped. The fucntion must accept an argument called y which is + crps_fcn: CRPS-calculating function to be wrapped. The function must accept an argument called y which is at the 2nd position. Returns: @@ -31,7 +31,7 @@ def normalize_by_observation(crps_fcn: Callable) -> Callable: @functools.wraps(crps_fcn) def wrapper(*args: WRAPPED_INPUT_TYPE, **kwargs: WRAPPED_INPUT_TYPE) -> torch.Tensor: - """The function returned by the decorator that does the normalization and the forwading to the CRPS function.""" + """The function returned by the decorator that normalizes and forwards to the CRPS function.""" # Find the observation 'y' from the arguments. if "y" in kwargs: y = kwargs["y"] From 87f341f9dbed3fbb2cde4ce07c08fc4961d64738 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 10:34:54 +0100 Subject: [PATCH 13/19] SCRPS tests are running --- tests/test_ensemble.py | 33 +++++++++++++++---- torch_crps/analytical/dispatch.py | 2 +- torch_crps/analytical/normal.py | 10 +++++- torch_crps/analytical/studentt.py | 6 ++-- torch_crps/ensemble.py | 55 ++++++++++++++++++++++++++++--- 5 files changed, 90 insertions(+), 16 deletions(-) diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py index 46c11c8..6245274 100644 --- a/tests/test_ensemble.py +++ b/tests/test_ensemble.py @@ -5,7 +5,7 @@ from _pytest.fixtures import FixtureRequest from tests.conftest import needs_cuda -from torch_crps import crps_ensemble, crps_ensemble_naive +from torch_crps.ensemble import crps_ensemble, crps_ensemble_naive, scrps_ensemble @pytest.mark.parametrize( @@ -13,8 +13,8 @@ ["case_flat_1d", "case_batched_2d", "case_batched_3d"], ids=["case_flat_1d", "case_batched_2d", "case_batched_3d"], ) -@pytest.mark.parametrize("crps_fcn", [crps_ensemble_naive, crps_ensemble], ids=["naive", "default"]) @pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +@pytest.mark.parametrize("crps_fcn", [crps_ensemble_naive, crps_ensemble], ids=["naive", "default"]) @pytest.mark.parametrize( "use_cuda", [ @@ -23,7 +23,11 @@ ], ) def test_ensemble_smoke( - test_case_fixture_name: str, crps_fcn: Callable, biased: bool, use_cuda: bool, request: FixtureRequest + test_case_fixture_name: str, + crps_fcn: Callable[[torch.Tensor, torch.Tensor, bool], torch.Tensor], + biased: bool, + use_cuda: bool, + request: FixtureRequest, ): """Test that naive ensemble method yield.""" test_case_fixture: dict = request.getfixturevalue(test_case_fixture_name) @@ -67,16 +71,33 @@ def test_ensemble_match(batch_shape: tuple[int, ...], biased: bool, dim_ensemble ) -def test_ensemble_invalid_shapes(dim_ensemble: int = 10): +@pytest.mark.parametrize("crps_fcn", [crps_ensemble, scrps_ensemble], ids=["CRPS", "SCRPS"]) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +def test_ensemble_invalid_shapes( + crps_fcn: Callable[[torch.Tensor, torch.Tensor, bool], torch.Tensor], biased: bool, dim_ensemble: int = 10 +): """Test that crps_ensemble raises an error for invalid input shapes.""" # Mismatch in the number of batch dimensions. x = torch.randn(2, 3, dim_ensemble) y = torch.randn(3) with pytest.raises(ValueError): - crps_ensemble(x, y) + crps_fcn(x, y, biased) # Mismatch in batch dimension sizes. x = torch.randn(4, 5, dim_ensemble) y = torch.randn(4, 6) with pytest.raises(ValueError): - crps_ensemble(x, y) + crps_fcn(x, y, biased) + + +def test_ensemble_scrps_nonnegativity(num_samples: int = 100, dim_ensemble: int = 50): + """Test that the SCRPS can have negative values (in contrast to the CRPS).""" + torch.manual_seed(0) + + # Create a random ensemble forecast with small dispersion, and observations. + x = 1e-1 * torch.randn(num_samples, dim_ensemble) + y = torch.randn(num_samples) + + scrps_values = scrps_ensemble(x, y, False) + + assert torch.any(scrps_values < 0), "SCRPS should have some negative values!" diff --git a/torch_crps/analytical/dispatch.py b/torch_crps/analytical/dispatch.py index e1ae0d1..cccb37e 100644 --- a/torch_crps/analytical/dispatch.py +++ b/torch_crps/analytical/dispatch.py @@ -42,7 +42,7 @@ def scrps_analytical( q: Distribution, y: torch.Tensor, ) -> torch.Tensor: - """Compute the (negatively-oriented, i.e., lower is better) scaled CRPS (SCRPS) in closed-form. + """Compute the (negatively-oriented, i.e., lower is better) Scaled CRPS (SCRPS) in closed-form. Note: The input distribution must be either `torch.distributions.Normal` or `torch.distributions.StudentT`. diff --git a/torch_crps/analytical/normal.py b/torch_crps/analytical/normal.py index 748d461..1a07d2c 100644 --- a/torch_crps/analytical/normal.py +++ b/torch_crps/analytical/normal.py @@ -69,7 +69,15 @@ def scrps_analytical_normal( q: Normal, y: torch.Tensor, ) -> torch.Tensor: - """Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a normal distribution. + r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution. + + $$ + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \\log \\left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \\log(D) + $$ + + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. Note: In contrast to the (negatively-oriented) CRPS, the SCRPS can have negative values. diff --git a/torch_crps/analytical/studentt.py b/torch_crps/analytical/studentt.py index 25e1034..1897ac3 100644 --- a/torch_crps/analytical/studentt.py +++ b/torch_crps/analytical/studentt.py @@ -168,11 +168,11 @@ def scrps_analytical_studentt( q: StudentT, y: torch.Tensor, ) -> torch.Tensor: - r"""Compute the (negatively-oriented) scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. + r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a Student-T distribution. - The score is calculated as: $$ - \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \log(D) + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) $$ where: diff --git a/torch_crps/ensemble.py b/torch_crps/ensemble.py index a5eadf1..27c062c 100644 --- a/torch_crps/ensemble.py +++ b/torch_crps/ensemble.py @@ -1,6 +1,6 @@ import torch -from torch_crps.abstract import crps_abstract +from torch_crps.abstract import crps_abstract, scrps_abstract def _accuracy_ensemble( @@ -118,7 +118,7 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) The unbiased estimator divides by m * (m - 1). Returns: - The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). + The CRPS value for each forecast in the batch, of shape (*batch_shape). """ if x.shape[:-1] != y.shape: raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") @@ -136,10 +136,10 @@ def crps_ensemble_naive(x: torch.Tensor, y: torch.Tensor, biased: bool = False) def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: r"""Computes the Continuous Ranked Probability Score (CRPS) for an ensemble forecast. - This implementation uses the equalities + This function implements $$ - CRPS(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] + \text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] $$ where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF @@ -170,7 +170,7 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> tor unbiased estimator which instead divides by m * (m - 1). Returns: - The calculated CRPS value for each forecast in the batch, of shape (*batch_shape). + The CRPS value for each forecast in the batch, of shape (*batch_shape). """ if x.shape[:-1] != y.shape: raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") @@ -183,3 +183,48 @@ def crps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> tor # CRPS value := A - 0.5 * D return crps_abstract(accuracy, dispersion) + + +def scrps_ensemble(x: torch.Tensor, y: torch.Tensor, biased: bool = False) -> torch.Tensor: + r"""Computes the Scaled Continuous Ranked Probability Score (SCRPS) for an ensemble forecast. + + $$ + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) + $$ + + where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF + of the ensemble distribution evaluated at $X$, and $y$ are the ground truth observations. + + It is designed to be fully vectorized and handle any number of leading batch dimensions in the input tensors, + as long as they are equal for `x` and `y`. + + See Also: + Bolin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2019. + + Note: + This implementation uses an efficient algorithm to compute the dispersion term E[|X - X'|] in O(m log(m)) + time, where m is the number of ensemble members. This is achieved by sorting the ensemble predictions and using + a mathematical identity to compute the mean absolute difference. You can also see this trick + [here][https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html] + + Args: + x: The ensemble predictions, of shape (*batch_shape, dim_ensemble). + y: The ground truth observations, of shape (*batch_shape). + biased: If True, uses the biased estimator for the dispersion term $D$, i.e., divides by m². If False, uses the + unbiased estimator which instead divides by m * (m - 1). + + Returns: + The SCRPS value for each forecast in the batch, of shape (*batch_shape). + """ + if x.shape[:-1] != y.shape: + raise ValueError(f"The batch dimension(s) of x {x.shape[:-1]} and y {y.shape} must be equal!") + + # Accuracy term A := E[|X - y|] + accuracy = _accuracy_ensemble(x, y) + + # Dispersion term D := E[|X - X'|] + dispersion = _dispersion_ensemble(x, biased) + + # SCRPS value := A/D + 0.5 * log(D) + return scrps_abstract(accuracy, dispersion) From 13f77c55d8da11e0042f0e9fb895bb4be92cd9f9 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 10:39:47 +0100 Subject: [PATCH 14/19] Expose scrps_ensemble --- torch_crps/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_crps/__init__.py b/torch_crps/__init__.py index 590dd41..f213f83 100644 --- a/torch_crps/__init__.py +++ b/torch_crps/__init__.py @@ -1,7 +1,7 @@ from .analytical import crps_analytical, scrps_analytical from .analytical.normal import crps_analytical_normal, scrps_analytical_normal from .analytical.studentt import crps_analytical_studentt, scrps_analytical_studentt -from .ensemble import crps_ensemble, crps_ensemble_naive +from .ensemble import crps_ensemble, crps_ensemble_naive, scrps_ensemble from .integral import crps_integral from .normalization import ( crps_analytical_normal_obsnormalized, @@ -26,4 +26,5 @@ "scrps_analytical", "scrps_analytical_normal", "scrps_analytical_studentt", + "scrps_ensemble", ] From 451ed5dd62ff2f0def0241445fdec97df46577da Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 10:53:43 +0100 Subject: [PATCH 15/19] Added tests for scale-invariance --- tests/test_ensemble.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py index 6245274..9859374 100644 --- a/tests/test_ensemble.py +++ b/tests/test_ensemble.py @@ -1,3 +1,4 @@ +import math from collections.abc import Callable import pytest @@ -101,3 +102,43 @@ def test_ensemble_scrps_nonnegativity(num_samples: int = 100, dim_ensemble: int scrps_values = scrps_ensemble(x, y, False) assert torch.any(scrps_values < 0), "SCRPS should have some negative values!" + + +@pytest.mark.parametrize("seed", [0, 1, 2, 3, 4], ids=["seed_0", "seed_1", "seed_2", "seed_3", "seed_4"]) +@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"]) +@pytest.mark.parametrize("scale_factor", [1e1, 1e3, 1e6, 1e9], ids=["1e1", "1e3", "1e6", "1e9"]) +def test_ensemble_scrps_scale_invariance( + seed: int, + biased: bool, + scale_factor: float, + num_samples: int = 100, + dim_ensemble: int = 50, +): + """Test that the SCRPS is locally scale-invariant (to a certain extent).""" + torch.manual_seed(seed) + rtol = 0.5 * math.log(scale_factor) # just an educated guess + + # Create a random ensemble forecast and observations. + x = torch.randn(num_samples, dim_ensemble) + y = torch.randn(num_samples) + scrps_original = scrps_ensemble(x, y, biased) + + # Scale the ensemble forecasts by a factor of 1000. + x_scaled = scale_factor * x + scrps_scaled = scrps_ensemble(x_scaled, y, biased) + + # Assert that the SCRPS values are approximately scale-invariant. + assert torch.allclose(scrps_original, scrps_scaled, rtol=rtol), ( + f"The SCRPS values are not scale-invariant as expected, i.e., scaling the forecasts by {scale_factor} " + f"should not change the SCRPS values by more than a factor of {rtol}, but it did." + ) + + # Scale the observations by a factor of 1000. + y_scaled = scale_factor * y + scrps_scaled = scrps_ensemble(x, y_scaled, biased) + + # Assert that the SCRPS values are approximately scale-invariant. + assert torch.allclose(scrps_original, scrps_scaled, rtol=rtol), ( + f"The SCRPS values are not scale-invariant as expected, i.e., scaling the observations by {scale_factor} " + f"should not change the SCRPS values by more than a factor of {rtol}, but it did." + ) From 0e3717427ca8ab9a47170258f8ee6cb15a391f67 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 11:21:10 +0100 Subject: [PATCH 16/19] Improved the doc --- pyproject.toml | 2 +- readme.md | 66 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a2ea71..896543d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ classifiers = [ dependencies = [ "torch>=2.7", ] -description = "Implementations of the CRPS using PyTorch" +description = "PyTorch-based implementations of the Continuously-Ranked Probability Score (CRPS) as well as its locally scale-invariant version (SCRPS)" dynamic = ["version"] license = "CC-BY-4.0" maintainers = [{name = "Fabio Muratore", email = "accounts@famura.net"}] diff --git a/readme.md b/readme.md index 4da6f30..e4ea4c8 100644 --- a/readme.md +++ b/readme.md @@ -13,18 +13,58 @@ [![Ruff][ruff-badge]][ruff] [![uv][uv-badge]][uv] -Implementations of the Continuously-Ranked Probability Score (CRPS) using PyTorch +PyTorch-based implementations of the Continuously-Ranked Probability Score (CRPS) as well as its locally scale-invariant +version (SCRPS) ## Background -The Continuously-Ranked Probability Score (CRPS) is a strictly proper scoring rule. -It assesses how well a distribution with the cumulative distribution function $F$ is explaining an observation $y$ +### Continuously-Ranked Probability Score (CRPS) -$$ \text{CRPS}(F,y) = \int _{\mathbb {R} }(F(x)-\mathbb {1} (x\geq y))^{2}dx \qquad (\text{integral formulation}) $$ +The CRPS is a strictly proper scoring rule. +It assesses how well a distribution with the cumulative distribution function $F(X)$ of the estimate $X$ (a random +variable) is explaining an observation $y$ + +$$ +\text{CRPS}(F,y) = \int _{\mathbb {R}} \left( F(x)-\mathbb {1} (x\geq y) \right)^{2} dx +$$ where $1$ denoted the indicator function. -In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different formulations of the CRPS. +In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different formulations of the CRPS. One of them is + +$$ +\text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] +$$ + +which can be shortened to + +$$ +\text{CRPS}(F, y) = A - 0.5 D +$$ + +where $A$ is called the accuracy term and $D$ is called the disperion term (at least I do it in this repo). + +### Scaled Continuously-Ranked Probability Score (SCRPS) + +The SCRPS is a locally scale-invariant version of the CRPS. +In their [paper][scrps-paper], Bolling & Wallin define it in a positively-oriented, i.e., higher is better. +In contrast, I implement the SCRPS in this repo negatively-oriented, just like a loss function. + +Oversimplifying the notation, the (negatively-oriented) SCRPS can be written as + +$$ +\text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) +$$ + +which can be shortened to + +$$ +\text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \log(D) +$$ + +The scale-invariance, i.e., the SCRPS value does not depend on the magnitude of $D$, comes from the division by $D$. + +Note that the SCRPS can, in contrast to the CRPS, yield negative values. ### Incomplete list of sources that I came across while researching about the CRPS @@ -33,6 +73,7 @@ In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different fo - Gneiting & Raftery; "Strictly Proper Scoring Rules, Prediction, and Estimation"; 2007 - Zamo & Naveau; "Estimation of the Continuous Ranked Probability Score with Limited Information and Applications to Ensemble Weather Forecasts"; 2018 - Jordan et al.; "Evaluating Probabilistic Forecasts with scoringRules"; 2019 +- Bollin & Wallin; "Local scale invariance and robustness of proper scoring rules"; 2029 - Olivares & Négiar & Ma et al; "CLOVER: Probabilistic Forecasting with Coherent Learning Objective Reparameterization"; 2023 - Vermorel & Tikhonov; "Continuously-Ranked Probability Score (CRPS)" [blog post][Lokad-post]; 2024 - Nvidia; "PhysicsNeMo Framework" [source code][nvidia-crps-implementation]; 2025 @@ -40,8 +81,8 @@ In Section 2 of this [paper][crps-folumations] Zamo & Naveau list 3 different fo ## Application to Machine Learning -The CRPS can be used as a loss function in machine learning, just like the well-known negative log-likelihood loss which -is the log scoring rule. +The CRPS, as well as the SCRPS, can be used as a loss function in machine learning, just like the well-known negative +log-likelihood loss which is the log scoring rule. The parametrized model outputs a distribution $q(x)$. The CRPS loss evaluates how good $q(x)$ is explaining the observation $y$. @@ -54,7 +95,15 @@ There is [work on multi-variate CRPS estimation][multivariate-crps], but it is n ## Implementation -The integral formulation is infeasible to naively evaluate on a computer due to the infinite integration over $x$. +The direct implementation of the integral formulation is not suited to evaluate on a computer due to the infinite +integration over the domain of the random variable $X$. +Nevertheless, this repository includes such an implementation to verify the others. + +The normalization-by-observation variants are improper solutions to normalize the CPRS values. The goal is to use the +CPRS as a loss function in machine learning tasks. For that, it is highly beneficial if the loss does not depend on +the scale of the problem. +However, deviding by the absolute maximum of the observations is a bad proxy for doing this. +I plan on removing these methods once I gained trust in my SCRPS implementation. I found [Nvidia's implementation][nvidia-crps-implementation] of the CRPS for ensemble preductions in $M log(M)$ time inspiring to read. @@ -89,6 +138,7 @@ inspiring to read. [uv]: https://docs.astral.sh/uv [crps-folumations]: https://link.springer.com/article/10.1007/s11004-017-9709-7 +[scrps-paper]: https://arxiv.org/abs/1912.05642 [Lokad-post]: https://www.lokad.com/continuous-ranked-probability-score/ [multivariate-crps]: https://arxiv.org/pdf/2410.09133 [nvidia-crps-implementation]: https://docs.nvidia.com/physicsnemo/25.11/_modules/physicsnemo/metrics/general/crps.html From c3ab697c680c53112c0ca24d12a690e310115a88 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 11:22:54 +0100 Subject: [PATCH 17/19] Ran `pre-commit run --all-files` --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 896543d..9acd520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,9 @@ distance-dirty = "{base_version}" [tool.mypy] ignore_missing_imports = true # when no stubs are available, e.g. for matplotlib or tabulate -pretty = true -show_error_context = true -show_traceback = true +pretty = true +show_error_context = true +show_traceback = true [tool.pytest.ini_options] addopts = [ From d9baf71fc9750bf2908f3ceeb3368d84eefc35bc Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 11:23:07 +0100 Subject: [PATCH 18/19] pre-commit run --all-files --- readme.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/readme.md b/readme.md index e4ea4c8..8a05715 100644 --- a/readme.md +++ b/readme.md @@ -36,7 +36,7 @@ $$ \text{CRPS}(F, y) = E[|X - y|] - 0.5 E[|X - X'|] = E[|X - y|] + E[X] - 2 E[X F(X)] $$ -which can be shortened to +which can be shortened to $$ \text{CRPS}(F, y) = A - 0.5 D @@ -56,7 +56,7 @@ $$ \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) $$ -which can be shortened to +which can be shortened to $$ \text{SCRPS}(F, y) = \frac{A}{D} + 0.5 \log(D) From 0c58de80ab15887a4724cfae8f6d09a28476a474 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Sun, 8 Feb 2026 11:36:14 +0100 Subject: [PATCH 19/19] Doc string fixes --- torch_crps/analytical/normal.py | 4 ++-- torch_crps/analytical/studentt.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_crps/analytical/normal.py b/torch_crps/analytical/normal.py index 1a07d2c..062c1d4 100644 --- a/torch_crps/analytical/normal.py +++ b/torch_crps/analytical/normal.py @@ -72,8 +72,8 @@ def scrps_analytical_normal( r"""Compute the (negatively-oriented) Scaled CRPS (SCRPS) in closed-form assuming a normal distribution. $$ - \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \\log \\left( E[|X - X'|] \right) - = \frac{A}{D} + 0.5 \\log(D) + \text{SCRPS}(F, y) = -\frac{E[|X - y|]}{E[|X - X'|]} - 0.5 \log \left( E[|X - X'|] \right) + = \frac{A}{D} + 0.5 \log(D) $$ where $X$ and $X'$ are independent random variables drawn from the ensemble distribution, and $F(X)$ is the CDF diff --git a/torch_crps/analytical/studentt.py b/torch_crps/analytical/studentt.py index 1897ac3..e8dba89 100644 --- a/torch_crps/analytical/studentt.py +++ b/torch_crps/analytical/studentt.py @@ -176,11 +176,12 @@ def scrps_analytical_studentt( $$ where: + - $F_{\nu, \mu, \sigma^2}$ is the cumulative Student-T distribution, and $F_{\nu}$ is the standardized version. - $A = E_F[|X - y|]$ is the accuracy term. - $A = \sigma [ z(2 F_{\nu}(z) - 1) + 2(\nu + z²) / (\nu*B(\nu/2, 1/2)) * F_{\nu+1}(z * \sqrt{(\nu+1)/(\nu+z²)}) ]$ - $D = E_F[|X - X'|]$ is the dispersion term. - - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2}$ + - $D = \frac{ 4\sigma }{ \nu-1 } * ( \frac{ \Gamma( \nu/2 ) }{ \Gamma( (\nu-1)/2) } )^2$ Note: This formula is only valid for degrees of freedom $\nu > 1$.