Skip to content

Commit

Permalink
Improve CodeFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
doccstat committed Mar 30, 2024
1 parent a877690 commit 96ad16c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 42 deletions.
102 changes: 60 additions & 42 deletions src/fastcpd_class.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ double Fastcpd::get_cval_for_r_t_set(
const int t,
double lambda
) {
double cval = 0;
DEBUG_RCOUT(i);
int tau = r_t_set(i - 1);
if (family == "lasso") {
Expand All @@ -364,53 +363,72 @@ double Fastcpd::get_cval_for_r_t_set(
mat data_segment = data.rows(tau, t - 1);
DEBUG_RCOUT(data_segment);
if (t > vanilla_percentage * n) {
// fastcpd
update_cost_parameters(t, tau, i, k.get(), lambda, line_search);
colvec theta = theta_sum.col(i - 1) / (t - tau);
DEBUG_RCOUT(theta);
if (!contain(FASTCPD_FAMILIES, family)) {
Function cost_non_null = cost.get();
SEXP cost_result = cost_non_null(data_segment, theta);
cval = as<double>(cost_result);
} else if (
(family != "lasso" && t - tau >= p) ||
(family == "lasso" && t - tau >= 3)
) {
cval = cost_function_wrapper(
data_segment, wrap(theta), lambda, false, R_NilValue
).value;
} else {
// t - tau < p or for lasso t - tau < 3
}
return get_cval_sen(data_segment, i, t, tau, lambda);
} else {
// vanilla PELT
CostResult cost_result;
if (!contain(FASTCPD_FAMILIES, family)) {
cost_result = get_optimized_cost(data_segment);
return get_cval_pelt(data_segment, i, t, tau, lambda);
}
}

double Fastcpd::get_cval_pelt(
const mat data_segment,
const unsigned int i,
const int t,
const int tau,
const double lambda
) {
double cval = 0;
CostResult cost_result;
if (!contain(FASTCPD_FAMILIES, family)) {
cost_result = get_optimized_cost(data_segment);
} else {
if (warm_start && t - tau >= 10 * p) {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false,
wrap(segment_theta_hat[segment_indices(t - 1) - 1])
// Or use `wrap(start.col(tau))` for warm start.
);
update_start(tau, colvec(cost_result.par));
} else {
if (warm_start && t - tau >= 10 * p) {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false,
wrap(segment_theta_hat[segment_indices(t - 1) - 1])
// Or use `wrap(start.col(tau))` for warm start.
);
update_start(tau, colvec(cost_result.par));
} else {
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false, R_NilValue
);
}
cost_result = cost_function_wrapper(
data_segment, R_NilValue, lambda, false, R_NilValue
);
}
cval = cost_result.value;
}
cval = cost_result.value;

// If `vanilla_percentage` is not 1, then we need to keep track of
// thetas for later `fastcpd` steps.
if (vanilla_percentage < 1 && t <= vanilla_percentage * n) {
update_theta_hat(i - 1, cost_result.par);
update_theta_sum(i - 1, cost_result.par);
}
// If `vanilla_percentage` is not 1, then we need to keep track of
// thetas for later `fastcpd` steps.
if (vanilla_percentage < 1 && t <= vanilla_percentage * n) {
update_theta_hat(i - 1, cost_result.par);
update_theta_sum(i - 1, cost_result.par);
}
return cval;
}

double Fastcpd::get_cval_sen(
const mat data_segment,
const unsigned int i,
const int t,
const int tau,
const double lambda
) {
double cval = 0;
update_cost_parameters(t, tau, i, k.get(), lambda, line_search);
colvec theta = theta_sum.col(i - 1) / (t - tau);
DEBUG_RCOUT(theta);
if (!contain(FASTCPD_FAMILIES, family)) {
Function cost_non_null = cost.get();
SEXP cost_result = cost_non_null(data_segment, theta);
cval = as<double>(cost_result);
} else if (
(family != "lasso" && t - tau >= p) ||
(family == "lasso" && t - tau >= 3)
) {
cval = cost_function_wrapper(
data_segment, wrap(theta), lambda, false, R_NilValue
).value;
}
// else t - tau < p or for lasso t - tau < 3
return cval;
}

Expand Down
16 changes: 16 additions & 0 deletions src/fastcpd_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ class Fastcpd {
double lambda
);

double get_cval_pelt(
const mat data_segment,
const unsigned int i,
const int t,
const int tau,
const double lambda
);

double get_cval_sen(
const mat data_segment,
const unsigned int i,
const int t,
const int tau,
const double lambda
);

// Update \code{theta_hat}, \code{theta_sum}, and \code{hessian}.
//
// @param data_segment A data frame containing a segment of the data.
Expand Down

0 comments on commit 96ad16c

Please sign in to comment.