Skip to content

Commit

Permalink
fix handling of initial observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Jouni Helske committed Jan 16, 2025
1 parent ea788cb commit bdc37b2
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 57 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ simulate_mnhmm_multichannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, et
.Call(`_seqHMM_simulate_mnhmm_multichannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, eta_omega, X_omega, M)
}

simulate_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_0) {
.Call(`_seqHMM_simulate_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_0)
simulate_fanhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_1) {
.Call(`_seqHMM_simulate_fanhmm_singlechannel`, eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_1)
}

viterbi_nhmm_singlechannel <- function(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs, Ti, icpt_only_pi, icpt_only_A, icpt_only_B, iv_A, iv_B, tv_A, tv_B) {
Expand Down
8 changes: 2 additions & 6 deletions R/build_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,10 @@ build_fanhmm <- function(
paste("~ . + ", terms_feedback)
)
}
obs_0 <- data[[observations]][data[[time]] == min(data[[time]])]
data <- data[data[[time]] > min(data[[time]]), ]
out <- create_base_nhmm(
observations, data, time, id, n_states, state_names, channel_names = NULL,
initial_formula, transition_formula, emission_formula, scale = scale,
check_formulas = FALSE)
check_formulas = FALSE, fanhmm = TRUE)
stopifnot_(
!any(out$model$observations == attr(out$model$observations, "nr")),
"FAN-HMM does not support missing values in the observations."
Expand All @@ -92,8 +90,6 @@ build_fanhmm <- function(
out$model$etas <- setNames(
create_initial_values(list(), out$model, 0), c("pi", "A", "B")
)

out$model$obs_0 <- obs_0
structure(
c(
out$model,
Expand All @@ -103,7 +99,7 @@ build_fanhmm <- function(
)
),
class = c("fanhmm", "nhmm"),
nobs = attr(out$model$observations, "nobs"),
nobs = attr(out$model$observations, "nobs") - out$model$n_sequences,
df = out$extras$np_pi + out$extras$np_A + out$extras$np_B,
type = paste0(out$extras$multichannel, "fanhmm"),
intercept_only = out$extras$intercept_only,
Expand Down
4 changes: 2 additions & 2 deletions R/create_base_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
initial_formula, transition_formula,
emission_formula, cluster_formula = NA,
cluster_names = "", scale = TRUE,
check_formulas = TRUE) {
check_formulas = TRUE, fanhmm = FALSE) {

stopifnot_(
!missing(n_states) && checkmate::test_int(x = n_states, lower = 2L),
Expand Down Expand Up @@ -124,7 +124,7 @@ create_base_nhmm <- function(observations, data, time, id, n_states,
)
B <- model_matrix_emission_formula(
emission_formula, data, n_sequences, length_of_sequences, n_states,
n_symbols, time, id, sequence_lengths, scale = scale
n_symbols, time, id, sequence_lengths, scale = scale, fanhmm = fanhmm
)
if (mixture) {
omega <- model_matrix_cluster_formula(
Expand Down
8 changes: 6 additions & 2 deletions R/model_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ model_matrix_emission_formula <- function(formula, data, n_sequences,
n_symbols,
time, id, sequence_lengths,
X_mean, X_sd, check = TRUE,
scale = TRUE) {
scale = TRUE, fanhmm = FALSE) {
icpt_only <- intercept_only(formula)
if (icpt_only) {
n_pars <- sum(n_states * (n_symbols - 1L))
Expand Down Expand Up @@ -255,9 +255,13 @@ model_matrix_emission_formula <- function(formula, data, n_sequences,
X[, cols] <- X_scaled
X_mean <- attr(X_scaled, "scaled:center")
X_sd <- attr(X_scaled, "scaled:scale")

complete <- complete.cases(X)
missing_values <- which(!complete)
if (fanhmm) {
# first observation is fixed so missing (lagged) values do not matter
missing_values <- setdiff(missing_values, which(data[[time]] == min(data[[time]])))
#complete[which(data[[time]] == min(data[[time]]))] <- TRUE
}
if (length(missing_values) > 0) {
ends <- sequence_lengths[match(data[[id]], unique(data[[id]]))]
stopifnot_(
Expand Down
21 changes: 7 additions & 14 deletions R/simulate_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ simulate_fanhmm <- function(
n_sequences, sequence_lengths, n_symbols, n_states,
initial_formula = ~1, transition_formula = ~1,
emission_formula = ~1, autoregression_formula = ~1,
feedback_formula = ~1, obs_0, data, time, id, coefs = "random",
feedback_formula = ~1, obs_1, data, time, id, coefs = "random",
response_name = "y", init_sd = 2) {

stopifnot_(
Expand Down Expand Up @@ -54,25 +54,18 @@ simulate_fanhmm <- function(
"Currently only single-channel responses are supported for FAN-HMM."
)
stopifnot_(
!missing(obs_0) &&
checkmate::test_integerish(x = obs_0, lower = 1L, upper = n_symbols),
"Argument {.arg obs_0} should be an integer vector of length {.val n_sequences}."
!missing(obs_1) &&
checkmate::test_integerish(x = obs_1, lower = 1L, upper = n_symbols),
"Argument {.arg obs_1} should be an integer vector of length {.val n_sequences}."
)
symbol_names <- as.character(seq_len(n_symbols))
T_ <- max(sequence_lengths)
data[[response_name]] <- factor(rep(symbol_names, length = nrow(data)),
levels = symbol_names)
data0 <- data[data[[time]] == min(data[[time]]), ]
data0[[time]] <- min(data[[time]]) - min(diff(sort(unique(data[[time]]))))
data0[[response_name]] <- obs_0
data0 <- rbind(
data0,
data
)


model <- build_fanhmm(
response_name, n_states, initial_formula, transition_formula, emission_formula,
autoregression_formula, feedback_formula, data0, time, id, scale = FALSE
autoregression_formula, feedback_formula, data, time, id, scale = FALSE
)

X_A <- X_B <- vector("list", n_symbols)
Expand Down Expand Up @@ -104,7 +97,7 @@ simulate_fanhmm <- function(
model$etas$pi, model$X_pi,
model$etas$A, X_A,
model$etas$B, X_B,
as.integer(model$obs_0) - 1
as.integer(obs_1) - 1
)
for (i in seq_len(model$n_sequences)) {
Ti <- sequence_lengths[i]
Expand Down
3 changes: 2 additions & 1 deletion R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ update.nhmm <- function(object, newdata, ...) {
object$length_of_sequences, object$n_states, object$n_symbols,
object$time_variable, object$id_variable,
object$sequence_lengths,
attr(object$X_B, "X_mean"), attr(object$X_B, "X_sd"), FALSE
attr(object$X_B, "X_mean"), attr(object$X_B, "X_sd"), FALSE,
fanhmm = inherits(object, "fanhmm")
)$X
object
}
Expand Down
6 changes: 5 additions & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @noRd
group_lag <- function(d, id, response) {
lagged_response <- c(d[[response]][1], d[[response]][-nrow(d)])
#lagged_response[which(!duplicated(d[[id]]))] <- d[[response]][1]
lagged_response[which(!duplicated(d[[id]]))] <- NA
lagged_response
}
#' Convert return code from estimate_nhmm and estimate_mnhmm to text
Expand Down Expand Up @@ -221,7 +221,11 @@ create_obsArray <- function(model) {
sum(obsArray[, , i] < model$n_symbols[i]) > 0,
"One channel contains only missing values, model is degenerate."
)
if (inherits(model, "fanhmm")) {
obsArray[, 1, i] <- model$n_symbols[i]
}
}

aperm(obsArray)
}
#' Create emissionArray for Various C++ functions
Expand Down
8 changes: 4 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,8 @@ BEGIN_RCPP
END_RCPP
}
// simulate_fanhmm_singlechannel
Rcpp::List simulate_fanhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::field<arma::cube>& X_A, const arma::cube& eta_B, const arma::field<arma::cube>& X_B, const arma::uvec& obs_0);
RcppExport SEXP _seqHMM_simulate_fanhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obs_0SEXP) {
Rcpp::List simulate_fanhmm_singlechannel(const arma::mat& eta_pi, const arma::mat& X_pi, const arma::cube& eta_A, const arma::field<arma::cube>& X_A, const arma::cube& eta_B, const arma::field<arma::cube>& X_B, const arma::uvec& obs_1);
RcppExport SEXP _seqHMM_simulate_fanhmm_singlechannel(SEXP eta_piSEXP, SEXP X_piSEXP, SEXP eta_ASEXP, SEXP X_ASEXP, SEXP eta_BSEXP, SEXP X_BSEXP, SEXP obs_1SEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -1090,8 +1090,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type X_A(X_ASEXP);
Rcpp::traits::input_parameter< const arma::cube& >::type eta_B(eta_BSEXP);
Rcpp::traits::input_parameter< const arma::field<arma::cube>& >::type X_B(X_BSEXP);
Rcpp::traits::input_parameter< const arma::uvec& >::type obs_0(obs_0SEXP);
rcpp_result_gen = Rcpp::wrap(simulate_fanhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_0));
Rcpp::traits::input_parameter< const arma::uvec& >::type obs_1(obs_1SEXP);
rcpp_result_gen = Rcpp::wrap(simulate_fanhmm_singlechannel(eta_pi, X_pi, eta_A, X_A, eta_B, X_B, obs_1));
return rcpp_result_gen;
END_RCPP
}
Expand Down
19 changes: 0 additions & 19 deletions src/nhmm_gradients.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,4 @@ void gradient_wrt_B(
const arma::cube& X, const arma::uvec& M, const arma::uword i,
const arma::uword s, const arma::uword t, const arma::uword c,
const arma::uword d);
// FAN-HMM
void gradient_wrt_A(
arma::mat& grad, arma::mat& tmpmat, const arma::umat& obs,
const arma::mat& log_py, const arma::mat& log_alpha,
const arma::mat& log_beta, const double ll, const arma::cube& A,
const arma::cube& X, const arma::cube& W, const arma::uword i,
const arma::uword t, const arma::uword s);
void gradient_wrt_B_t0(
arma::mat& grad, arma::vec& tmpvec, const arma::umat& obs,
const arma::uvec& obs_0, const arma::vec& log_pi, const arma::mat& log_beta,
const double ll, const arma::cube& B, const arma::cube& X,
const arma::cube& W, const arma::uword i, const arma::uword s);
void gradient_wrt_B(
arma::mat& grad, arma::vec& tmpvec, const arma::umat& obs,
const arma::mat& log_alpha,
const arma::mat& log_beta, const double ll,
const arma::cube& log_A, const arma::cube& B, const arma::cube& X,
const arma::cube& W, const arma::uword i, const arma::uword s,
const arma::uword t);
#endif
12 changes: 6 additions & 6 deletions src/nhmm_simulate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ Rcpp::List simulate_fanhmm_singlechannel(
const arma::mat& eta_pi, const arma::mat& X_pi,
const arma::cube& eta_A, const arma::field<arma::cube>& X_A,
const arma::cube& eta_B, const arma::field<arma::cube>& X_B,
const arma::uvec& obs_0) {
arma::uword N = obs_0.n_elem;
const arma::uvec& obs_1) {
arma::uword N = obs_1.n_elem;
arma::uword T = X_A(0).n_cols;
arma::uword S = eta_A.n_slices;
arma::uword M = eta_B.n_rows + 1;
Expand All @@ -209,10 +209,10 @@ Rcpp::List simulate_fanhmm_singlechannel(
for (arma::uword i = 0; i < N; i++) {
pi = get_pi(gamma_pi, X_pi.col(i));
z(0, i) = arma::as_scalar(Rcpp::RcppArmadillo::sample(seqS, 1, false, pi));
B = softmax(
gamma_B.slice(z(0, i)) * X_B(obs_0(i)).slice(i).col(0)
);
y(0, i) = arma::as_scalar(Rcpp::RcppArmadillo::sample(seqM, 1, false, B));
// B = softmax(
// gamma_B.slice(z(0, i)) * X_B(obs_0(i)).slice(i).col(0)
// );
y(0, i) = obs_1(i);
for (arma::uword t = 1; t < T; t++) {
A = softmax(
gamma_A.slice(z(t - 1, i)) * X_A(y(t - 1, i)).slice(i).col(t - 1)
Expand Down

0 comments on commit bdc37b2

Please sign in to comment.