From a9df8f64870d3843f73936f3f511d5683761a0f2 Mon Sep 17 00:00:00 2001 From: El Mekkaoui Date: Tue, 9 Dec 2025 10:42:50 +0100 Subject: [PATCH 1/2] Update importance_sampling_distribution_uniform_region to use iterative sampling --- src/axtreme/sampling/importance_sampling.py | 92 +++++++++++---------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/src/axtreme/sampling/importance_sampling.py b/src/axtreme/sampling/importance_sampling.py index 642edd8d..b65cba48 100644 --- a/src/axtreme/sampling/importance_sampling.py +++ b/src/axtreme/sampling/importance_sampling.py @@ -12,7 +12,6 @@ from collections.abc import Callable import torch -from torch.distributions.distribution import Distribution torch.set_default_dtype(torch.float64) @@ -119,7 +118,9 @@ def importance_sampling_distribution_uniform_region( 3. The PDF of the sampled points `h_x(x)` is a uniform distribution over the region F. - 3.1 `h_x(x)` is estimated with `1/volume(region)`. + 3.1 `h_x(x)` is estimated as `total_proposals / (volume(region) * total_accepted)`, where + `total_proposals` is the total number of uniform samples generated and `total_accepted` is + the total number of samples that passed the threshold. 4. The importance sampling weights are then calculated as w(x) = p(x)/h_x(x) @@ -157,47 +158,50 @@ def importance_sampling_distribution_uniform_region( """ uniform_dist = torch.distributions.Uniform(region[0], region[1]) - def _create_samples_and_weights( - dist: Distribution, num_samples_to_create: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """Create samples and weights from a uniform distribution over a defined region.""" - # Generate samples from the uniform distribution over the region - samples = dist.sample(torch.Size([num_samples_to_create])) - - # Calculate the probability density of the samples - pdf = env_distribution_pdf(samples) - - # Find the samples that are above the threshold - mask = pdf > threshold - samples = samples[mask] - - # Calculate the volume of the hyper rectangle that contains the samples - volume = torch.prod(region[1] - region[0]) - - # The number of samples that are above the threshold - num_samples = samples.shape[0] - - # Calculate the importance sampling distribution - # The importance sampling distribution is estimated to be - # h_x(x) = num_samples_total/(volume(region) * num_samples) - h_x = num_samples_to_create / (volume * num_samples) - - # Calculate the importance sampling weights - weights = pdf[mask] / h_x - - return samples, weights - - # Keep creating importance samples and weights until there are num_samples_total of them - samples = torch.empty((0,)) - weights = torch.empty((0,)) - while len(samples) < num_samples_total: - # If we were to create num_samples_total-len(samples) samples starting in the second iteration the weights would - # be inconsistent due to the definition of h_x. Hence, we need to create too many samples and then take only - # the needed amount of samples. This is computationally inefficient but as _create_samples_and_weights runs fast - # this is acceptable. - s, w = _create_samples_and_weights(uniform_dist, num_samples_total) - num_missing_samples = min(num_samples_total - len(samples), len(s)) - samples = torch.cat((samples, s[:num_missing_samples])) - weights = torch.cat((weights, w[:num_missing_samples])) + accepted_samples = [] + accepted_pdfs = [] + + total_accepted = 0 # All accepted samples (for h_x estimation) + total_proposals = 0 # All uniform draws + samples_collected = 0 # Samples stored for output + + while samples_collected < num_samples_total: + # Draw a batch of proposals + proposals = uniform_dist.sample(torch.Size([num_samples_total])) + total_proposals += num_samples_total + + # Evaluate PDF and apply threshold + pdf_values = env_distribution_pdf(proposals) + acceptance_mask = pdf_values > threshold + + if not acceptance_mask.any(): + continue + + batch_accepted = proposals[acceptance_mask] + batch_pdfs = pdf_values[acceptance_mask] + + # Track ALL accepted samples for h(x) + num_accepted_this_batch = batch_accepted.shape[0] + total_accepted += num_accepted_this_batch + + # Store only what we still need + remaining_needed = num_samples_total - samples_collected + num_to_store = min(remaining_needed, num_accepted_this_batch) + + accepted_samples.append(batch_accepted[:num_to_store]) + accepted_pdfs.append(batch_pdfs[:num_to_store]) + + samples_collected += num_to_store + + # Concatenate all collected samples + samples = torch.cat(accepted_samples, dim=0) + pdf_values = torch.cat(accepted_pdfs, dim=0) + + # Calculate proposal distribution density using total accepted + volume = torch.prod(region[1] - region[0]) + h_x = total_proposals / (volume * total_accepted) + + # Compute importance weights + weights = pdf_values / h_x return samples, weights From 0fab44f7c0ddc94b550d438c3efe6020f5c4b003 Mon Sep 17 00:00:00 2001 From: am-kaiser <63399571+am-kaiser@users.noreply.github.com> Date: Wed, 10 Dec 2025 14:57:12 +0100 Subject: [PATCH 2/2] fixed path --- tutorials/importance_sampling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tutorials/importance_sampling.py b/tutorials/importance_sampling.py index d8aa4dca..5d98c5e6 100644 --- a/tutorials/importance_sampling.py +++ b/tutorials/importance_sampling.py @@ -77,6 +77,7 @@ from axtreme.plotting.gp_fit import plot_surface_over_2d_search_space from axtreme.plotting.histogram3d import histogram_surface3d from axtreme.qoi import MarginalCDFExtrapolation +from axtreme.sampling.importance_sampling import importance_sampling_distribution_uniform_region from axtreme.sampling.ut_sampler import UTSampler from axtreme.utils import population_estimators, transforms @@ -85,8 +86,6 @@ root_dir = Path("../") sys.path.append(str(root_dir)) -# TODO (ak:25-08-06): change path when file is moved to src -from examples.crest_heights_north_sea.importance_sampling import importance_sampling_distribution_uniform_region from examples.tutorials.importance_sampling.problem.brute_force import collect_or_calculate_results from examples.tutorials.importance_sampling.problem.env_data import ( calculate_environment_distribution,