Skip to content

Commit bcfa1db

Browse files
author
Jouni Helske
committed
fix start in ame
1 parent 7e05098 commit bcfa1db

File tree

8 files changed

+56
-215
lines changed

8 files changed

+56
-215
lines changed

R/RcppExports.R

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,42 +49,10 @@ fast_quantiles <- function(X, probs) {
4949
.Call(`_seqHMM_fast_quantiles`, X, probs)
5050
}
5151

52-
get_omega <- function(gamma, X) {
53-
.Call(`_seqHMM_get_omega`, gamma, X)
54-
}
55-
56-
get_log_omega <- function(gamma, X) {
57-
.Call(`_seqHMM_get_log_omega`, gamma, X)
58-
}
59-
6052
get_omega_all <- function(gamma, X) {
6153
.Call(`_seqHMM_get_omega_all`, gamma, X)
6254
}
6355

64-
get_pi <- function(gamma, X) {
65-
.Call(`_seqHMM_get_pi`, gamma, X)
66-
}
67-
68-
get_log_pi <- function(gamma, X) {
69-
.Call(`_seqHMM_get_log_pi`, gamma, X)
70-
}
71-
72-
get_A <- function(gamma, X, tv) {
73-
.Call(`_seqHMM_get_A`, gamma, X, tv)
74-
}
75-
76-
get_log_A <- function(gamma, X, tv) {
77-
.Call(`_seqHMM_get_log_A`, gamma, X, tv)
78-
}
79-
80-
get_B <- function(gamma, X, tv, add_missing) {
81-
.Call(`_seqHMM_get_B`, gamma, X, tv, add_missing)
82-
}
83-
84-
get_log_B <- function(gamma, X, tv, add_missing) {
85-
.Call(`_seqHMM_get_log_B`, gamma, X, tv, add_missing)
86-
}
87-
8856
get_pi_all <- function(gamma, X) {
8957
.Call(`_seqHMM_get_pi_all`, gamma, X)
9058
}

R/ame_obs.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ ame_obs.nhmm <- function(
9797
)
9898
newdata <- model$data
9999
}
100+
stopifnot_(
101+
!missing(start_time) && checkmate::test_choice(start_time, newdata[[time]]),
102+
"Argument {.arg start_time} must be a single value matching the
103+
time point in {.arg newdata}."
104+
)
100105
newdata[[variable]][newdata[[time]] >= start_time] <- values[1]
101106
X1 <- update(model, newdata)[c("X_pi", "X_A", "X_B")]
102107
newdata[[variable]][newdata[[time]] >= start_time] <- values[2]
@@ -247,6 +252,13 @@ ame_obs.fanhmm <- function(
247252
)
248253
newdata <- model$data
249254
}
255+
stopifnot_(
256+
!missing(start_time) &&
257+
checkmate::test_choice(start_time, newdata[[time]]) &&
258+
start_time != min(newdata[[time]]),
259+
"Argument {.arg start_time} must be a single value matching the
260+
time point in {.arg newdata}, excluding the first time point."
261+
)
250262
newdata[[variable]][newdata[[time]] >= start_time] <- values[1]
251263
X1 <- update(model, newdata)[c("X_pi", "X_A", "X_B")]
252264
W1_A <- W1_B <- vector("list", model$n_symbols)

R/ame_param.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ ame_param.nhmm <- function(
149149
unlist(get_A_all(model$gammas$A, X2, tv_A)),
150150
c(S, S, T_, N)
151151
),
152-
1:3, mean)
152+
1:3, mean, na.rm = TRUE)
153153
)
154154
)
155155
if (return_quantiles) {

src/RcppExports.cpp

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -250,30 +250,6 @@ BEGIN_RCPP
250250
return rcpp_result_gen;
251251
END_RCPP
252252
}
253-
// get_omega
254-
arma::vec get_omega(const arma::mat& gamma, const arma::vec& X);
255-
RcppExport SEXP _seqHMM_get_omega(SEXP gammaSEXP, SEXP XSEXP) {
256-
BEGIN_RCPP
257-
Rcpp::RObject rcpp_result_gen;
258-
Rcpp::RNGScope rcpp_rngScope_gen;
259-
Rcpp::traits::input_parameter< const arma::mat& >::type gamma(gammaSEXP);
260-
Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP);
261-
rcpp_result_gen = Rcpp::wrap(get_omega(gamma, X));
262-
return rcpp_result_gen;
263-
END_RCPP
264-
}
265-
// get_log_omega
266-
arma::vec get_log_omega(const arma::mat& gamma, const arma::vec& X);
267-
RcppExport SEXP _seqHMM_get_log_omega(SEXP gammaSEXP, SEXP XSEXP) {
268-
BEGIN_RCPP
269-
Rcpp::RObject rcpp_result_gen;
270-
Rcpp::RNGScope rcpp_rngScope_gen;
271-
Rcpp::traits::input_parameter< const arma::mat& >::type gamma(gammaSEXP);
272-
Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP);
273-
rcpp_result_gen = Rcpp::wrap(get_log_omega(gamma, X));
274-
return rcpp_result_gen;
275-
END_RCPP
276-
}
277253
// get_omega_all
278254
arma::mat get_omega_all(const arma::mat& gamma, const arma::mat& X);
279255
RcppExport SEXP _seqHMM_get_omega_all(SEXP gammaSEXP, SEXP XSEXP) {
@@ -286,84 +262,6 @@ BEGIN_RCPP
286262
return rcpp_result_gen;
287263
END_RCPP
288264
}
289-
// get_pi
290-
arma::vec get_pi(const arma::mat& gamma, const arma::vec& X);
291-
RcppExport SEXP _seqHMM_get_pi(SEXP gammaSEXP, SEXP XSEXP) {
292-
BEGIN_RCPP
293-
Rcpp::RObject rcpp_result_gen;
294-
Rcpp::RNGScope rcpp_rngScope_gen;
295-
Rcpp::traits::input_parameter< const arma::mat& >::type gamma(gammaSEXP);
296-
Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP);
297-
rcpp_result_gen = Rcpp::wrap(get_pi(gamma, X));
298-
return rcpp_result_gen;
299-
END_RCPP
300-
}
301-
// get_log_pi
302-
arma::vec get_log_pi(const arma::mat& gamma, const arma::vec& X);
303-
RcppExport SEXP _seqHMM_get_log_pi(SEXP gammaSEXP, SEXP XSEXP) {
304-
BEGIN_RCPP
305-
Rcpp::RObject rcpp_result_gen;
306-
Rcpp::RNGScope rcpp_rngScope_gen;
307-
Rcpp::traits::input_parameter< const arma::mat& >::type gamma(gammaSEXP);
308-
Rcpp::traits::input_parameter< const arma::vec& >::type X(XSEXP);
309-
rcpp_result_gen = Rcpp::wrap(get_log_pi(gamma, X));
310-
return rcpp_result_gen;
311-
END_RCPP
312-
}
313-
// get_A
314-
arma::cube get_A(const arma::cube& gamma, const arma::mat& X, const bool tv);
315-
RcppExport SEXP _seqHMM_get_A(SEXP gammaSEXP, SEXP XSEXP, SEXP tvSEXP) {
316-
BEGIN_RCPP
317-
Rcpp::RObject rcpp_result_gen;
318-
Rcpp::RNGScope rcpp_rngScope_gen;
319-
Rcpp::traits::input_parameter< const arma::cube& >::type gamma(gammaSEXP);
320-
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
321-
Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP);
322-
rcpp_result_gen = Rcpp::wrap(get_A(gamma, X, tv));
323-
return rcpp_result_gen;
324-
END_RCPP
325-
}
326-
// get_log_A
327-
arma::cube get_log_A(const arma::cube& gamma, const arma::mat& X, const bool tv);
328-
RcppExport SEXP _seqHMM_get_log_A(SEXP gammaSEXP, SEXP XSEXP, SEXP tvSEXP) {
329-
BEGIN_RCPP
330-
Rcpp::RObject rcpp_result_gen;
331-
Rcpp::RNGScope rcpp_rngScope_gen;
332-
Rcpp::traits::input_parameter< const arma::cube& >::type gamma(gammaSEXP);
333-
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
334-
Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP);
335-
rcpp_result_gen = Rcpp::wrap(get_log_A(gamma, X, tv));
336-
return rcpp_result_gen;
337-
END_RCPP
338-
}
339-
// get_B
340-
arma::cube get_B(const arma::cube& gamma, const arma::mat& X, const bool tv, const bool add_missing);
341-
RcppExport SEXP _seqHMM_get_B(SEXP gammaSEXP, SEXP XSEXP, SEXP tvSEXP, SEXP add_missingSEXP) {
342-
BEGIN_RCPP
343-
Rcpp::RObject rcpp_result_gen;
344-
Rcpp::RNGScope rcpp_rngScope_gen;
345-
Rcpp::traits::input_parameter< const arma::cube& >::type gamma(gammaSEXP);
346-
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
347-
Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP);
348-
Rcpp::traits::input_parameter< const bool >::type add_missing(add_missingSEXP);
349-
rcpp_result_gen = Rcpp::wrap(get_B(gamma, X, tv, add_missing));
350-
return rcpp_result_gen;
351-
END_RCPP
352-
}
353-
// get_log_B
354-
arma::cube get_log_B(const arma::cube& gamma, const arma::mat& X, const bool tv, const bool add_missing);
355-
RcppExport SEXP _seqHMM_get_log_B(SEXP gammaSEXP, SEXP XSEXP, SEXP tvSEXP, SEXP add_missingSEXP) {
356-
BEGIN_RCPP
357-
Rcpp::RObject rcpp_result_gen;
358-
Rcpp::RNGScope rcpp_rngScope_gen;
359-
Rcpp::traits::input_parameter< const arma::cube& >::type gamma(gammaSEXP);
360-
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
361-
Rcpp::traits::input_parameter< const bool >::type tv(tvSEXP);
362-
Rcpp::traits::input_parameter< const bool >::type add_missing(add_missingSEXP);
363-
rcpp_result_gen = Rcpp::wrap(get_log_B(gamma, X, tv, add_missing));
364-
return rcpp_result_gen;
365-
END_RCPP
366-
}
367265
// get_pi_all
368266
arma::mat get_pi_all(const arma::mat& gamma, const arma::mat& X);
369267
RcppExport SEXP _seqHMM_get_pi_all(SEXP gammaSEXP, SEXP XSEXP) {
@@ -1709,15 +1607,7 @@ static const R_CallMethodDef CallEntries[] = {
17091607
{"_seqHMM_eta_to_gamma_mat_field", (DL_FUNC) &_seqHMM_eta_to_gamma_mat_field, 1},
17101608
{"_seqHMM_eta_to_gamma_cube_field", (DL_FUNC) &_seqHMM_eta_to_gamma_cube_field, 1},
17111609
{"_seqHMM_fast_quantiles", (DL_FUNC) &_seqHMM_fast_quantiles, 2},
1712-
{"_seqHMM_get_omega", (DL_FUNC) &_seqHMM_get_omega, 2},
1713-
{"_seqHMM_get_log_omega", (DL_FUNC) &_seqHMM_get_log_omega, 2},
17141610
{"_seqHMM_get_omega_all", (DL_FUNC) &_seqHMM_get_omega_all, 2},
1715-
{"_seqHMM_get_pi", (DL_FUNC) &_seqHMM_get_pi, 2},
1716-
{"_seqHMM_get_log_pi", (DL_FUNC) &_seqHMM_get_log_pi, 2},
1717-
{"_seqHMM_get_A", (DL_FUNC) &_seqHMM_get_A, 3},
1718-
{"_seqHMM_get_log_A", (DL_FUNC) &_seqHMM_get_log_A, 3},
1719-
{"_seqHMM_get_B", (DL_FUNC) &_seqHMM_get_B, 4},
1720-
{"_seqHMM_get_log_B", (DL_FUNC) &_seqHMM_get_log_B, 4},
17211611
{"_seqHMM_get_pi_all", (DL_FUNC) &_seqHMM_get_pi_all, 2},
17221612
{"_seqHMM_get_A_all", (DL_FUNC) &_seqHMM_get_A_all, 3},
17231613
{"_seqHMM_get_B_all", (DL_FUNC) &_seqHMM_get_B_all, 3},

src/get_parameters.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
#include "get_parameters.h"
22

3-
// gamma_omega is D x K (start from, covariates)
4-
// X a vector of length K
5-
// [[Rcpp::export]]
63
arma::vec get_omega(const arma::mat& gamma, const arma::vec& X) {
74
return softmax(gamma * X);
85
}
9-
// eta_omega is D x K (start from, covariates)
10-
// X a vector of length K
11-
// [[Rcpp::export]]
126
arma::vec get_log_omega(const arma::mat& gamma, const arma::vec& X) {
137
return arma::log(softmax(gamma * X));
148
}
@@ -21,21 +15,12 @@ arma::mat get_omega_all(const arma::mat& gamma, const arma::mat& X) {
2115
return omega;
2216
}
2317

24-
// gamma is S x K (start from, covariates)
25-
// X a vector of length K
26-
// [[Rcpp::export]]
2718
arma::vec get_pi(const arma::mat& gamma, const arma::vec& X) {
2819
return softmax(gamma * X);
2920
}
30-
// gamma is S x K (start from, covariates)
31-
// X a vector of length K
32-
// [[Rcpp::export]]
3321
arma::vec get_log_pi(const arma::mat& gamma, const arma::vec& X) {
3422
return arma::log(softmax(gamma * X));
3523
}
36-
// gamma is S x K x S (transition to, covariates, transition from)
37-
// X is K x T matrix (covariates, time points)
38-
// [[Rcpp::export]]
3924
arma::cube get_A(const arma::cube& gamma, const arma::mat& X,
4025
const bool tv) {
4126
arma::uword S = gamma.n_slices;
@@ -57,9 +42,6 @@ arma::cube get_A(const arma::cube& gamma, const arma::mat& X,
5742
}
5843
return A;
5944
}
60-
// gamma is S x K x S (transition to, covariates, transition from)
61-
// X is K x T matrix (covariates, time points)
62-
// [[Rcpp::export]]
6345
arma::cube get_log_A(const arma::cube& gamma, const arma::mat& X,
6446
const bool tv) {
6547
arma::uword S = gamma.n_slices;
@@ -81,9 +63,6 @@ arma::cube get_log_A(const arma::cube& gamma, const arma::mat& X,
8163
}
8264
return arma::log(A);
8365
}
84-
// gamma is M x K x S (symbols, covariates, transition from)
85-
// X is K x T (covariates, time points)
86-
// [[Rcpp::export]]
8766
arma::cube get_B(const arma::cube& gamma, const arma::mat& X,
8867
const bool tv, const bool add_missing) {
8968
arma::uword S = gamma.n_slices;
@@ -111,8 +90,6 @@ arma::cube get_B(const arma::cube& gamma, const arma::mat& X,
11190
}
11291
return B;
11392
}
114-
// gamma is a a field of M_c x K x S cubes
115-
// X is K x T (covariates, time point)
11693
arma::field<arma::cube> get_B(
11794
const arma::field<arma::cube>& gamma, const arma::mat& X,
11895
const arma::uvec& M, const bool tv, const bool add_missing) {
@@ -123,9 +100,6 @@ arma::field<arma::cube> get_B(
123100
}
124101
return B;
125102
}
126-
// gamma is M x K x S (symbols, covariates, transition from)
127-
// X is K x T (covariates, time points)
128-
// [[Rcpp::export]]
129103
arma::cube get_log_B(const arma::cube& gamma, const arma::mat& X,
130104
const bool tv, const bool add_missing) {
131105
arma::uword S = gamma.n_slices;
@@ -153,8 +127,6 @@ arma::cube get_log_B(const arma::cube& gamma, const arma::mat& X,
153127
}
154128
return arma::log(B);
155129
}
156-
// gamma is a a field of M_c x K x S cubes
157-
// X is K x T (covariates, time point)
158130
arma::field<arma::cube> get_log_B(
159131
const arma::field<arma::cube>& gamma, const arma::mat& X,
160132
const arma::uvec& M, const bool tv, const bool add_missing) {
@@ -166,6 +138,7 @@ arma::field<arma::cube> get_log_B(
166138
return log_B;
167139
}
168140

141+
169142
// gamma is S x K (start from, covariates)
170143
// X a K x N matrix
171144
// [[Rcpp::export]]

src/mnhmm_base.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ struct mnhmm_base {
236236
}
237237
log_A(d) = arma::log(A(d));
238238
}
239-
240239
void estep_omega(const arma::uword i, const arma::vec ll_i,
241240
const double ll) {
242241
E_omega.col(i) = arma::exp(ll_i - ll);

0 commit comments

Comments
 (0)