Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dynamax/hidden_markov_model/models/arhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __init__(self,
input_dim = num_lags * emission_dim
super().__init__(num_states, input_dim, emission_dim)

def _check_emissions_format(self, emission_weights, emission_biases, emission_covariances):
assert emission_weights.shape == (self.num_states, self.emission_dim, self.input_dim), f"'emission_weights' must have shape (num_states, emission_dim, input_dim)={(self.num_states, self.emission_dim, self.input_dim)} but {emission_weights.shape} provided."
assert emission_biases.shape == (self.num_states, self.emission_dim), f"'emission_biases' must have shape (num_states, emission_dim)={(self.num_states, self.emission_dim)} but {emission_biases.shape} provided."
assert emission_covariances.shape == (self.num_states, self.emission_dim, self.emission_dim), f"'emission_covariances' must have shape (num_states, emission_dim, emission_dim)={(self.num_states, self.emission_dim, self.emission_dim)} but {emission_covariances.shape} provided."

def initialize(self,
key: Array=jr.PRNGKey(0),
method: str="prior",
Expand Down Expand Up @@ -93,6 +98,9 @@ def initialize(self,
weights=ParameterProperties(),
biases=ParameterProperties(),
covs=ParameterProperties(constrainer=RealToPSDBijector()))

self._check_emissions_format(emission_weights=params.weights, emission_biases=params.biases, emission_covariances=params.covs)

return params, props


Expand Down
12 changes: 7 additions & 5 deletions dynamax/hidden_markov_model/models/categorical_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def log_prior(self, params: ParamsCategoricalHMMEmissions) -> Scalar:
"""Return the log prior probability of the emission parameters."""
return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum()

def _check_emissions_format(self, emission_probs):
assert emission_probs.shape == (self.num_states, self.emission_dim, self.num_classes), f"'emission_probs' must have shape (num_states, emission_dim, num_classes)={(self.num_states, self.emission_dim, self.num_classes)} but {emission_probs.shape} provided."
assert jnp.all(emission_probs >= 0), "All entries in 'emission_probs' must be non-negative."
assert jnp.allclose(emission_probs.sum(axis=2), 1.0), "Each row of 'emission_probs' must sum to 1."

def initialize(self,
key:Optional[Array]=jr.PRNGKey(0),
method="prior",
Expand Down Expand Up @@ -91,11 +96,8 @@ def initialize(self,
raise NotImplementedError("kmeans initialization is not yet implemented!")
else:
raise Exception("invalid initialization method: {}".format(method))
else:
assert emission_probs.shape == (self.num_states, self.emission_dim, self.num_classes)
assert jnp.all(emission_probs >= 0)
assert jnp.allclose(emission_probs.sum(axis=2), 1.0)


self._check_emissions_format(emission_probs=emission_probs)
# Add parameters to the dictionary
params = ParamsCategoricalHMMEmissions(probs=emission_probs)
props = ParamsCategoricalHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
Expand Down
6 changes: 6 additions & 0 deletions dynamax/hidden_markov_model/models/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tf
"""Return the distribution object of the initial distribution."""
return tfd.Categorical(probs=params.probs)

def _check_initialization_format(self, initial_probs: Float[Array, " num_states"]) -> None:
assert initial_probs.shape == (self.num_states,), f"'initial_probs' must have shape (num_states,)={(self.num_states,)} but {initial_probs.shape} provided."
assert jnp.all(initial_probs >= 0.0), f"All entries in 'initial_probs' must be non-negative."
assert jnp.isclose(initial_probs.sum(), 1.0), ValueError(f"'initial_probs' must sum to 1.0.")

def initialize(
self,
key: Optional[Array]=None,
Expand All @@ -59,6 +64,7 @@ def initialize(
this_key, key = jr.split(key)
initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)

self._check_initialization_format(initial_probs=initial_probs)
# Package the results into dictionaries
params = ParamsStandardHMMInitialState(probs=initial_probs)
props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
Expand Down
8 changes: 7 additions & 1 deletion dynamax/hidden_markov_model/models/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, i
"""Return the distribution over the next state given the current state."""
return tfd.Categorical(probs=params.transition_matrix[state])

def _check_transitions_format(self, transition_matrix: Float[Array, "num_states num_states"]):
assert transition_matrix.shape == (self.num_states, self.num_states), f"'transition_matrix' must have shape (num_states, num_states)={(self.num_states, self.num_states)} but {transition_matrix.shape} provided."
assert jnp.all(transition_matrix >= 0.0), f"All entries in 'transition_matrix' must be non-negative."
assert jnp.isclose(transition_matrix.sum(axis=1), 1.0).all(), f"Each row of 'transition_matrix' must sum to 1.0."

def initialize(
self,
key: Optional[Array]=None,
Expand All @@ -73,7 +78,8 @@ def initialize(
else:
transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key)
transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample)


self._check_transitions_format(transition_matrix=transition_matrix)
# Package the results into dictionaries
params = ParamsStandardHMMTransitions(transition_matrix=transition_matrix)
props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
Expand Down