Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,572 changes: 1,572 additions & 0 deletions notebooks/ADVI Guide API.ipynb

Large diffs are not rendered by default.

Empty file.
93 changes: 93 additions & 0 deletions pymc_extras/inference/advi/autoguide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field

import numpy as np
import pytensor.tensor as pt

from pymc.distributions import Normal
from pymc.logprob.basic import conditional_logp
from pymc.model.core import Deterministic, Model
from pytensor import graph_replace
from pytensor.gradient import disconnected_grad
from pytensor.graph.basic import Variable

from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes


@dataclass(frozen=True)
class AutoGuideModel:
model: Model
params_init_values: dict[Variable, np.ndarray]
name_to_param: dict[str, Variable] = field(init=False)

def __post_init__(self):
object.__setattr__(
self,
"name_to_param",
{x.name: x for x in self.params_init_values.keys()},
)

@property
def params(self) -> tuple[Variable, ...]:
return tuple(self.params_init_values.keys())

def __getitem__(self, name: str) -> Variable:
return self.name_to_param[name]

def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable:
"""Returns a graph representing the logp of the guide model, evaluated under draws from its random variables."""
# This allows arbitrary
logp_terms = conditional_logp(
{rv: rv for rv in self.model.deterministics},
warn_rvs=False,
)
logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()])

if stick_the_landing:
# Detach variational parameters from the gradient computation of logq
repl = {p: disconnected_grad(p) for p in self.params}
logq = graph_replace(logq, repl)

return logq


def AutoDiagonalNormal(model) -> AutoGuideModel:
coords = model.coords
free_rvs = model.free_RVs

free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs)))
params_init_values = {}

with Model(coords=coords) as guide_model:
for rv in free_rvs:
loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape)
scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape)
# TODO: Make these customizable
params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval()
params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval()

z = Normal(
f"{rv.name}_z",
mu=0,
sigma=1,
shape=free_rv_shapes[rv],
)
Deterministic(
rv.name,
loc + pt.softplus(scale) * z,
dims=model.named_vars_to_dims.get(rv.name, None),
)

return AutoGuideModel(guide_model, params_init_values)
74 changes: 74 additions & 0 deletions pymc_extras/inference/advi/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Protocol

import numpy as np

from pymc import Model, compile
from pymc.pytensorf import rewrite_pregrad
from pytensor import tensor as pt

from pymc_extras.inference.advi.autoguide import AutoGuideModel
from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq
from pymc_extras.inference.advi.pytensorf import vectorize_random_graph


class TrainingFn(Protocol):
def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ...


def compile_svi_training_fn(
model: Model,
guide: AutoGuideModel,
stick_the_landing: bool = True,
minibatch: bool = False,
**compile_kwargs,
) -> TrainingFn:
draws = pt.scalar("draws", dtype=int)
params = guide.params
inputs = [draws, *params]

logp_scale = 1

if minibatch:
data = model.data_vars
inputs = [*inputs, *data]

logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing)

scalar_negative_elbo = advi_objective(logp / logp_scale, logq)
[negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws)
negative_elbo = negative_elbo_draws.mean(axis=0)

negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params)

if "trust_input" not in compile_kwargs:
compile_kwargs["trust_input"] = True

f_loss_dloss = compile(
inputs=inputs, outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs
)

return f_loss_dloss


def compile_sampling_fn(model: Model, guide: AutoGuideModel, **compile_kwargs) -> TrainingFn:
draws = pt.scalar("draws", dtype=int)
params = guide.params

parameterized_value_vars = [
guide.model[rv.name] for rv in model.rvs_to_values.keys() if rv not in model.observed_RVs
]
transformed_vars = [
transform.backward(parameterized_var)
if (transform := model.rvs_to_transforms[rv]) is not None
else parameterized_var
for rv, parameterized_var in zip(model.rvs_to_values.keys(), parameterized_value_vars)
]

sampled_rvs_draws = vectorize_random_graph(transformed_vars, batch_draws=draws)

if "trust_input" not in compile_kwargs:
compile_kwargs["trust_input"] = True

f_sample = compile(inputs=[draws, *params], outputs=sampled_rvs_draws, **compile_kwargs)

return f_sample
61 changes: 61 additions & 0 deletions pymc_extras/inference/advi/objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from pymc import Model
from pytensor import graph_replace
from pytensor.tensor import TensorVariable

from pymc_extras.inference.advi.autoguide import AutoGuideModel


def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True):
"""
Compute the log probability of the model and the guide.

Parameters
----------
model : Model
The probabilistic model.
guide : AutoGuideModel
The variational guide.
stick_the_landing : bool, optional
Whether to use the stick-the-landing (STL) gradient estimator, by default True.
The STL estimator has lower gradient variance by removing the score function term
from the gradient. When True, gradients are stopped from flowing through logq.

Returns
-------
logp : TensorVariable
Log probability of the model.
logq : TensorVariable
Log probability of the guide.
"""

inputs_to_guide_rvs = {
model_value_var: guide.model[rv.name]
for rv, model_value_var in model.rvs_to_values.items()
if rv not in model.observed_RVs
}

logp = graph_replace(model.logp(), inputs_to_guide_rvs)
logq = guide.stochastic_logq(stick_the_landing=stick_the_landing)

return logp, logq


def advi_objective(logp: TensorVariable, logq: TensorVariable):
"""Compute the negative ELBO objective for ADVI.

Parameters
----------
logp : TensorVariable
Log probability of the model.
logq : TensorVariable
Log probability of the guide.

Returns
-------
TensorVariable
The negative ELBO.
"""
negative_elbo = logq - logp
return negative_elbo
58 changes: 58 additions & 0 deletions pymc_extras/inference/advi/pytensorf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, cast

from pymc import SymbolicRandomVariable
from pymc.distributions.shape_utils import change_dist_size
from pytensor import config
from pytensor import tensor as pt
from pytensor.graph import FunctionGraph, ancestors, vectorize_graph
from pytensor.tensor import TensorLike, TensorVariable
from pytensor.tensor.basic import infer_shape_db
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.rewriting.shape import ShapeFeature

if TYPE_CHECKING:
pass


def vectorize_random_graph(
graph: Sequence[TensorVariable], batch_draws: TensorLike
) -> list[TensorVariable]:
# Find the root random nodes
rvs = tuple(
var
for var in ancestors(graph)
if (
var.owner is not None
and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable)
)
)
rvs_set = set(rvs)
root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set))

# Vectorize graph by vectorizing root RVs
batch_draws = pt.as_tensor(batch_draws, dtype=int)
vectorized_replacements = {
root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True)
for root_rv in root_rvs
}
return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements))


def get_symbolic_rv_shapes(
rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True
) -> tuple[TensorVariable, ...]:
# TODO: Move me to pymc.pytensorf, this is needed often

rv_shapes = [rv.shape for rv in rvs]
shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True)
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
infer_shape_db.default_query.rewrite(shape_fg)
rv_shapes = shape_fg.outputs

if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))):
raise ValueError(f"rv_shapes still depend the following rvs {overlap}")

return cast(tuple[TensorVariable, ...], tuple(rv_shapes))
Loading