Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ uv.lock

# results from examples
examples/*.png
!examples/visualization_gamma.png
!examples/visualization_normal.png
32 changes: 17 additions & 15 deletions examples/time_series_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
"""

import pathlib
from typing import Literal

import matplotlib.pyplot as plt
import numpy
import seaborn
import torch
from matplotlib.figure import Figure

from torch_crps.analytical import crps_analytical
from torch_crps import crps_analytical, scrps_analytical

EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent)

torch.set_default_dtype(torch.float32)


class SimpleDistributionalModel(torch.nn.Module):
"""A model that makes independent predictions given a sequence of inputs, yielding a StudentT distribution."""
Expand Down Expand Up @@ -100,7 +99,7 @@ def simple_training(
packed_targets: torch.Tensor,
dataset_name: str,
normalize_data: bool,
use_crps: bool,
score_fcn: Literal["nll", "crps", "scrps"],
device: torch.device,
) -> None:
"""A bare bones training loop for the time series model that works on windowed data.
Expand All @@ -113,7 +112,7 @@ def simple_training(
packed_targets: Target tensor of shape (num_samples, dim_output).
dataset_name: Name of the dataset.
normalize_data: Whether the data is normalized.
use_crps: If True, use CRPS loss. If false, use negative log-likelihood loss.
score_fcn: Which scoring function to use: "nll", "crps", or "scrps".
device: Device to run training on.
"""
# Move data to device.
Expand All @@ -123,7 +122,7 @@ def simple_training(
# Use a simple heuristic for the optimization hyper-parameters.
if dataset_name == "monthly_sunspots":
if normalize_data:
num_epochs = 3001
num_epochs = 4001
lr = 3e-3
else:
# The data is in [0, 100] so we need more steps.
Expand All @@ -144,8 +143,10 @@ def simple_training(
packed_predictions = model(packed_inputs)

# Compute the loss, lower is better in both cases.
if use_crps:
if score_fcn == "crps":
loss = crps_analytical(packed_predictions, packed_targets).mean()
elif score_fcn == "scrps":
loss = scrps_analytical(packed_predictions, packed_targets).mean()
else:
loss = -packed_predictions.log_prob(packed_targets).mean()

Expand Down Expand Up @@ -239,7 +240,7 @@ def plot_results(
if dataset_name not in ("monthly_sunspots", "mackey_glass"):
raise NotImplementedError(f"Unknown dataset {dataset_name}! Please specify the necessary parts in the script.")

fig, axs = plt.subplots(2, 1, figsize=(16, 9))
fig, axs = plt.subplots(2, 1, figsize=(16, 10))

# Plot training data and predictions.
axs[0].plot(data_trn, label="data train")
Expand Down Expand Up @@ -277,12 +278,14 @@ def plot_results(

if __name__ == "__main__":
seaborn.set_theme()

torch.set_default_dtype(torch.float32)
torch.manual_seed(0)

# Configure.
normalize_data = True # scales the data to be in [-1, 1] (recommended for monthly_sunspots dataset)
dataset_name = "monthly_sunspots" # monthly_sunspots or mackey_glass
use_crps = True # if True, use CRPS loss instead of NLL
score_fcn = "scrps" # "nll", "crps", or "scrps"
len_window = 10 # tested 10 and 20
dim_hidden = 64

Expand All @@ -292,7 +295,7 @@ def plot_results(

# Prepare the data.
data_trn, data_tst = load_and_split_data(dataset_name, normalize_data)
dim_data = data_trn.size(1)
num_training_samples, dim_data = data_trn.size(0), data_trn.size(1)

# Create the model and move to device
model = SimpleDistributionalModel(dim_input=dim_data, dim_output=dim_data, hidden_size=dim_hidden)
Expand All @@ -305,7 +308,7 @@ def plot_results(
# i i ... t
inputs = []
targets = []
for idx in range(len_window, data_trn.size(0)):
for idx in range(len_window, num_training_samples):
# Slice the input.
idx_begin = max(idx - len_window, 0)
inp = data_trn[idx_begin:idx, :].view(-1, dim_data)
Expand All @@ -329,7 +332,7 @@ def plot_results(
packed_targets,
dataset_name=dataset_name,
normalize_data=normalize_data,
use_crps=use_crps,
score_fcn=score_fcn,
device=device,
)

Expand All @@ -347,6 +350,5 @@ def plot_results(
predictions_tst_mean,
predictions_tst_std,
)
loss_name = "crps" if use_crps else "nll"
plt.savefig(EXAMPLES_DIR / f"time_series_learning_{dataset_name}_{loss_name}.png", dpi=300)
print(f"Figure saved to {EXAMPLES_DIR / f'time_series_learning_{dataset_name}.png'}")
fig.savefig(EXAMPLES_DIR / f"time_series_learning_{dataset_name}_{score_fcn}.png", dpi=300)
print(f"Figure saved to {EXAMPLES_DIR / f'time_series_learning_{dataset_name}_{score_fcn}.png'}")
149 changes: 149 additions & 0 deletions examples/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import pathlib

import matplotlib.pyplot as plt
import seaborn
import torch

from torch_crps import crps_analytical, crps_ensemble, scrps_analytical, scrps_ensemble

EXAMPLES_DIR = pathlib.Path(pathlib.Path(__file__).parent)


def gamma_example(
concentration: float,
rate: float,
eval_min: float,
eval_max: float,
num_eval_points: int,
ensemble_size: int,
) -> None:
"""Example showing the probability density, negative log-likelihood, CRPS, and SCRPS of a Gamma distribution.

Args:
concentration: The concentration parameter of the Gamma distribution.
rate: The rate parameter of the Gamma distribution.
eval_min: The minimum grid value of y to evaluate on.
eval_max: The maximum grid value of y to evaluate on.
num_eval_points: The number of grid points to evaluate the functions on.
ensemble_size: The number of ensemble estimates to use for the CRPS and SCRPS evaluation.
"""
assert concentration > 0 and rate > 0
assert eval_min < eval_max
assert num_eval_points > 0
assert ensemble_size > 0

# Create a distribution (imagine a model's output).
p = torch.distributions.Gamma(concentration=concentration, rate=rate)

# Define a grid for all evaluations.
y = torch.linspace(eval_min, eval_max, num_eval_points)

# Evaluate the probability, negative log-probability, and the CRPS on the grid.
p_y = p.log_prob(y).exp()
nll_y = -p.log_prob(y)
q_samples = p.sample((num_eval_points, ensemble_size))
crps_y = crps_ensemble(q_samples, y)
scrps_y = scrps_ensemble(q_samples, y)
print(f"Evaluated p(y), NLL(p(y), y), and CRPS(p(y), y) on a grid of {num_eval_points} points")

# Plot the evaluations.
fig = plt.figure(figsize=(12, 8))
y_plot = y.cpu().numpy()
seaborn.lineplot(x=y_plot, y=p_y.cpu().numpy(), label=f"p(x) = Gamma(concentration={concentration}, rate={rate})")
seaborn.lineplot(x=y_plot, y=nll_y.cpu().numpy(), label="NLL(p(x), y)")
seaborn.lineplot(x=y_plot, y=crps_y.cpu().numpy(), label=f"CRPS_{ensemble_size}(p(x), y)")
seaborn.lineplot(x=y_plot, y=scrps_y.cpu().numpy(), label=f"SCRPS_{ensemble_size}(p(x), y)")

# Plot the mean and the median as dashed vertical lines.
plt.axvline(p.mean.item(), color="C8", linestyle="dashed", label="mean")
plt.axvline(p.mode.item(), color="C9", linestyle="dashed", label="median")

# Add annotation.
plt.xlabel("observation y")
plt.ylabel("value")
plt.legend(loc="upper right")

# Save the plot.
fig.tight_layout()
fig.savefig(EXAMPLES_DIR / "visualization_gamma.png", dpi=300)
print("Saved visualization to", EXAMPLES_DIR / "visualization_gamma.png")


def scale_example(
loc: float,
scale: float,
num_eval_points: int,
) -> None:
"""Example showing the effect of the random variable's scale on the CRPS and SCRPS of a distribution.

Args:
loc: The location parameter of the Normal distribution.
scale: The scale parameter of the Normal distribution.
num_eval_points: The number of grid points to evaluate the functions on.
"""
assert loc > 0 and scale > 0
assert num_eval_points > 0

# Create a distribution (imagine a model's output).
p = torch.distributions.Normal(loc=loc, scale=scale)

# Define a grid for all evaluations.
eval_min, eval_max = loc - 4 * scale, loc + 4 * scale
y = torch.linspace(eval_min, eval_max, num_eval_points)

# Evaluate the probability, negative log-probability, and the CRPS on the grid.
p_y = p.log_prob(y).exp()
nll_y = -p.log_prob(y)
crps_y = crps_analytical(p, y)
scrps_y = scrps_analytical(p, y)
print(f"Evaluated p(y), NLL(p(y), y), and CRPS(p(y), y) on a grid of {num_eval_points} points")

# Plot the evaluations. Make the upper subplot 1/4 the height of the lower one
fig, ax = plt.subplots(
nrows=2,
ncols=1,
figsize=(12, 8),
gridspec_kw={"height_ratios": [1, 4]},
sharex=True,
)
y_plot = y.cpu().numpy()

# Upper (smaller) subplot: probability density
ax[0].plot(y_plot, p_y.cpu().numpy(), color="C0", label=f"p(x) = Normal(loc={loc}, scale={scale})")
ax[0].set_ylabel("p(x)")
ax[0].legend(loc="upper right")

# Lower (larger) subplot: NLL, CRPS, SCRPS
ax[1].plot(y_plot, nll_y.cpu().numpy(), color="C1", label="NLL(p(x), y)")
ax[1].plot(y_plot, crps_y.cpu().numpy(), color="C2", label="CRPS(p(x), y)")
ax[1].plot(y_plot, scrps_y.cpu().numpy(), color="C3", label="SCRPS(p(x), y)")
ax[1].set_xlabel("observation y")
ax[1].set_ylabel("value")
ax[1].legend(loc="upper center")

# Save the plot.
fig.tight_layout()
fig.savefig(EXAMPLES_DIR / "visualization_normal.png", dpi=300)
print("Saved visualization to", EXAMPLES_DIR / "visualization_normal.png")


if __name__ == "__main__":
seaborn.set_theme()

torch.set_default_dtype(torch.float32)
torch.manual_seed(0)

gamma_example(
concentration=3,
rate=4,
eval_min=0.01,
eval_max=2.5,
num_eval_points=5000,
ensemble_size=2000,
)

scale_example(
loc=1000,
scale=20,
num_eval_points=1000,
)
Binary file added examples/visualization_gamma.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/visualization_normal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ distance-dirty = "{base_version}"

[tool.mypy]
ignore_missing_imports = true # when no stubs are available, e.g. for matplotlib or tabulate
pretty = true
show_error_context = true
show_traceback = true
pretty = true
show_error_context = true
show_traceback = true

[tool.pytest.ini_options]
addopts = [
Expand Down Expand Up @@ -146,6 +146,9 @@ ignore = [
preview = true

[tool.ruff.lint.per-file-ignores]
"examples/*" = [
"S101", # use of assert detected
]
"tests/*" = [
"ANN201", # return type information
"D", # pydocstyle
Expand Down
Loading