Skip to content

Commit

Permalink
Feat: Make some deps optional (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon authored Nov 3, 2023
1 parent 0ab3106 commit f2b0e31
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 45 deletions.
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
["src/pygenstability/generalized_louvain/generalized_louvain.cpp"],
include_dirs=["extra", "generalizedLouvain"],
extra_compile_args=["-std=c++11"],
optional=True,
),
]
plotly_require = ["plotly>=3.6.0"]
Expand All @@ -28,17 +29,17 @@
"numpy>=1.18.1",
"scipy>=1.4.1",
"matplotlib>=3.1.3",
"networkx>=3.0",
"scikit-learn",
"cmake>=3.16.3",
"click>=7.0",
"tqdm>=4.45.0",
"pybind11>=2.10.0",
"pandas>=1.0.0",
"igraph",
"leidenalg",
"threadpoolctl",
]
leiden_install = ["igraph", "leidenalg"]
networkx_install = ["networkx>=3.0"]

setup(
name="PyGenStability",
version=__version__,
Expand All @@ -54,7 +55,9 @@
zip_safe=False,
extras_require={
"plotly": plotly_require,
"all": plotly_require + test_require,
"leiden": leiden_install,
"networkx": networkx_install,
"all": plotly_require + test_require + leiden_install + networkx_install,
},
entry_points={"console_scripts": ["pygenstability=pygenstability.app:cli"]},
packages=find_namespace_packages("src"),
Expand Down
7 changes: 6 additions & 1 deletion src/pygenstability/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx

try:
import networkx as nx
except ImportError: # pragma: no cover
print('Please install networkx via pip install "pygenstability[networkx]" for full plotting.')

import numpy as np
from matplotlib import gridspec
from matplotlib import patches
Expand Down
69 changes: 54 additions & 15 deletions src/pygenstability/pygenstability.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
from functools import wraps
from time import time

import igraph as ig
import leidenalg
try:
import igraph as ig
import leidenalg

_NO_LEIDEN = False
except ImportError: # pragma: no cover
_NO_LEIDEN = True

import numpy as np
import scipy.sparse as sp
from sklearn.metrics import mutual_info_score
Expand Down Expand Up @@ -101,6 +107,21 @@ def _get_constructor_data(constructor, scales, pool, tqdm_disable=False):
)


def _check_method(method): # pragma: no cover
if _NO_LEIDEN and not hasattr(generalized_louvain, "evaluate_quality"):
raise Exception("Without Louvain or Leiden solver, we cannot run PyGenStability")

if method == "louvain" and not hasattr(generalized_louvain, "evaluate_quality"):
print("Louvain is not available, we fallback to leiden.")
return "leiden"

if method == "leiden" and _NO_LEIDEN:
print("Leiden is not available, we fallback to louvain.")
return "louvain"

return method


@_timing
def run(
graph=None,
Expand Down Expand Up @@ -171,6 +192,7 @@ def run(
- 'NVI': NVI at each scale
- 'ttprime': ttprime matrix
"""
method = _check_method(method)
run_params = _get_params(locals())
graph = _graph_checks(graph)
scales = _get_scales(
Expand Down Expand Up @@ -217,7 +239,7 @@ def run(

if with_postprocessing:
L.info("Apply postprocessing...")
_apply_postprocessing(all_results, pool, constructor_data, tqdm_disable)
_apply_postprocessing(all_results, pool, constructor_data, tqdm_disable, method=method)

if with_ttprime or with_optimal_scales:
L.info("Compute ttprimes...")
Expand Down Expand Up @@ -342,18 +364,34 @@ def _optimise(_, quality_indices, quality_values, null_model, global_shift, meth
return stability + global_shift, community_id


def _evaluate_quality(partition_id, qualities_index, null_model, global_shift):
def _evaluate_quality(partition_id, qualities_index, null_model, global_shift, method="louvain"):
"""Worker for generalized Markov Stability optimisation runs."""
quality = generalized_louvain.evaluate_quality(
qualities_index[0][0],
qualities_index[0][1],
qualities_index[1],
len(qualities_index[1]),
null_model,
np.shape(null_model)[0],
1.0,
partition_id,
)
if method == "louvain":
quality = generalized_louvain.evaluate_quality(
qualities_index[0][0],
qualities_index[0][1],
qualities_index[1],
len(qualities_index[1]),
null_model,
np.shape(null_model)[0],
1.0,
partition_id,
)

if method == "leiden":
quality = np.mean(
[
leidenalg.CPMVertexPartition(
ig.Graph(edges=zip(*qualities_index[0]), directed=True),
initial_membership=partition_id,
weights=qualities_index[1],
node_sizes=null.tolist(),
correct_self_loops=True,
).quality()
for null in null_model[::2]
]
)

return quality + global_shift


Expand Down Expand Up @@ -390,7 +428,7 @@ def _compute_ttprime(all_results, pool):


@_timing
def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False):
def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False, method="louvain"):
"""Apply postprocessing."""
all_results_raw = all_results.copy()

Expand All @@ -402,6 +440,7 @@ def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False):
qualities_index=_to_indices(constructor["quality"]),
null_model=constructor["null_model"],
global_shift=constructor.get("shift", 0.0),
method=method,
)
best_quality_id = np.argmax(
pool.map(
Expand Down
50 changes: 25 additions & 25 deletions tests/data/test_run_default_leiden.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,28 @@ community_id:
- 18.0
- 19.0
- 20.0
- - 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- - 1.0
- 2.0
- 3.0
- 4.0
- 5.0
- 6.0
- 7.0
- 8.0
- 9.0
- 10.0
- 0.0
- 0.0
- 2.0
- 2.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 11.0
- 12.0
- 13.0
- 14.0
- 15.0
- 16.0
- 17.0
- 18.0
- 19.0
- 20.0
- - 0.0
- 0.0
- 0.0
Expand All @@ -90,8 +90,8 @@ community_id:
number_of_communities:
- 21.0
- 21.0
- 20.0
- 3.0
- 21.0
- 21.0
- 3.0
run_params:
constructor: linearized
Expand Down Expand Up @@ -123,6 +123,6 @@ scales:
stability:
- 0.8526881720430108
- 0.8526881720430108
- 0.6401045688749567
- 0.5087182926944845
- 0.5087182926944845
- 0.8526881720430108
- 0.8526881720430108
- 0.49930627818244916
5 changes: 5 additions & 0 deletions tests/test_pygenstability.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,8 @@ def test__evaluate_quality(graph):
qualities_index = pgs._to_indices(data["quality"])
quality = pgs._evaluate_quality(community_id, qualities_index, data["null_model"], 0)
assert_almost_equal(quality, 0.5590341906608186)

quality = pgs._evaluate_quality(
community_id, qualities_index, data["null_model"], 0, method="leiden"
)
assert_almost_equal(quality, 0.2741359784037568)

0 comments on commit f2b0e31

Please sign in to comment.