Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import numpy as np
import torch
from importance_sampling import importance_sampling_distribution_uniform_region # type: ignore[import-not-found]
from matplotlib import pyplot as plt
from usecase.env_data import env_pdf # type: ignore[import-not-found]

from axtreme.sampling.importance_sampling import ( # type: ignore[import-not-found]
importance_sampling_distribution_uniform_region,
)

torch.set_default_dtype(torch.float64)

# Initialize _ as Any to avoid mypy type checking issues
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""This module provides methods for importance sampling.
# %%
"""This module provides methods to create importance samples and weights if the env data distribution is known.

Importance sampling focuses computational effort on regions of interest. This is especially useful when only a small
part of the environment data contributes meaningfully to the quantity of interest (QoI).

In this file the following is included:
- Create importance sample and weights for a given importance distribution.
- Create importance sample and weights for a uniform region.

TODO(sw 25-05-26): This should be moved to src/axtreme/sampling once sufficiently tested.
"""

from collections.abc import Callable
Expand All @@ -17,7 +16,7 @@
torch.set_default_dtype(torch.float64)


# TODO(sw25-05-26): make he docstring here a bit clearer. Bit more explanation about the input functions.
# TODO(sw25-05-26): make the docstring here a bit clearer. Bit more explanation about the input functions.
# Maybe change the order
def importance_sampling_from_distribution(
env_distribution_pdf: Callable[[torch.Tensor], torch.Tensor],
Expand Down Expand Up @@ -59,8 +58,6 @@ def importance_sampling_from_distribution(


# TODO(sw25-05-26): make the docstring here a bit clearer
# num_samples_total; this should probably define the number of samples to return, and then continue generating them
# we get enough. (This will also need to track some summary statistics)
def importance_sampling_distribution_uniform_region(
env_distribution_pdf: Callable[[torch.Tensor], torch.Tensor],
region: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
Expand All @@ -78,8 +75,8 @@ def importance_sampling_distribution_uniform_region(

Args:
env_distribution_pdf: The pdf function of the real environment distribution.
It should be callable with a tensor of shape (num_samples, d) and return a tensor of shape (num_samples,)
Where d is the size of the input space.
It should be callable with a tensor of shape (num_samples_total, d) and return a tensor of shape
(num_samples_total,). Where d is the size of the input space.
region: The bounds of the region to generate samples from. Can be a tuple of two tensors or a single tensor.

if a single tensor:
Expand All @@ -93,14 +90,13 @@ def importance_sampling_distribution_uniform_region(

threshold: Environment regions with pdf values less than this threshold will not be explored by the importance
samples. See `Details` for more information on how this threshold is used.
# TODO(ak-06-10): see comment above by Sebastian: num_samples_total should equal number of returned samples
num_samples_total: Total number of samples to draw uniformly before filtering. The actual number of
returned samples may be smaller depending on how many pass the threshold filter.

num_samples_total: Total number of samples to return.

Returns:
A tuple (Tensor, Tensor) containing:
The filtered samples drawn from the uniform distribution. Shape (n_samples,d)
Importance sampling weights for each sample. Shape (n_samples,)
The filtered samples drawn from the uniform distribution. Shape (num_samples_total,d)
Importance sampling weights for each sample. Shape (num_samples_total,)

Details:
The mathematical justification for this algorithm is given in
Expand All @@ -118,10 +114,13 @@ def importance_sampling_distribution_uniform_region(
2. Generate `num_samples_total` uniform samples from the region. The region must cover all of F.

2.1 Discard any points not in F. `num_samples` is the number of points that are left after discarding.
2.2. If `num_samples` is less than `num_samples_total`, repeat step 2 until enough samples are generated.

3. The PDF of the sampled points `h_x(x)` is a uniform distribution over the region F.

3.1 `h_x(x)` is estimated with `1/volume(region) * num_samples_total/num_samples`.
3.1 `h_x(x)` is estimated as `total_proposals / (volume(region) * total_accepted)`, where
`total_proposals` is the total number of uniform samples generated and `total_accepted` is
the total number of samples that passed the threshold.

4. The importance sampling weights are then calculated as w(x) = p(x)/h_x(x)

Expand Down Expand Up @@ -156,35 +155,53 @@ def importance_sampling_distribution_uniform_region(
- The point would add 0 to the non-importance weighted sum.
- It will produce an approximate result if r(x_i) != 0.
- This is a reasonable approximation if p(x_i) is considered to be close enough to 0.

Todo: TODO
- (ak 2025-07-09): add seeded option for testing purposes

"""
uniform_dist = torch.distributions.Uniform(region[0], region[1])

# Generate samples from the uniform distribution over the region
samples = uniform_dist.sample(torch.Size([num_samples_total]))
accepted_samples = []
accepted_pdfs = []

# Calculate the probability density of the samples
pdf = env_distribution_pdf(samples)
total_accepted = 0 # All accepted samples (for h_x estimation)
total_proposals = 0 # All uniform draws
samples_collected = 0 # Samples stored for output

# Find the samples that are above the threshold
mask = pdf > threshold
samples = samples[mask]
while samples_collected < num_samples_total:
# Draw a batch of proposals
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You run into the danger of having an infinite loop here, i.e. if env_distribution_pdf never returns values > threshold you’ll loop forever. Suggested solution: add a max proposals/trials limit and raise a clear error if nothing is accepted.

proposals = uniform_dist.sample(torch.Size([num_samples_total]))
total_proposals += num_samples_total

# Calculate the volume of the hyper rectangle that contains the samples
volume = torch.prod(region[1] - region[0])
# Evaluate PDF and apply threshold
pdf_values = env_distribution_pdf(proposals)
acceptance_mask = pdf_values > threshold

if not acceptance_mask.any():
continue

# The number of samples that are above the threshold
num_samples = samples.shape[0]
batch_accepted = proposals[acceptance_mask]
batch_pdfs = pdf_values[acceptance_mask]

# Calculate the importance sampling distribution
# The importance sampling distribution is estimated to be
# h_x(x) = num_samples_total/(volume(region) * num_samples)
h_x = num_samples_total / (volume * num_samples)
# Track ALL accepted samples for h(x)
num_accepted_this_batch = batch_accepted.shape[0]
total_accepted += num_accepted_this_batch

# Store only what we still need
remaining_needed = num_samples_total - samples_collected
num_to_store = min(remaining_needed, num_accepted_this_batch)

accepted_samples.append(batch_accepted[:num_to_store])
accepted_pdfs.append(batch_pdfs[:num_to_store])

samples_collected += num_to_store

# Concatenate all collected samples
samples = torch.cat(accepted_samples, dim=0)
pdf_values = torch.cat(accepted_pdfs, dim=0)

# Calculate proposal distribution density using total accepted
volume = torch.prod(region[1] - region[0])
h_x = total_proposals / (volume * total_accepted)

# Calculate the importance sampling weights
weights = pdf[mask] / h_x
# Compute importance weights
weights = pdf_values / h_x

return samples, weights
79 changes: 79 additions & 0 deletions tests/sampling/test_importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Test the creation of importance sampling.

The theory is derived and explained in:

Winter, Sebastian, Christian Agrell, Juan Camilo Guevara Gómez, and Erik Vanem. “Efficient Long-Term Structural
Reliability Estimation with Non-Gaussian Stochastic Models: A Design of Experiments Approach.” arXiv, March 3, 2025.
https://doi.org/10.48550/arXiv.2503.01566.

and is therefore not part of the tests. The focus here is on unit tests of the two functions used to create the
importance samples and weights for a given environment distribution.
"""

from unittest.mock import patch

import torch

from axtreme.sampling.importance_sampling import (
importance_sampling_distribution_uniform_region,
importance_sampling_from_distribution,
)


def test_importance_sampling_from_distribution():
"""Basic test to see if the function runs and returns the expected output with simple Callables."""

samples = torch.tensor([1.0, 2.0, 3.0, 4.0])

def _mock_sampler(size: torch.Size) -> torch.Tensor:
return samples[: size[0]]

def _env_distribution_pdf(x: torch.Tensor) -> torch.Tensor:
return x + 1

def _importance_distribution_pdf(x: torch.Tensor) -> torch.Tensor:
return x + 2

samples, weights = importance_sampling_from_distribution(
env_distribution_pdf=_env_distribution_pdf,
importance_distribution_pdf=_importance_distribution_pdf,
importance_sampling_sampler=_mock_sampler,
num_samples=3,
)

expected_samples = torch.tensor([1.0, 2.0, 3.0])
expected_weights = torch.tensor([2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0])

assert torch.equal(samples, expected_samples)
assert torch.equal(weights, expected_weights)


def test_importance_sampling_distribution_uniform_region():
"""Basic test to see if the function runs and returns the expected output with simple Callables."""

def _env_distribution_pdf(x: torch.Tensor) -> torch.Tensor:
return x

region = (torch.tensor([0.0]), torch.tensor([4.0]))

threshold = 2
num_samples_total = 6
fixed_samples = torch.tensor([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

# Temporarily replace the .sample() method of torch.distributions.Uniform with deterministic version that always
# returns the predefined fixed_samples during the with block. This is necessary as no seeding is implemented for
# the function importance_sampling_distribution_uniform_region.
with patch.object(torch.distributions.Uniform, "sample", return_value=fixed_samples):
samples, weights = importance_sampling_distribution_uniform_region(
env_distribution_pdf=_env_distribution_pdf,
region=region,
threshold=threshold,
num_samples_total=num_samples_total,
)

expected_samples = torch.Tensor([2.5, 3.0, 2.5, 3.0, 2.5, 3.0])
expected_weights = (4 / 3) * expected_samples

assert torch.equal(samples, expected_samples)
torch.testing.assert_close(weights, expected_weights)
3 changes: 1 addition & 2 deletions tutorials/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from axtreme.plotting.gp_fit import plot_surface_over_2d_search_space
from axtreme.plotting.histogram3d import histogram_surface3d
from axtreme.qoi import MarginalCDFExtrapolation
from axtreme.sampling.importance_sampling import importance_sampling_distribution_uniform_region
from axtreme.sampling.ut_sampler import UTSampler
from axtreme.utils import population_estimators, transforms

Expand All @@ -85,8 +86,6 @@
root_dir = Path("../")
sys.path.append(str(root_dir))

# TODO (ak:25-08-06): change path when file is moved to src
from examples.crest_heights_north_sea.importance_sampling import importance_sampling_distribution_uniform_region
from examples.tutorials.importance_sampling.problem.brute_force import collect_or_calculate_results
from examples.tutorials.importance_sampling.problem.env_data import (
calculate_environment_distribution,
Expand Down
Loading