From f2b0e31d3b9a53a089ffb7b57953ace40a78842e Mon Sep 17 00:00:00 2001 From: Alexis Arnaudon Date: Fri, 3 Nov 2023 12:33:52 +0100 Subject: [PATCH] Feat: Make some deps optional (#90) --- setup.py | 11 ++-- src/pygenstability/plotting.py | 7 ++- src/pygenstability/pygenstability.py | 69 +++++++++++++++++++------ tests/data/test_run_default_leiden.yaml | 50 +++++++++--------- tests/test_pygenstability.py | 5 ++ 5 files changed, 97 insertions(+), 45 deletions(-) diff --git a/setup.py b/setup.py index 5d6dd6b..7f06b11 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ ["src/pygenstability/generalized_louvain/generalized_louvain.cpp"], include_dirs=["extra", "generalizedLouvain"], extra_compile_args=["-std=c++11"], + optional=True, ), ] plotly_require = ["plotly>=3.6.0"] @@ -28,17 +29,17 @@ "numpy>=1.18.1", "scipy>=1.4.1", "matplotlib>=3.1.3", - "networkx>=3.0", "scikit-learn", "cmake>=3.16.3", "click>=7.0", "tqdm>=4.45.0", "pybind11>=2.10.0", "pandas>=1.0.0", - "igraph", - "leidenalg", "threadpoolctl", ] +leiden_install = ["igraph", "leidenalg"] +networkx_install = ["networkx>=3.0"] + setup( name="PyGenStability", version=__version__, @@ -54,7 +55,9 @@ zip_safe=False, extras_require={ "plotly": plotly_require, - "all": plotly_require + test_require, + "leiden": leiden_install, + "networkx": networkx_install, + "all": plotly_require + test_require + leiden_install + networkx_install, }, entry_points={"console_scripts": ["pygenstability=pygenstability.app:cli"]}, packages=find_namespace_packages("src"), diff --git a/src/pygenstability/plotting.py b/src/pygenstability/plotting.py index 0ef15b1..ee4f19f 100644 --- a/src/pygenstability/plotting.py +++ b/src/pygenstability/plotting.py @@ -4,7 +4,12 @@ import matplotlib import matplotlib.pyplot as plt -import networkx as nx + +try: + import networkx as nx +except ImportError: # pragma: no cover + print('Please install networkx via pip install "pygenstability[networkx]" for full plotting.') + import numpy as np from matplotlib import gridspec from matplotlib import patches diff --git a/src/pygenstability/pygenstability.py b/src/pygenstability/pygenstability.py index d672ea0..6e13a32 100644 --- a/src/pygenstability/pygenstability.py +++ b/src/pygenstability/pygenstability.py @@ -19,8 +19,14 @@ from functools import wraps from time import time -import igraph as ig -import leidenalg +try: + import igraph as ig + import leidenalg + + _NO_LEIDEN = False +except ImportError: # pragma: no cover + _NO_LEIDEN = True + import numpy as np import scipy.sparse as sp from sklearn.metrics import mutual_info_score @@ -101,6 +107,21 @@ def _get_constructor_data(constructor, scales, pool, tqdm_disable=False): ) +def _check_method(method): # pragma: no cover + if _NO_LEIDEN and not hasattr(generalized_louvain, "evaluate_quality"): + raise Exception("Without Louvain or Leiden solver, we cannot run PyGenStability") + + if method == "louvain" and not hasattr(generalized_louvain, "evaluate_quality"): + print("Louvain is not available, we fallback to leiden.") + return "leiden" + + if method == "leiden" and _NO_LEIDEN: + print("Leiden is not available, we fallback to louvain.") + return "louvain" + + return method + + @_timing def run( graph=None, @@ -171,6 +192,7 @@ def run( - 'NVI': NVI at each scale - 'ttprime': ttprime matrix """ + method = _check_method(method) run_params = _get_params(locals()) graph = _graph_checks(graph) scales = _get_scales( @@ -217,7 +239,7 @@ def run( if with_postprocessing: L.info("Apply postprocessing...") - _apply_postprocessing(all_results, pool, constructor_data, tqdm_disable) + _apply_postprocessing(all_results, pool, constructor_data, tqdm_disable, method=method) if with_ttprime or with_optimal_scales: L.info("Compute ttprimes...") @@ -342,18 +364,34 @@ def _optimise(_, quality_indices, quality_values, null_model, global_shift, meth return stability + global_shift, community_id -def _evaluate_quality(partition_id, qualities_index, null_model, global_shift): +def _evaluate_quality(partition_id, qualities_index, null_model, global_shift, method="louvain"): """Worker for generalized Markov Stability optimisation runs.""" - quality = generalized_louvain.evaluate_quality( - qualities_index[0][0], - qualities_index[0][1], - qualities_index[1], - len(qualities_index[1]), - null_model, - np.shape(null_model)[0], - 1.0, - partition_id, - ) + if method == "louvain": + quality = generalized_louvain.evaluate_quality( + qualities_index[0][0], + qualities_index[0][1], + qualities_index[1], + len(qualities_index[1]), + null_model, + np.shape(null_model)[0], + 1.0, + partition_id, + ) + + if method == "leiden": + quality = np.mean( + [ + leidenalg.CPMVertexPartition( + ig.Graph(edges=zip(*qualities_index[0]), directed=True), + initial_membership=partition_id, + weights=qualities_index[1], + node_sizes=null.tolist(), + correct_self_loops=True, + ).quality() + for null in null_model[::2] + ] + ) + return quality + global_shift @@ -390,7 +428,7 @@ def _compute_ttprime(all_results, pool): @_timing -def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False): +def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False, method="louvain"): """Apply postprocessing.""" all_results_raw = all_results.copy() @@ -402,6 +440,7 @@ def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False): qualities_index=_to_indices(constructor["quality"]), null_model=constructor["null_model"], global_shift=constructor.get("shift", 0.0), + method=method, ) best_quality_id = np.argmax( pool.map( diff --git a/tests/data/test_run_default_leiden.yaml b/tests/data/test_run_default_leiden.yaml index 92285a5..f0b0230 100644 --- a/tests/data/test_run_default_leiden.yaml +++ b/tests/data/test_run_default_leiden.yaml @@ -43,28 +43,28 @@ community_id: - 18.0 - 19.0 - 20.0 -- - 0.0 - - 0.0 - - 0.0 - - 0.0 - - 0.0 - - 0.0 - - 0.0 - - 0.0 +- - 1.0 + - 2.0 + - 3.0 + - 4.0 + - 5.0 + - 6.0 + - 7.0 + - 8.0 + - 9.0 + - 10.0 - 0.0 - 0.0 - - 2.0 - - 2.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 + - 11.0 + - 12.0 + - 13.0 + - 14.0 + - 15.0 + - 16.0 + - 17.0 + - 18.0 + - 19.0 + - 20.0 - - 0.0 - 0.0 - 0.0 @@ -90,8 +90,8 @@ community_id: number_of_communities: - 21.0 - 21.0 -- 20.0 -- 3.0 +- 21.0 +- 21.0 - 3.0 run_params: constructor: linearized @@ -123,6 +123,6 @@ scales: stability: - 0.8526881720430108 - 0.8526881720430108 -- 0.6401045688749567 -- 0.5087182926944845 -- 0.5087182926944845 +- 0.8526881720430108 +- 0.8526881720430108 +- 0.49930627818244916 diff --git a/tests/test_pygenstability.py b/tests/test_pygenstability.py index 6de21fd..abdc00d 100644 --- a/tests/test_pygenstability.py +++ b/tests/test_pygenstability.py @@ -168,3 +168,8 @@ def test__evaluate_quality(graph): qualities_index = pgs._to_indices(data["quality"]) quality = pgs._evaluate_quality(community_id, qualities_index, data["null_model"], 0) assert_almost_equal(quality, 0.5590341906608186) + + quality = pgs._evaluate_quality( + community_id, qualities_index, data["null_model"], 0, method="leiden" + ) + assert_almost_equal(quality, 0.2741359784037568)