diff --git a/pymc_extras/distributions/timeseries.py b/pymc_extras/distributions/timeseries.py index 6bf3acdb2..fe816fed5 100644 --- a/pymc_extras/distributions/timeseries.py +++ b/pymc_extras/distributions/timeseries.py @@ -241,7 +241,7 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng): def greedy_transition(*args): *states, transition_probs, old_rng = args p = transition_probs[tuple(states)] - return pt.argmax(p) + return pt.argmax(p, axis=-1) chain_moment, moment_updates = pytensor.scan( greedy_transition, @@ -250,8 +250,8 @@ def greedy_transition(*args): n_steps=steps, strict=True, ) - chain_moment = pt.concatenate([init_dist_moment, chain_moment]) - return chain_moment + chain_moment = pt.concatenate([init_dist_moment, chain_moment], axis=0) + return pt.moveaxis(chain_moment, 0, -1) @_logprob.register(DiscreteMarkovChainRV)