-
Notifications
You must be signed in to change notification settings - Fork 0
Tests for importance sample creation #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
am-kaiser
wants to merge
7
commits into
main
Choose a base branch
from
97_test_importance_sampling_creation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+136
−38
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9d8588c
Merge branch 'main' into 97_test_importance_sampling_creation
am-kaiser b5609c1
Merge branch 'main' into 97_test_importance_sampling_creation
am-kaiser a9df8f6
Update importance_sampling_distribution_uniform_region to use iterati…
saraelme 0f00115
Merge branch '97_test_importance_sampling_creation' of https://github…
saraelme 0fab44f
fixed path
am-kaiser 77106c8
Merge branch 'main' into 97_test_importance_sampling_creation
am-kaiser 48f89eb
Merge branch 'main' into 97_test_importance_sampling_creation
am-kaiser File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,12 @@ | ||
| """This module provides methods for importance sampling. | ||
| # %% | ||
| """This module provides methods to create importance samples and weights if the env data distribution is known. | ||
|
|
||
| Importance sampling focuses computational effort on regions of interest. This is especially useful when only a small | ||
| part of the environment data contributes meaningfully to the quantity of interest (QoI). | ||
|
|
||
| In this file the following is included: | ||
| - Create importance sample and weights for a given importance distribution. | ||
| - Create importance sample and weights for a uniform region. | ||
|
|
||
| TODO(sw 25-05-26): This should be moved to src/axtreme/sampling once sufficiently tested. | ||
| """ | ||
|
|
||
| from collections.abc import Callable | ||
|
|
@@ -17,7 +16,7 @@ | |
| torch.set_default_dtype(torch.float64) | ||
|
|
||
|
|
||
| # TODO(sw25-05-26): make he docstring here a bit clearer. Bit more explanation about the input functions. | ||
| # TODO(sw25-05-26): make the docstring here a bit clearer. Bit more explanation about the input functions. | ||
| # Maybe change the order | ||
| def importance_sampling_from_distribution( | ||
| env_distribution_pdf: Callable[[torch.Tensor], torch.Tensor], | ||
|
|
@@ -59,8 +58,6 @@ def importance_sampling_from_distribution( | |
|
|
||
|
|
||
| # TODO(sw25-05-26): make the docstring here a bit clearer | ||
| # num_samples_total; this should probably define the number of samples to return, and then continue generating them | ||
| # we get enough. (This will also need to track some summary statistics) | ||
| def importance_sampling_distribution_uniform_region( | ||
| env_distribution_pdf: Callable[[torch.Tensor], torch.Tensor], | ||
| region: torch.Tensor | tuple[torch.Tensor, torch.Tensor], | ||
|
|
@@ -78,8 +75,8 @@ def importance_sampling_distribution_uniform_region( | |
|
|
||
| Args: | ||
| env_distribution_pdf: The pdf function of the real environment distribution. | ||
| It should be callable with a tensor of shape (num_samples, d) and return a tensor of shape (num_samples,) | ||
| Where d is the size of the input space. | ||
| It should be callable with a tensor of shape (num_samples_total, d) and return a tensor of shape | ||
| (num_samples_total,). Where d is the size of the input space. | ||
| region: The bounds of the region to generate samples from. Can be a tuple of two tensors or a single tensor. | ||
|
|
||
| if a single tensor: | ||
|
|
@@ -93,14 +90,13 @@ def importance_sampling_distribution_uniform_region( | |
|
|
||
| threshold: Environment regions with pdf values less than this threshold will not be explored by the importance | ||
| samples. See `Details` for more information on how this threshold is used. | ||
| # TODO(ak-06-10): see comment above by Sebastian: num_samples_total should equal number of returned samples | ||
| num_samples_total: Total number of samples to draw uniformly before filtering. The actual number of | ||
| returned samples may be smaller depending on how many pass the threshold filter. | ||
|
|
||
| num_samples_total: Total number of samples to return. | ||
|
|
||
| Returns: | ||
| A tuple (Tensor, Tensor) containing: | ||
| The filtered samples drawn from the uniform distribution. Shape (n_samples,d) | ||
| Importance sampling weights for each sample. Shape (n_samples,) | ||
| The filtered samples drawn from the uniform distribution. Shape (num_samples_total,d) | ||
| Importance sampling weights for each sample. Shape (num_samples_total,) | ||
|
|
||
| Details: | ||
| The mathematical justification for this algorithm is given in | ||
|
|
@@ -118,10 +114,13 @@ def importance_sampling_distribution_uniform_region( | |
| 2. Generate `num_samples_total` uniform samples from the region. The region must cover all of F. | ||
|
|
||
| 2.1 Discard any points not in F. `num_samples` is the number of points that are left after discarding. | ||
| 2.2. If `num_samples` is less than `num_samples_total`, repeat step 2 until enough samples are generated. | ||
|
|
||
| 3. The PDF of the sampled points `h_x(x)` is a uniform distribution over the region F. | ||
|
|
||
| 3.1 `h_x(x)` is estimated with `1/volume(region) * num_samples_total/num_samples`. | ||
| 3.1 `h_x(x)` is estimated as `total_proposals / (volume(region) * total_accepted)`, where | ||
| `total_proposals` is the total number of uniform samples generated and `total_accepted` is | ||
| the total number of samples that passed the threshold. | ||
|
|
||
| 4. The importance sampling weights are then calculated as w(x) = p(x)/h_x(x) | ||
|
|
||
|
|
@@ -156,35 +155,53 @@ def importance_sampling_distribution_uniform_region( | |
| - The point would add 0 to the non-importance weighted sum. | ||
| - It will produce an approximate result if r(x_i) != 0. | ||
| - This is a reasonable approximation if p(x_i) is considered to be close enough to 0. | ||
|
|
||
| Todo: TODO | ||
| - (ak 2025-07-09): add seeded option for testing purposes | ||
|
|
||
| """ | ||
| uniform_dist = torch.distributions.Uniform(region[0], region[1]) | ||
|
|
||
| # Generate samples from the uniform distribution over the region | ||
| samples = uniform_dist.sample(torch.Size([num_samples_total])) | ||
| accepted_samples = [] | ||
| accepted_pdfs = [] | ||
|
|
||
| # Calculate the probability density of the samples | ||
| pdf = env_distribution_pdf(samples) | ||
| total_accepted = 0 # All accepted samples (for h_x estimation) | ||
| total_proposals = 0 # All uniform draws | ||
| samples_collected = 0 # Samples stored for output | ||
|
|
||
| # Find the samples that are above the threshold | ||
| mask = pdf > threshold | ||
| samples = samples[mask] | ||
| while samples_collected < num_samples_total: | ||
| # Draw a batch of proposals | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You run into the danger of having an infinite loop here, i.e. if env_distribution_pdf never returns values > threshold you’ll loop forever. Suggested solution: add a max proposals/trials limit and raise a clear error if nothing is accepted. |
||
| proposals = uniform_dist.sample(torch.Size([num_samples_total])) | ||
| total_proposals += num_samples_total | ||
|
|
||
| # Calculate the volume of the hyper rectangle that contains the samples | ||
| volume = torch.prod(region[1] - region[0]) | ||
| # Evaluate PDF and apply threshold | ||
| pdf_values = env_distribution_pdf(proposals) | ||
| acceptance_mask = pdf_values > threshold | ||
|
|
||
| if not acceptance_mask.any(): | ||
| continue | ||
|
|
||
| # The number of samples that are above the threshold | ||
| num_samples = samples.shape[0] | ||
| batch_accepted = proposals[acceptance_mask] | ||
| batch_pdfs = pdf_values[acceptance_mask] | ||
|
|
||
| # Calculate the importance sampling distribution | ||
| # The importance sampling distribution is estimated to be | ||
| # h_x(x) = num_samples_total/(volume(region) * num_samples) | ||
| h_x = num_samples_total / (volume * num_samples) | ||
| # Track ALL accepted samples for h(x) | ||
| num_accepted_this_batch = batch_accepted.shape[0] | ||
| total_accepted += num_accepted_this_batch | ||
|
|
||
| # Store only what we still need | ||
| remaining_needed = num_samples_total - samples_collected | ||
| num_to_store = min(remaining_needed, num_accepted_this_batch) | ||
|
|
||
| accepted_samples.append(batch_accepted[:num_to_store]) | ||
| accepted_pdfs.append(batch_pdfs[:num_to_store]) | ||
|
|
||
| samples_collected += num_to_store | ||
|
|
||
| # Concatenate all collected samples | ||
| samples = torch.cat(accepted_samples, dim=0) | ||
| pdf_values = torch.cat(accepted_pdfs, dim=0) | ||
|
|
||
| # Calculate proposal distribution density using total accepted | ||
| volume = torch.prod(region[1] - region[0]) | ||
| h_x = total_proposals / (volume * total_accepted) | ||
|
|
||
| # Calculate the importance sampling weights | ||
| weights = pdf[mask] / h_x | ||
| # Compute importance weights | ||
| weights = pdf_values / h_x | ||
|
|
||
| return samples, weights | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| """ | ||
| Test the creation of importance sampling. | ||
|
|
||
| The theory is derived and explained in: | ||
|
|
||
| Winter, Sebastian, Christian Agrell, Juan Camilo Guevara Gómez, and Erik Vanem. “Efficient Long-Term Structural | ||
| Reliability Estimation with Non-Gaussian Stochastic Models: A Design of Experiments Approach.” arXiv, March 3, 2025. | ||
| https://doi.org/10.48550/arXiv.2503.01566. | ||
|
|
||
| and is therefore not part of the tests. The focus here is on unit tests of the two functions used to create the | ||
| importance samples and weights for a given environment distribution. | ||
| """ | ||
|
|
||
| from unittest.mock import patch | ||
|
|
||
| import torch | ||
|
|
||
| from axtreme.sampling.importance_sampling import ( | ||
| importance_sampling_distribution_uniform_region, | ||
| importance_sampling_from_distribution, | ||
| ) | ||
|
|
||
|
|
||
| def test_importance_sampling_from_distribution(): | ||
| """Basic test to see if the function runs and returns the expected output with simple Callables.""" | ||
|
|
||
| samples = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
|
|
||
| def _mock_sampler(size: torch.Size) -> torch.Tensor: | ||
| return samples[: size[0]] | ||
|
|
||
| def _env_distribution_pdf(x: torch.Tensor) -> torch.Tensor: | ||
| return x + 1 | ||
|
|
||
| def _importance_distribution_pdf(x: torch.Tensor) -> torch.Tensor: | ||
| return x + 2 | ||
|
|
||
| samples, weights = importance_sampling_from_distribution( | ||
| env_distribution_pdf=_env_distribution_pdf, | ||
| importance_distribution_pdf=_importance_distribution_pdf, | ||
| importance_sampling_sampler=_mock_sampler, | ||
| num_samples=3, | ||
| ) | ||
|
|
||
| expected_samples = torch.tensor([1.0, 2.0, 3.0]) | ||
| expected_weights = torch.tensor([2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0]) | ||
|
|
||
| assert torch.equal(samples, expected_samples) | ||
| assert torch.equal(weights, expected_weights) | ||
|
|
||
|
|
||
| def test_importance_sampling_distribution_uniform_region(): | ||
| """Basic test to see if the function runs and returns the expected output with simple Callables.""" | ||
|
|
||
| def _env_distribution_pdf(x: torch.Tensor) -> torch.Tensor: | ||
| return x | ||
|
|
||
| region = (torch.tensor([0.0]), torch.tensor([4.0])) | ||
|
|
||
| threshold = 2 | ||
| num_samples_total = 6 | ||
| fixed_samples = torch.tensor([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) | ||
|
|
||
| # Temporarily replace the .sample() method of torch.distributions.Uniform with deterministic version that always | ||
| # returns the predefined fixed_samples during the with block. This is necessary as no seeding is implemented for | ||
| # the function importance_sampling_distribution_uniform_region. | ||
| with patch.object(torch.distributions.Uniform, "sample", return_value=fixed_samples): | ||
| samples, weights = importance_sampling_distribution_uniform_region( | ||
| env_distribution_pdf=_env_distribution_pdf, | ||
| region=region, | ||
| threshold=threshold, | ||
| num_samples_total=num_samples_total, | ||
| ) | ||
|
|
||
| expected_samples = torch.Tensor([2.5, 3.0, 2.5, 3.0, 2.5, 3.0]) | ||
| expected_weights = (4 / 3) * expected_samples | ||
|
|
||
| assert torch.equal(samples, expected_samples) | ||
| torch.testing.assert_close(weights, expected_weights) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.