diff --git a/CHANGELOG.md b/CHANGELOG.md index 959b8d5..4243522 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,12 @@ # Changelog +## 1.1.0 - (2023-08-01) + +**New features**: + +- New argument `forbidden_segments` (list or vector of 2-tuple) or `None` to `Control`. If not `None`, `changeforest` will not split on split points contained in segments `(a, b]` in `forbidden_segments` (rust and Python only). Thanks @enzbus! + ## 1.0.1 - (2022-06-01) **Bug fixes:** diff --git a/Cargo.toml b/Cargo.toml index 3e42c56..5fb94f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "changeforest" description = "Random Forests for Change Point Detection" authors = ["Malte Londschien "] repository = "https://github.com/mlondschien/changeforest/" -version = "1.0.1" +version = "1.1.0" edition = "2021" readme = "README.md" license = "BSD-3-Clause" diff --git a/changeforest-py/Cargo.toml b/changeforest-py/Cargo.toml index db79a90..1d51fcb 100644 --- a/changeforest-py/Cargo.toml +++ b/changeforest-py/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "changeforest_py" -version = "1.0.1" +version = "1.1.0" edition = "2021" [lib] diff --git a/changeforest-py/changeforest/control.py b/changeforest-py/changeforest/control.py index c46f227..99f2bcf 100644 --- a/changeforest-py/changeforest/control.py +++ b/changeforest-py/changeforest/control.py @@ -18,6 +18,7 @@ def __init__( random_forest_max_depth="default", random_forest_max_features="default", random_forest_n_jobs="default", + forbidden_segments="default", ): self.minimal_relative_segment_length = _to_float( minimal_relative_segment_length @@ -32,6 +33,7 @@ def __init__( self.random_forest_max_depth = _to_int(random_forest_max_depth) self.random_forest_max_features = _to_int(random_forest_max_features) self.random_forest_n_jobs = _to_int(random_forest_n_jobs) + self.forbidden_segments = _to_segments(forbidden_segments) def _to_float(value): @@ -50,3 +52,16 @@ def _to_int(value): return value else: return int(value) + + +def _to_segments(value): + if (value is None) or isinstance(value, str): + return value + else: + try: + return [(int(el1), int(el2)) for (el1, el2) in value] + except Exception: + raise SyntaxError( + "forbidden_segments must be provided as [(a,b), ...] where a and b are " + "integers." + ) diff --git a/changeforest-py/pyproject.toml b/changeforest-py/pyproject.toml index 7ceecb2..d2a5941 100644 --- a/changeforest-py/pyproject.toml +++ b/changeforest-py/pyproject.toml @@ -2,7 +2,7 @@ name = "changeforest" description = "Random Forests for Change Point Detection" readme = "README.md" -version = "1.0.1" +version = "1.1.0" requires-python = ">=3.7" author = "Malte Londschien " urls = {homepage = "https://github.com/mlondschien/changeforest/"} diff --git a/changeforest-py/src/control.rs b/changeforest-py/src/control.rs index 00a7041..7bf98b9 100644 --- a/changeforest-py/src/control.rs +++ b/changeforest-py/src/control.rs @@ -86,6 +86,12 @@ pub fn control_from_pyobj(py: Python, obj: Option) -> PyResult>>(py) { + control = control.with_forbidden_segments(value); + } + }; } Ok(control) diff --git a/changeforest-py/tests/test_changeforest.py b/changeforest-py/tests/test_changeforest.py index cb2e476..6a1c764 100644 --- a/changeforest-py/tests/test_changeforest.py +++ b/changeforest-py/tests/test_changeforest.py @@ -29,3 +29,43 @@ def test_changeforest_repr(iris_dataset): °--(100, 150] 136 -2.398 0.875\ """ ) + + +def test_changeforest_repr_segments(iris_dataset): + result = changeforest( + iris_dataset, + "random_forest", + "bs", + control=Control(forbidden_segments=[(0, 49), (101, 120)]), + ) + assert ( + result.__repr__() + == """\ + best_split max_gain p_value +(0, 150] 50 95.1 0.005 + ¦--(0, 50] + °--(50, 150] 100 52.799 0.005 + ¦--(50, 100] 53 6.892 0.315 + °--(100, 150] 136 -3.516 0.68\ +""" # noqa: W291 + ) + + +def test_changeforest_repr_segments2(iris_dataset): + result = changeforest( + iris_dataset, + "random_forest", + "bs", + control=Control(forbidden_segments=[(49, 101)]), + ) + assert ( + result.__repr__() + == """\ + best_split max_gain p_value +(0, 150] 49 87.462 0.005 + ¦--(0, 49] 2 -8.889 0.995 + °--(49, 150] 102 41.237 0.005 + ¦--(49, 102] + °--(102, 150] 136 1.114 0.36\ +""" # noqa: W291 + ) diff --git a/changeforest-py/tests/test_control.py b/changeforest-py/tests/test_control.py index 07971de..f12d896 100644 --- a/changeforest-py/tests/test_control.py +++ b/changeforest-py/tests/test_control.py @@ -129,3 +129,22 @@ def test_control_defaults(iris_dataset, key, default_value, another_value): assert str(result) == str(default_result) assert str(result) != str(another_result) + + +def test_control_segments(): + with pytest.raises(SyntaxError): + Control( + forbidden_segments=[ + (2), + ] + ) + + with pytest.raises(SyntaxError): + Control( + forbidden_segments=[ + (2, 3, 4), + ] + ) + + with pytest.raises(SyntaxError): + Control(forbidden_segments=[2, 3]) diff --git a/changeforest-r/DESCRIPTION b/changeforest-r/DESCRIPTION index b1d3d7e..f02c536 100644 --- a/changeforest-r/DESCRIPTION +++ b/changeforest-r/DESCRIPTION @@ -1,7 +1,7 @@ Package: changeforest Type: Package Title: Random Forests for Change Point Detection -Version: 1.0.1 +Version: 1.1.0 Author: Malte Londschien Maintainer: Malte Londschien Description: diff --git a/changeforest-r/src/rust/Cargo.toml b/changeforest-r/src/rust/Cargo.toml index 7aad776..dbf2382 100644 --- a/changeforest-r/src/rust/Cargo.toml +++ b/changeforest-r/src/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = 'changeforestr' -version = '1.0.1' +version = '1.1.0' edition = '2021' [lib] diff --git a/src/control.rs b/src/control.rs index eebd4f5..761e552 100644 --- a/src/control.rs +++ b/src/control.rs @@ -28,6 +28,8 @@ pub struct Control { pub seed: u64, /// Hyperparameters for random forests. pub random_forest_parameters: RandomForestParameters, + /// Segments of indexes were no segmentation is allowed. + pub forbidden_segments: Option>, } impl Control { @@ -45,6 +47,7 @@ impl Control { .with_max_depth(Some(8)) .with_max_features(MaxFeatures::Sqrt) .with_n_jobs(Some(-1)), + forbidden_segments: None, } } @@ -111,4 +114,20 @@ impl Control { self.random_forest_parameters = random_forest_parameters; self } + + pub fn with_forbidden_segments( + mut self, + forbidden_segments: Option>, + ) -> Self { + // check that segments are well specified + if let Some(ref _forbidden_segments) = forbidden_segments { + for el in _forbidden_segments.iter() { + if el.0 > el.1 { + panic!("Forbidden segments must be specified as [(a,b), ...] where a <= b!"); + } + } + } + self.forbidden_segments = forbidden_segments; + self + } } diff --git a/src/optimizer/grid_search.rs b/src/optimizer/grid_search.rs index 31e4375..6441aa7 100644 --- a/src/optimizer/grid_search.rs +++ b/src/optimizer/grid_search.rs @@ -93,4 +93,37 @@ mod tests { expected ); } + + #[rstest] + #[case(0, 10, Some(vec![(0, 3)]), 0.09, vec![4, 5, 6, 7, 8])] + #[case(1, 10, Some(vec![(6, 10)]), 0.15, vec![3, 4, 5, 6])] + #[case(0, 10, Some(vec![(2, 4), (5, 7)]), 0.09, vec![1, 2, 5, 8])] + #[case(1, 7, Some(vec![(2, 4), (5, 7)]), 0.09, vec![2, 5])] + fn test_split_candidates( + #[case] start: usize, + #[case] stop: usize, + #[case] forbidden_segments: Option>, + #[case] delta: f64, + #[case] expected: Vec, + ) { + let X = ndarray::array![ + [0.0], + [0.0], + [0.0], + [0.0], + [-0.0], + [-0.0], + [-0.0], + [-0.0], + [-0.0], + [-0.0] + ]; + let X_view = X.view(); + let control = Control::default() + .with_minimal_relative_segment_length(delta) + .with_forbidden_segments(forbidden_segments); + let gain = testing::ChangeInMean::new(&X_view, &control); + let grid_search = GridSearch { gain }; + assert_eq!(grid_search.split_candidates(start, stop).unwrap(), expected); + } } diff --git a/src/optimizer/optimizer.rs b/src/optimizer/optimizer.rs index 823a4d1..a13446f 100644 --- a/src/optimizer/optimizer.rs +++ b/src/optimizer/optimizer.rs @@ -23,7 +23,22 @@ pub trait Optimizer { if 2 * minimal_segment_length >= (stop - start) { Err("Segment too small.") } else { - Ok(((start + minimal_segment_length)..(stop - minimal_segment_length)).collect()) + let mut split_candidates: Vec = + ((start + minimal_segment_length)..(stop - minimal_segment_length)).collect(); + + if let Some(forbidden_segments) = &self.control().forbidden_segments { + split_candidates.retain(|x| { + forbidden_segments + .iter() + .all(|segment| x <= &segment.0 || x > &segment.1) + }); + } + + if split_candidates.is_empty() { + Err("No split_candidates left after filtering out forbidden_segments.") + } else { + Ok(split_candidates) + } } } } diff --git a/src/optimizer/two_step_search.rs b/src/optimizer/two_step_search.rs index 10f4445..b7357c9 100644 --- a/src/optimizer/two_step_search.rs +++ b/src/optimizer/two_step_search.rs @@ -51,13 +51,46 @@ where fn find_best_split(&self, start: usize, stop: usize) -> Result { let split_candidates = self.split_candidates(start, stop)?; - let guesses = vec![ - (3 * start + stop) / 4, - (start + stop) / 2, - (start + 3 * stop) / 4, - ]; + let mut guesses = vec![]; let mut results: Vec = vec![]; + // if there are forbidden segments change the heuristics + // pick middle element of split_candidates, 1/4th and 3/4th + if let Some(_forbidden_segments) = &self.control().forbidden_segments { + // there is at least one element in split_candidates + guesses.push( + split_candidates + .clone() + .into_iter() + .nth(split_candidates.len() / 4) + .unwrap(), + ); + + // we add this if it is not equal to last + let cand = split_candidates + .clone() + .into_iter() + .nth(split_candidates.len() / 2) + .unwrap(); + if cand > guesses[guesses.len() - 1] { + guesses.push(cand) + }; + + // same + let cand = split_candidates + .clone() + .into_iter() + .nth(3 * split_candidates.len() / 4) + .unwrap(); + if cand > guesses[guesses.len() - 1] { + guesses.push(cand) + }; + } else { + guesses.push((3 * start + stop) / 4); + guesses.push((start + stop) / 2); + guesses.push((start + 3 * stop) / 4); + } + // Don't use first and last guess if stop - start / 4 < delta. for guess in guesses.iter().filter(|x| split_candidates.contains(x)) { results.push(self._single_find_best_split(start, stop, *guess, &split_candidates)); diff --git a/src/segmentation.rs b/src/segmentation.rs index 3f53f7e..d62c680 100644 --- a/src/segmentation.rs +++ b/src/segmentation.rs @@ -60,7 +60,9 @@ impl<'a> Segmentation<'a> { // start + segment_length > n through floating point errors in // n_segments, e.g. for n = 20'000, alpha_k = 1/sqrt(2), k=6 stop = (start + (segment_length as f32).ceil() as usize).min(optimizer.n()); - segments.push(optimizer.find_best_split(start, stop).unwrap()); + if let Ok(optimizer_result) = optimizer.find_best_split(start, stop) { + segments.push(optimizer_result) + } } } }