diff --git a/src/fastcpd_class.cc b/src/fastcpd_class.cc index 01f10f6e..b0e413af 100644 --- a/src/fastcpd_class.cc +++ b/src/fastcpd_class.cc @@ -354,7 +354,6 @@ double Fastcpd::get_cval_for_r_t_set( const int t, double lambda ) { - double cval = 0; DEBUG_RCOUT(i); int tau = r_t_set(i - 1); if (family == "lasso") { @@ -364,53 +363,72 @@ double Fastcpd::get_cval_for_r_t_set( mat data_segment = data.rows(tau, t - 1); DEBUG_RCOUT(data_segment); if (t > vanilla_percentage * n) { - // fastcpd - update_cost_parameters(t, tau, i, k.get(), lambda, line_search); - colvec theta = theta_sum.col(i - 1) / (t - tau); - DEBUG_RCOUT(theta); - if (!contain(FASTCPD_FAMILIES, family)) { - Function cost_non_null = cost.get(); - SEXP cost_result = cost_non_null(data_segment, theta); - cval = as(cost_result); - } else if ( - (family != "lasso" && t - tau >= p) || - (family == "lasso" && t - tau >= 3) - ) { - cval = cost_function_wrapper( - data_segment, wrap(theta), lambda, false, R_NilValue - ).value; - } else { - // t - tau < p or for lasso t - tau < 3 - } + return get_cval_sen(data_segment, i, t, tau, lambda); } else { - // vanilla PELT - CostResult cost_result; - if (!contain(FASTCPD_FAMILIES, family)) { - cost_result = get_optimized_cost(data_segment); + return get_cval_pelt(data_segment, i, t, tau, lambda); + } +} + +double Fastcpd::get_cval_pelt( + const mat data_segment, + const unsigned int i, + const int t, + const int tau, + const double lambda +) { + double cval = 0; + CostResult cost_result; + if (!contain(FASTCPD_FAMILIES, family)) { + cost_result = get_optimized_cost(data_segment); + } else { + if (warm_start && t - tau >= 10 * p) { + cost_result = cost_function_wrapper( + data_segment, R_NilValue, lambda, false, + wrap(segment_theta_hat[segment_indices(t - 1) - 1]) + // Or use `wrap(start.col(tau))` for warm start. + ); + update_start(tau, colvec(cost_result.par)); } else { - if (warm_start && t - tau >= 10 * p) { - cost_result = cost_function_wrapper( - data_segment, R_NilValue, lambda, false, - wrap(segment_theta_hat[segment_indices(t - 1) - 1]) - // Or use `wrap(start.col(tau))` for warm start. - ); - update_start(tau, colvec(cost_result.par)); - } else { - cost_result = cost_function_wrapper( - data_segment, R_NilValue, lambda, false, R_NilValue - ); - } + cost_result = cost_function_wrapper( + data_segment, R_NilValue, lambda, false, R_NilValue + ); } - cval = cost_result.value; + } + cval = cost_result.value; - // If `vanilla_percentage` is not 1, then we need to keep track of - // thetas for later `fastcpd` steps. - if (vanilla_percentage < 1 && t <= vanilla_percentage * n) { - update_theta_hat(i - 1, cost_result.par); - update_theta_sum(i - 1, cost_result.par); - } + // If `vanilla_percentage` is not 1, then we need to keep track of + // thetas for later `fastcpd` steps. + if (vanilla_percentage < 1 && t <= vanilla_percentage * n) { + update_theta_hat(i - 1, cost_result.par); + update_theta_sum(i - 1, cost_result.par); } + return cval; +} +double Fastcpd::get_cval_sen( + const mat data_segment, + const unsigned int i, + const int t, + const int tau, + const double lambda +) { + double cval = 0; + update_cost_parameters(t, tau, i, k.get(), lambda, line_search); + colvec theta = theta_sum.col(i - 1) / (t - tau); + DEBUG_RCOUT(theta); + if (!contain(FASTCPD_FAMILIES, family)) { + Function cost_non_null = cost.get(); + SEXP cost_result = cost_non_null(data_segment, theta); + cval = as(cost_result); + } else if ( + (family != "lasso" && t - tau >= p) || + (family == "lasso" && t - tau >= 3) + ) { + cval = cost_function_wrapper( + data_segment, wrap(theta), lambda, false, R_NilValue + ).value; + } + // else t - tau < p or for lasso t - tau < 3 return cval; } diff --git a/src/fastcpd_classes.h b/src/fastcpd_classes.h index dbcae54c..a43ca96f 100644 --- a/src/fastcpd_classes.h +++ b/src/fastcpd_classes.h @@ -127,6 +127,22 @@ class Fastcpd { double lambda ); + double get_cval_pelt( + const mat data_segment, + const unsigned int i, + const int t, + const int tau, + const double lambda + ); + + double get_cval_sen( + const mat data_segment, + const unsigned int i, + const int t, + const int tau, + const double lambda + ); + // Update \code{theta_hat}, \code{theta_sum}, and \code{hessian}. // // @param data_segment A data frame containing a segment of the data.