Skip to content

Commit

Permalink
refactor internal interface of change point detection
Browse files Browse the repository at this point in the history
  • Loading branch information
yuuki committed Dec 2, 2023
1 parent 65dec36 commit ebf03c1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
23 changes: 19 additions & 4 deletions metricsifter/algo/detection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import warnings
from typing import cast
from collections import defaultdict
from typing import Final

import numpy as np
import numpy.typing as npt
import pandas as pd
import ruptures as rpt
from joblib import Parallel, delayed

NO_CHANGE_POINTS: Final[int] = -1


def _detect_changepoints_with_missing_values(x: np.ndarray) -> npt.ArrayLike:
"""
Expand Down Expand Up @@ -59,9 +62,21 @@ def detect_multi_changepoints(
penalty: str | float,
penalty_adjust: float,
n_jobs: int = -1,
) -> list[list[int]]:
) -> tuple[list[int], dict[int, list[str]], dict[str, list[int]]]:
metrics: list[str] = X.columns.tolist()
multi_change_points = Parallel(n_jobs=n_jobs)(
delayed(detect_univariate_changepoints)(X[metric].to_numpy(), search_method, cost_model, penalty, penalty_adjust)
for metric in X.columns.tolist()
for metric in metrics
)
return cast(list[list[int]], multi_change_points)
cp_to_metrics: dict[int, list[str]] = defaultdict(list)
for metric, change_points in zip(metrics, multi_change_points):
if change_points is None or len(change_points) < 1:
cp_to_metrics[NO_CHANGE_POINTS].append(metric) # cp == -1 means no change point
continue
for cp in change_points:
cp_to_metrics[cp].append(metric)

flatten_change_points: list[int] = sum(multi_change_points, [])
metric_to_cps = {metric: cps for metric, cps in zip(metrics, multi_change_points) if cps is not None}

return flatten_change_points, cp_to_metrics, metric_to_cps
28 changes: 7 additions & 21 deletions metricsifter/algo/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
from collections import defaultdict
from typing import Final

import numpy as np
import scipy.signal
from statsmodels.nonparametric.kde import KDEUnivariate

NO_CHANGE_POINTS: Final[int] = -1
from metricsifter.algo.detection import NO_CHANGE_POINTS


def segment_nested_changepoints(
multi_change_points: list[list[int]],
metrics: list[str],
flatten_change_points: list[int],
cp_to_metrics: dict[int, list[str]],
time_series_length: int,
kde_bandwidth: float | str = 2.5,
kde_bandwidth_adjust: float = 1.,
) -> tuple[dict[int, set[str]], dict[int, np.ndarray]]:
cp_to_metrics: dict[int, list[str]] = defaultdict(list)
for metric, change_points in zip(metrics, multi_change_points):
if len(change_points) < 1:
cp_to_metrics[NO_CHANGE_POINTS].append(metric) # cp == -1 means no change point
continue
for cp in change_points:
cp_to_metrics[cp].append(metric)

flatten_change_points: list[int] = sum(multi_change_points, [])
if len(flatten_change_points) == 0:
return {}, {}

_, label_to_change_points = segment_changepoints_with_kde(
flatten_change_points, time_series_length=time_series_length,
flatten_change_points,
time_series_length=time_series_length,
kde_bandwidth=kde_bandwidth,
kde_bandwidth_adjust=kde_bandwidth_adjust,
unique_values=True,
)

Expand All @@ -47,7 +34,6 @@ def segment_changepoints_with_kde(
change_points: list[int],
time_series_length: int,
kde_bandwidth: str | float,
kde_bandwidth_adjust: float = 1.,
unique_values: bool = True,
) -> tuple[np.ndarray, dict[int, np.ndarray]]:
assert len(change_points) > 0, "change_points should not be empty"
Expand All @@ -57,7 +43,7 @@ def segment_changepoints_with_kde(
return np.zeros(len(x), dtype=int), {0: np.unique(x) if unique_values else x} # the all change points belongs to cluster 0.

dens = KDEUnivariate(x)
dens.fit(kernel="gau", bw=kde_bandwidth, fft=True, adjust=kde_bandwidth_adjust)
dens.fit(kernel="gau", bw=kde_bandwidth, fft=True)
s = np.linspace(start=0, stop=time_series_length - 1, num=time_series_length)
e = dens.evaluate(s)

Expand Down
9 changes: 4 additions & 5 deletions metricsifter/sifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,21 @@ def run(self, data: pd.DataFrame, without_simple_filter: bool = False) -> pd.Dat
metrics: list[str] = X.columns.tolist()

# STEP1: detect change points
change_point_indexes = detection.detect_multi_changepoints(
flatten_change_points, cp_to_metrics, metric_to_cps = detection.detect_multi_changepoints(
X,
search_method=self.search_method,
cost_model=self.cost_model,
penalty=self.penalty,
penalty_adjust=self.penalty_adjust,
n_jobs=self.n_jobs,
)
if not change_point_indexes:
if not flatten_change_points:
return pd.DataFrame()
metric_to_cps = {metric: cps for metric, cps in zip(metrics, change_point_indexes)}

# STEP2: segment change points
cluster_label_to_metrics, _ = segmentation.segment_nested_changepoints(
multi_change_points=change_point_indexes,
metrics=metrics,
flatten_change_points=flatten_change_points,
cp_to_metrics=cp_to_metrics,
time_series_length=X.shape[0],
kde_bandwidth=self.bandwidth,
)
Expand Down

0 comments on commit ebf03c1

Please sign in to comment.