Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.util import RandomState
from pytensor import Variable, graph_replace
from pytensor.graph.traversal import explicit_graph_inputs
from rich.box import SIMPLE_HEAD
from rich.console import Console
from rich.table import Table
Expand Down Expand Up @@ -266,6 +267,7 @@ def __init__(
self._fit_dims: dict[str, Sequence[str]] | None = None
self._fit_data: pt.TensorVariable | None = None
self._fit_exog_data: dict[str, dict] = {}
self._shared_timestep: pt.TensorVariable | None = None

self._needs_exog_data = None
self._tensor_variable_info = SymbolicVariableInfo()
Expand All @@ -279,6 +281,9 @@ def __init__(

self._populate_properties()

# Placeholder for time-varying matrices that depend on data length
self._n_timesteps_placeholder = pt.iscalar("n_timesteps")

# All models contain a state space representation and a Kalman filter
self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)

Expand Down Expand Up @@ -494,6 +499,11 @@ def unpack_statespace(self) -> list[pt.TensorVariable]:

return self.subbed_ssm

@property
def n_timesteps(self) -> Variable:
"""Symbolic placeholder for the number of time steps in the data."""
return self._n_timesteps_placeholder

@property
def param_names(self) -> tuple[str, ...]:
"""
Expand Down Expand Up @@ -894,6 +904,56 @@ def _insert_data_variables(self):
replacement_dict = {data: pymc_model[name] for name, data in self._name_to_data.items()}
self.subbed_ssm = graph_replace(self.subbed_ssm, replace=replacement_dict, strict=True)

def _insert_data_shape_into_n_timesteps(self, data):
"""
Replace any occurrence of the n_timesteps symbolic variable with the length of the data.

n_timesteps is a special symbolic variable used by graphs with time-varying matrices, whose shapes won't be
known until the user provides data. We need to collect them and replace them with a single common shared
variable to define all shapes consistently.
"""
# This method should only be called after data has been ingested and transformed into a pytensor variable.
# Otherwise, we don't get symbolic linkage between time-varying matrix shapes and the data when the user calls
# pm.set_data
assert isinstance(data, pt.TensorVariable)
matrices = (
self.subbed_ssm
if self.subbed_ssm is not None
else self._unpack_statespace_with_placeholders()
)

n_timestep_variables = tuple(
variable
for variable in explicit_graph_inputs(matrices)
if variable.name == "n_timesteps"
)

if n_timestep_variables:
self._shared_timestep = data.shape[0].astype("int32")
replacement_dict = {var: self._shared_timestep for var in n_timestep_variables}

self.subbed_ssm = graph_replace(self.subbed_ssm, replace=replacement_dict, strict=False)

def _insert_constant_timestep(self, matrices, step: int | pt.TensorVariable):
"""
Replace any occurrence of the n_timesteps symbolic variable with a constant integer.

This is used for constructing graphs for prior predictive sampling, where no data is available.
"""
step = pt.as_tensor_variable(step).astype("int32")

n_timestep_variables = tuple(
variable
for variable in explicit_graph_inputs(matrices)
if variable.name == "n_timesteps"
)

if not n_timestep_variables:
return matrices

replacement_dict = {var: step for var in n_timestep_variables}
return graph_replace(matrices, replace=replacement_dict, strict=False)

def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
"""
Add all statespace matrices to the PyMC model currently on the context stack as pm.Deterministic nodes, and
Expand Down Expand Up @@ -1048,6 +1108,9 @@ def build_statespace_graph(
missing_fill_value=missing_fill_value,
)

# Order is important here: only call _insert_data_shape_into_n_timesteps after data has been registered.
self._insert_data_shape_into_n_timesteps(data)

filter_outputs = self.kalman_filter.build_graph(
pt.as_tensor_variable(data),
*self.unpack_statespace(),
Expand Down Expand Up @@ -1217,11 +1280,16 @@ def _kalman_filter_outputs_from_dummy_graph(
if name in scenario.keys():
pm.set_data({name: scenario[name]})

x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
matrices = self.unpack_statespace()

if data is None:
data = self._fit_data

# Replace n_timesteps with data length for time-varying matrices
data_len = data.shape[0] if hasattr(data, "shape") else len(data)
matrices = self._insert_constant_timestep(matrices, data_len)
x0, P0, c, d, T, Z, R, H, Q = matrices

obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)

data, nan_mask = register_data_with_pymc(
Expand Down Expand Up @@ -1481,8 +1549,8 @@ def _sample_unconditional(
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

matrices = [x0, P0, c, d, T, Z, R, H, Q] = self.unpack_statespace()
matrices = self._insert_constant_timestep(self.unpack_statespace(), step=steps)
x0, P0, c, d, T, Z, R, H, Q = matrices

if not self.measurement_error:
H_jittered = pm.Deterministic(
Expand Down Expand Up @@ -1814,7 +1882,11 @@ def sample_statespace_matrices(
return matrix_idata

def sample_filter_outputs(
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
self,
idata,
filter_output_names: str | list[str] | None = None,
group: str = "posterior",
**kwargs,
):
if isinstance(filter_output_names, str):
filter_output_names = [filter_output_names]
Expand Down Expand Up @@ -2302,11 +2374,15 @@ def _build_forecast_model(
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

# Get matrices with n_timesteps set to forecast length for time-varying models
# Note: matrices already has x0, P0 skipped from _kalman_filter_outputs_from_dummy_graph
forecast_matrices = self._insert_constant_timestep(matrices, len(forecast_index))

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*matrices,
*forecast_matrices,
steps=len(forecast_index),
dims=dims,
sequence_names=self.kalman_filter.seq_names,
Expand Down Expand Up @@ -2603,7 +2679,8 @@ def impulse_response_function(
self._build_dummy_graph()
self._insert_random_variables()

P0, _, c, d, T, Z, R, H, post_Q = self.unpack_statespace()
matrices = self._insert_constant_timestep(self.unpack_statespace(), step=n_steps)
P0, _, c, d, T, Z, R, H, post_Q = matrices
x0 = pm.Deterministic("x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM])

if use_posterior_cov:
Expand Down Expand Up @@ -2632,15 +2709,25 @@ def impulse_response_function(
else:
shock_trajectory = pt.as_tensor_variable(shock_trajectory)

def irf_step(shock, x, c, T, R):
time_varying_T = T.ndim == 3

def irf_step(*args):
if time_varying_T:
shock, T, x, c, R = args
else:
shock, x, c, T, R = args

next_x = c + T @ x + R @ shock
return next_x

sequences = [shock_trajectory, T] if time_varying_T else [shock_trajectory]
non_sequences = [c, R] if time_varying_T else [c, T, R]

irf = pytensor.scan(
irf_step,
sequences=[shock_trajectory],
sequences=sequences,
outputs_info=[x0],
non_sequences=[c, T, R],
non_sequences=non_sequences,
n_steps=n_steps,
strict=True,
return_updates=False,
Expand Down
Loading
Loading