From b29007c798002c8e5926fc7d2072d6aca4965b98 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 11 Dec 2021 02:40:53 +0000 Subject: [PATCH 1/3] Implementation of quantiles for messenger guides --- pyro/infer/autoguide/effect.py | 52 ++++++++++++++++++++++++++++++++++ tests/infer/test_autoguide.py | 2 ++ 2 files changed, 54 insertions(+) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index d7abe537b0..68df52e03f 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -157,12 +157,16 @@ def __init__( self.init_loc_fn = init_loc_fn self._init_scale = init_scale self._computing_median = False + self._computing_quantiles = False + self._quantile_values = None def get_posterior( self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) @@ -205,11 +209,30 @@ def median(self, *args, **kwargs): finally: self._computing_median = False + @torch.no_grad() def _get_posterior_median(self, name, prior): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) return transform(loc) + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + site_quantiles = torch.tensor( + self._quantile_values, dtype=loc.dtype, device=loc.device + ) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + return transform(site_quantiles_values) + class AutoHierarchicalNormalMessenger(AutoNormalMessenger): """ @@ -263,12 +286,16 @@ def __init__( self._init_weight = init_weight self._hierarchical_sites = hierarchical_sites self._computing_median = False + self._computing_quantiles = False + self._quantile_values = None def get_posterior( self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) @@ -351,6 +378,7 @@ def median(self, *args, **kwargs): finally: self._computing_median = False + @torch.no_grad() def _get_posterior_median(self, name, prior): transform = biject_to(prior.support) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): @@ -360,6 +388,30 @@ def _get_posterior_median(self, name, prior): loc, scale = self._get_params(name, prior) return transform(loc) + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + loc, scale, weight = self._get_params(name, prior) + loc = loc + transform.inv(prior.mean) * weight + else: + loc, scale = self._get_params(name, prior) + + site_quantiles = torch.tensor( + self._quantile_values, dtype=loc.dtype, device=loc.device + ) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + raise ValueError(site_quantiles_values.shape) + return transform(site_quantiles_values) + class AutoRegressiveMessenger(AutoMessenger): """ diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 76a190fd91..561107a770 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -522,6 +522,8 @@ def AutoGuideList_x(model): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGuideList_x, + AutoNormalMessenger, + AutoHierarchicalNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) From 0d2ac51c1467fde50c2c6b7491727103cea4e741 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 26 Jan 2022 00:24:55 +0000 Subject: [PATCH 2/3] Deleted ValueError Accidentally committed raise error used for debugging --- pyro/infer/autoguide/effect.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 68df52e03f..6de36a87b3 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -409,7 +409,6 @@ def _get_posterior_quantiles(self, name, prior): self._quantile_values, dtype=loc.dtype, device=loc.device ) site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) - raise ValueError(site_quantiles_values.shape) return transform(site_quantiles_values) From 9268bff29c129223f060f90c4d693b14ca468ee4 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sun, 26 Nov 2023 22:21:03 +0000 Subject: [PATCH 3/3] add xfail_messenger to quantile test --- tests/infer/test_autoguide.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index 09332ec8e3..d3d7a7d7e9 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -528,6 +528,7 @@ def AutoGuideList_x(model): ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_quantiles(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) def model(): pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1))