From f856d3d6598ba0bfd8ea066da4ee68dfcaca301c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 8 Jan 2026 18:19:56 +0100 Subject: [PATCH] Fix support point for DiscreteMarkovChain with batch dimensions --- pymc_extras/distributions/timeseries.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_extras/distributions/timeseries.py b/pymc_extras/distributions/timeseries.py index 6bf3acdb2..fe816fed5 100644 --- a/pymc_extras/distributions/timeseries.py +++ b/pymc_extras/distributions/timeseries.py @@ -241,7 +241,7 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng): def greedy_transition(*args): *states, transition_probs, old_rng = args p = transition_probs[tuple(states)] - return pt.argmax(p) + return pt.argmax(p, axis=-1) chain_moment, moment_updates = pytensor.scan( greedy_transition, @@ -250,8 +250,8 @@ def greedy_transition(*args): n_steps=steps, strict=True, ) - chain_moment = pt.concatenate([init_dist_moment, chain_moment]) - return chain_moment + chain_moment = pt.concatenate([init_dist_moment, chain_moment], axis=0) + return pt.moveaxis(chain_moment, 0, -1) @_logprob.register(DiscreteMarkovChainRV)