From 89fcde0fdc8f315f842b45dacbd7b8a6a6c56046 Mon Sep 17 00:00:00 2001 From: d-schindler <60650591+d-schindler@users.noreply.github.com> Date: Thu, 21 Nov 2024 12:28:38 +0000 Subject: [PATCH] add method to compute all mcf measures at once --- mcf/io.py | 15 ++++++++++++++ mcf/mcf_base.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++ mcf/measures.py | 2 +- setup.py | 2 +- tests/test_mcf.py | 29 ++++++++++++++++++++++++++ tests/test_mcnf.py | 30 +++++++++++++++++++++++++++ 6 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 mcf/io.py diff --git a/mcf/io.py b/mcf/io.py new file mode 100644 index 0000000..2f27775 --- /dev/null +++ b/mcf/io.py @@ -0,0 +1,15 @@ +"""I/O functions.""" + +import pickle + + +def save_results(results, filename="results.pkl"): + """Save results in a pickle.""" + with open(filename, "wb") as results_file: + pickle.dump(results, results_file) + + +def load_results(filename="results.pkl"): # pragma: no cover + """Load results from a pickle.""" + with open(filename, "rb") as results_file: + return pickle.load(results_file) diff --git a/mcf/mcf_base.py b/mcf/mcf_base.py index 6daeae4..6351d4e 100644 --- a/mcf/mcf_base.py +++ b/mcf/mcf_base.py @@ -6,6 +6,7 @@ from tqdm import tqdm +from mcf.io import save_results from mcf.measures import ( compute_bettis, compute_partition_size, @@ -130,6 +131,56 @@ def compute_persistent_conflict(self): c_1, c_2, c = compute_persistent_conflict(self) return c_1, c_2, c + def compute_all_measures( + self, + file_path="mcf_results.pkl", + ): + """Construct filtration, compute PH and compute all derived measures.""" + + # build filtration + self.build_filtration() + + # compute persistent homology + self.compute_persistence() + + # obtain persistence + persistence = [ + self.filtration_gudhi.persistence_intervals_in_dimension(dim) + for dim in range(self.max_dim) + ] + + # compute Betti numbers + betti_0, betti_1, betti_2 = self.compute_bettis() + + # compute size of partitions + s_partitions = self.compute_partition_size() + + # compute persistent hierarchy + h, h_bar = self.compute_persistent_hierarchy() + + # compute persistent conflict + c_1, c_2, c = self.compute_persistent_conflict() + + # compile results dictionary + mcf_results = {} + mcf_results["filtration_indices"] = self.filtration_indices + mcf_results["max_dim"] = self.max_dim + mcf_results["persistence"] = persistence + mcf_results["betti_0"] = betti_0 + mcf_results["betti_1"] = betti_1 + mcf_results["betti_2"] = betti_2 + mcf_results["s_partitions"] = s_partitions + mcf_results["h"] = h + mcf_results["h_bar"] = h_bar + mcf_results["c_1"] = c_1 + mcf_results["c_2"] = c_2 + mcf_results["c"] = c + + # save results + save_results(mcf_results, file_path) + + return mcf_results + class MCNF(MCF): """Class to construct MCNF from a sequence of partitions using equivalent diff --git a/mcf/measures.py b/mcf/measures.py index 0c1fd17..afac1cb 100644 --- a/mcf/measures.py +++ b/mcf/measures.py @@ -17,7 +17,7 @@ def _compute_death_count(mcf, dim): death_count[i] = np.sum(all_deaths == mcf.filtration_indices[i]) # count inf - death_count[mcf.n_partitions] = np.sum(all_deaths == np.Inf) + death_count[mcf.n_partitions] = np.sum(all_deaths == np.inf) return death_count diff --git a/setup.py b/setup.py index 8ca68ff..e919f98 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages from setuptools import setup -__version__ = "0.0.4" +__version__ = "0.0.5" setup( name="MCF", diff --git a/tests/test_mcf.py b/tests/test_mcf.py index bb38b4c..3e251bb 100644 --- a/tests/test_mcf.py +++ b/tests/test_mcf.py @@ -135,3 +135,32 @@ def test_compute_persistent_conflict(): assert np.array_equal(c_1, np.array([0.0, 0.0, 0.0, 1.0, -1.0])) assert np.array_equal(c_2, np.array([0.0, 0.0, 0.0, 0.0, 0.0])) assert np.array_equal(c, np.array([0.0, 0.0, 0.0, 1.0, -1.0])) + + +def test_compute_all_measures(): + """Test for computing all MCF measures.""" + + # initialise MCF object + mcf = MCF() + mcf.load_data(partitions, filtration_indices) + + # compute all MCF measures + mcf_results = mcf.compute_all_measures() + + # check if results match + assert np.array_equal(mcf_results["filtration_indices"], filtration_indices) + assert mcf_results["max_dim"] == 3 + assert np.array_equal( + mcf_results["persistence"][0], np.array([[1.0, 2.0], [1.0, 3.0], [1.0, np.inf]]) + ) + assert np.array_equal(mcf_results["persistence"][1], np.array([[4.0, 5.0]])) + assert len(mcf_results["persistence"][2]) == 0 + assert np.array_equal(mcf_results["betti_0"], np.array([3, 2, 1, 1, 1])) + assert np.array_equal(mcf_results["betti_1"], np.array([0, 0, 0, 1, 0])) + assert np.array_equal(mcf_results["betti_2"], np.array([0, 0, 0, 0, 0])) + assert np.array_equal(mcf_results["s_partitions"], np.array([3, 2, 2, 2, 1])) + assert np.array_equal(mcf_results["h"], np.array([1.0, 1.0, 0.5, 0.5, 1.0])) + assert mcf_results["h_bar"] == 0.75 + assert np.array_equal(mcf_results["c_1"], np.array([0.0, 0.0, 0.0, 1.0, -1.0])) + assert np.array_equal(mcf_results["c_2"], np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + assert np.array_equal(mcf_results["c"], np.array([0.0, 0.0, 0.0, 1.0, -1.0])) diff --git a/tests/test_mcnf.py b/tests/test_mcnf.py index c346455..c90bb50 100644 --- a/tests/test_mcnf.py +++ b/tests/test_mcnf.py @@ -115,3 +115,33 @@ def test_compute_persistent_conflict(): assert np.array_equal(c_1, np.array([0.0, 0.0, 0.0, 1.0, -1.0])) assert np.array_equal(c_2, np.array([0.0, 0.0, 0.0, 0.0, 0.0])) assert np.array_equal(c, np.array([0.0, 0.0, 0.0, 1.0, -1.0])) + + +def test_compute_all_measures(): + """Test for computing all MCF measures.""" + + # initialise MCNF object + mcnf = MCNF() + mcnf.load_data(partitions, filtration_indices) + + # compute all MCNF measures + mcnf_results = mcnf.compute_all_measures() + + # check if results match + assert np.array_equal(mcnf_results["filtration_indices"], filtration_indices) + assert mcnf_results["max_dim"] == 3 + assert np.array_equal( + mcnf_results["persistence"][0], + np.array([[1.0, 2.0], [1.0, 3.0], [1.0, np.inf]]), + ) + assert np.array_equal(mcnf_results["persistence"][1], np.array([[4.0, 5.0]])) + assert len(mcnf_results["persistence"][2]) == 0 + assert np.array_equal(mcnf_results["betti_0"], np.array([3, 2, 1, 1, 1])) + assert np.array_equal(mcnf_results["betti_1"], np.array([0, 0, 0, 1, 0])) + assert np.array_equal(mcnf_results["betti_2"], np.array([0, 0, 0, 0, 0])) + assert np.array_equal(mcnf_results["s_partitions"], np.array([3, 2, 2, 2, 1])) + assert np.array_equal(mcnf_results["h"], np.array([1.0, 1.0, 0.5, 0.5, 1.0])) + assert mcnf_results["h_bar"] == 0.75 + assert np.array_equal(mcnf_results["c_1"], np.array([0.0, 0.0, 0.0, 1.0, -1.0])) + assert np.array_equal(mcnf_results["c_2"], np.array([0.0, 0.0, 0.0, 0.0, 0.0])) + assert np.array_equal(mcnf_results["c"], np.array([0.0, 0.0, 0.0, 1.0, -1.0]))