diff --git a/.gitignore b/.gitignore index e234b91..573e7c4 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,5 @@ uv.lock # results from examples examples/*.png +!examples/visualization_gamma.png +!examples/visualization_normal.png diff --git a/examples/time_series_learning.py b/examples/time_series_learning.py index 0f3fa9a..9a2a8e9 100644 --- a/examples/time_series_learning.py +++ b/examples/time_series_learning.py @@ -8,6 +8,7 @@ """ import pathlib +from typing import Literal import matplotlib.pyplot as plt import numpy @@ -15,12 +16,10 @@ import torch from matplotlib.figure import Figure -from torch_crps.analytical import crps_analytical +from torch_crps import crps_analytical, scrps_analytical EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent) -torch.set_default_dtype(torch.float32) - class SimpleDistributionalModel(torch.nn.Module): """A model that makes independent predictions given a sequence of inputs, yielding a StudentT distribution.""" @@ -100,7 +99,7 @@ def simple_training( packed_targets: torch.Tensor, dataset_name: str, normalize_data: bool, - use_crps: bool, + score_fcn: Literal["nll", "crps", "scrps"], device: torch.device, ) -> None: """A bare bones training loop for the time series model that works on windowed data. @@ -113,7 +112,7 @@ def simple_training( packed_targets: Target tensor of shape (num_samples, dim_output). dataset_name: Name of the dataset. normalize_data: Whether the data is normalized. - use_crps: If True, use CRPS loss. If false, use negative log-likelihood loss. + score_fcn: Which scoring function to use: "nll", "crps", or "scrps". device: Device to run training on. """ # Move data to device. @@ -123,7 +122,7 @@ def simple_training( # Use a simple heuristic for the optimization hyper-parameters. if dataset_name == "monthly_sunspots": if normalize_data: - num_epochs = 3001 + num_epochs = 4001 lr = 3e-3 else: # The data is in [0, 100] so we need more steps. @@ -144,8 +143,10 @@ def simple_training( packed_predictions = model(packed_inputs) # Compute the loss, lower is better in both cases. - if use_crps: + if score_fcn == "crps": loss = crps_analytical(packed_predictions, packed_targets).mean() + elif score_fcn == "scrps": + loss = scrps_analytical(packed_predictions, packed_targets).mean() else: loss = -packed_predictions.log_prob(packed_targets).mean() @@ -239,7 +240,7 @@ def plot_results( if dataset_name not in ("monthly_sunspots", "mackey_glass"): raise NotImplementedError(f"Unknown dataset {dataset_name}! Please specify the necessary parts in the script.") - fig, axs = plt.subplots(2, 1, figsize=(16, 9)) + fig, axs = plt.subplots(2, 1, figsize=(16, 10)) # Plot training data and predictions. axs[0].plot(data_trn, label="data train") @@ -277,12 +278,14 @@ def plot_results( if __name__ == "__main__": seaborn.set_theme() + + torch.set_default_dtype(torch.float32) torch.manual_seed(0) # Configure. normalize_data = True # scales the data to be in [-1, 1] (recommended for monthly_sunspots dataset) dataset_name = "monthly_sunspots" # monthly_sunspots or mackey_glass - use_crps = True # if True, use CRPS loss instead of NLL + score_fcn = "scrps" # "nll", "crps", or "scrps" len_window = 10 # tested 10 and 20 dim_hidden = 64 @@ -292,7 +295,7 @@ def plot_results( # Prepare the data. data_trn, data_tst = load_and_split_data(dataset_name, normalize_data) - dim_data = data_trn.size(1) + num_training_samples, dim_data = data_trn.size(0), data_trn.size(1) # Create the model and move to device model = SimpleDistributionalModel(dim_input=dim_data, dim_output=dim_data, hidden_size=dim_hidden) @@ -305,7 +308,7 @@ def plot_results( # i i ... t inputs = [] targets = [] - for idx in range(len_window, data_trn.size(0)): + for idx in range(len_window, num_training_samples): # Slice the input. idx_begin = max(idx - len_window, 0) inp = data_trn[idx_begin:idx, :].view(-1, dim_data) @@ -329,7 +332,7 @@ def plot_results( packed_targets, dataset_name=dataset_name, normalize_data=normalize_data, - use_crps=use_crps, + score_fcn=score_fcn, device=device, ) @@ -347,6 +350,5 @@ def plot_results( predictions_tst_mean, predictions_tst_std, ) - loss_name = "crps" if use_crps else "nll" - plt.savefig(EXAMPLES_DIR / f"time_series_learning_{dataset_name}_{loss_name}.png", dpi=300) - print(f"Figure saved to {EXAMPLES_DIR / f'time_series_learning_{dataset_name}.png'}") + fig.savefig(EXAMPLES_DIR / f"time_series_learning_{dataset_name}_{score_fcn}.png", dpi=300) + print(f"Figure saved to {EXAMPLES_DIR / f'time_series_learning_{dataset_name}_{score_fcn}.png'}") diff --git a/examples/visualization.py b/examples/visualization.py new file mode 100644 index 0000000..c6dcb52 --- /dev/null +++ b/examples/visualization.py @@ -0,0 +1,149 @@ +import pathlib + +import matplotlib.pyplot as plt +import seaborn +import torch + +from torch_crps import crps_analytical, crps_ensemble, scrps_analytical, scrps_ensemble + +EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent) + + +def gamma_example( + concentration: float, + rate: float, + eval_min: float, + eval_max: float, + num_eval_points: int, + ensemble_size: int, +) -> None: + """Example showing the probability density, negative log-likelihood, CRPS, and SCRPS of a Gamma distribution. + + Args: + concentration: The concentration parameter of the Gamma distribution. + rate: The rate parameter of the Gamma distribution. + eval_min: The minimum grid value of y to evaluate on. + eval_max: The maximum grid value of y to evaluate on. + num_eval_points: The number of grid points to evaluate the functions on. + ensemble_size: The number of ensemble estimates to use for the CRPS and SCRPS evaluation. + """ + assert concentration > 0 and rate > 0 + assert eval_min < eval_max + assert num_eval_points > 0 + assert ensemble_size > 0 + + # Create a distribution (imagine a model's output). + p = torch.distributions.Gamma(concentration=concentration, rate=rate) + + # Define a grid for all evaluations. + y = torch.linspace(eval_min, eval_max, num_eval_points) + + # Evaluate the probability, negative log-probability, and the CRPS on the grid. + p_y = p.log_prob(y).exp() + nll_y = -p.log_prob(y) + q_samples = p.sample((num_eval_points, ensemble_size)) + crps_y = crps_ensemble(q_samples, y) + scrps_y = scrps_ensemble(q_samples, y) + print(f"Evaluated p(y), NLL(p(y), y), and CRPS(p(y), y) on a grid of {num_eval_points} points") + + # Plot the evaluations. + fig = plt.figure(figsize=(12, 8)) + y_plot = y.cpu().numpy() + seaborn.lineplot(x=y_plot, y=p_y.cpu().numpy(), label=f"p(x) = Gamma(concentration={concentration}, rate={rate})") + seaborn.lineplot(x=y_plot, y=nll_y.cpu().numpy(), label="NLL(p(x), y)") + seaborn.lineplot(x=y_plot, y=crps_y.cpu().numpy(), label=f"CRPS_{ensemble_size}(p(x), y)") + seaborn.lineplot(x=y_plot, y=scrps_y.cpu().numpy(), label=f"SCRPS_{ensemble_size}(p(x), y)") + + # Plot the mean and the median as dashed vertical lines. + plt.axvline(p.mean.item(), color="C8", linestyle="dashed", label="mean") + plt.axvline(p.mode.item(), color="C9", linestyle="dashed", label="median") + + # Add annotation. + plt.xlabel("observation y") + plt.ylabel("value") + plt.legend(loc="upper right") + + # Save the plot. + fig.tight_layout() + fig.savefig(EXAMPLES_DIR / "visualization_gamma.png", dpi=300) + print("Saved visualization to", EXAMPLES_DIR / "visualization_gamma.png") + + +def scale_example( + loc: float, + scale: float, + num_eval_points: int, +) -> None: + """Example showing the effect of the random variable's scale on the CRPS and SCRPS of a distribution. + + Args: + loc: The location parameter of the Normal distribution. + scale: The scale parameter of the Normal distribution. + num_eval_points: The number of grid points to evaluate the functions on. + """ + assert loc > 0 and scale > 0 + assert num_eval_points > 0 + + # Create a distribution (imagine a model's output). + p = torch.distributions.Normal(loc=loc, scale=scale) + + # Define a grid for all evaluations. + eval_min, eval_max = loc - 4 * scale, loc + 4 * scale + y = torch.linspace(eval_min, eval_max, num_eval_points) + + # Evaluate the probability, negative log-probability, and the CRPS on the grid. + p_y = p.log_prob(y).exp() + nll_y = -p.log_prob(y) + crps_y = crps_analytical(p, y) + scrps_y = scrps_analytical(p, y) + print(f"Evaluated p(y), NLL(p(y), y), and CRPS(p(y), y) on a grid of {num_eval_points} points") + + # Plot the evaluations. Make the upper subplot 1/4 the height of the lower one + fig, ax = plt.subplots( + nrows=2, + ncols=1, + figsize=(12, 8), + gridspec_kw={"height_ratios": [1, 4]}, + sharex=True, + ) + y_plot = y.cpu().numpy() + + # Upper (smaller) subplot: probability density + ax[0].plot(y_plot, p_y.cpu().numpy(), color="C0", label=f"p(x) = Normal(loc={loc}, scale={scale})") + ax[0].set_ylabel("p(x)") + ax[0].legend(loc="upper right") + + # Lower (larger) subplot: NLL, CRPS, SCRPS + ax[1].plot(y_plot, nll_y.cpu().numpy(), color="C1", label="NLL(p(x), y)") + ax[1].plot(y_plot, crps_y.cpu().numpy(), color="C2", label="CRPS(p(x), y)") + ax[1].plot(y_plot, scrps_y.cpu().numpy(), color="C3", label="SCRPS(p(x), y)") + ax[1].set_xlabel("observation y") + ax[1].set_ylabel("value") + ax[1].legend(loc="upper center") + + # Save the plot. + fig.tight_layout() + fig.savefig(EXAMPLES_DIR / "visualization_normal.png", dpi=300) + print("Saved visualization to", EXAMPLES_DIR / "visualization_normal.png") + + +if __name__ == "__main__": + seaborn.set_theme() + + torch.set_default_dtype(torch.float32) + torch.manual_seed(0) + + gamma_example( + concentration=3, + rate=4, + eval_min=0.01, + eval_max=2.5, + num_eval_points=5000, + ensemble_size=2000, + ) + + scale_example( + loc=1000, + scale=20, + num_eval_points=1000, + ) diff --git a/examples/visualization_gamma.png b/examples/visualization_gamma.png new file mode 100644 index 0000000..110cc48 Binary files /dev/null and b/examples/visualization_gamma.png differ diff --git a/examples/visualization_normal.png b/examples/visualization_normal.png new file mode 100644 index 0000000..00f34e3 Binary files /dev/null and b/examples/visualization_normal.png differ diff --git a/pyproject.toml b/pyproject.toml index 9acd520..e7c95aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,9 @@ distance-dirty = "{base_version}" [tool.mypy] ignore_missing_imports = true # when no stubs are available, e.g. for matplotlib or tabulate -pretty = true -show_error_context = true -show_traceback = true +pretty = true +show_error_context = true +show_traceback = true [tool.pytest.ini_options] addopts = [ @@ -146,6 +146,9 @@ ignore = [ preview = true [tool.ruff.lint.per-file-ignores] +"examples/*" = [ + "S101", # use of assert detected +] "tests/*" = [ "ANN201", # return type information "D", # pydocstyle