Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Make some deps optional #90

Merged
merged 5 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
["src/pygenstability/generalized_louvain/generalized_louvain.cpp"],
include_dirs=["extra", "generalizedLouvain"],
extra_compile_args=["-std=c++11"],
optional=True,
),
]
plotly_require = ["plotly>=3.6.0"]
Expand All @@ -28,17 +29,17 @@
"numpy>=1.18.1",
"scipy>=1.4.1",
"matplotlib>=3.1.3",
"networkx>=3.0",
"scikit-learn",
"cmake>=3.16.3",
"click>=7.0",
"tqdm>=4.45.0",
"pybind11>=2.10.0",
"pandas>=1.0.0",
"igraph",
"leidenalg",
"threadpoolctl",
]
leiden_install = ["igraph", "leidenalg"]
networkx_install = ["networkx>=3.0"]

setup(
name="PyGenStability",
version=__version__,
Expand All @@ -54,7 +55,9 @@
zip_safe=False,
extras_require={
"plotly": plotly_require,
"all": plotly_require + test_require,
"leiden": leiden_install,
"networkx": networkx_install,
"all": plotly_require + test_require + leiden_install + networkx_install,
},
entry_points={"console_scripts": ["pygenstability=pygenstability.app:cli"]},
packages=find_namespace_packages("src"),
Expand Down
7 changes: 6 additions & 1 deletion src/pygenstability/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx

try:
import networkx as nx
except ImportError: # pragma: no cover
print('Please install networkx via pip install "pygenstability[networkx]" for full plotting.')

import numpy as np
from matplotlib import gridspec
from matplotlib import patches
Expand Down
69 changes: 54 additions & 15 deletions src/pygenstability/pygenstability.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
from functools import wraps
from time import time

import igraph as ig
import leidenalg
try:
import igraph as ig
import leidenalg

_NO_LEIDEN = False
except ImportError: # pragma: no cover
_NO_LEIDEN = True

import numpy as np
import scipy.sparse as sp
from sklearn.metrics import mutual_info_score
Expand Down Expand Up @@ -101,6 +107,21 @@ def _get_constructor_data(constructor, scales, pool, tqdm_disable=False):
)


def _check_method(method): # pragma: no cover
if _NO_LEIDEN and not hasattr(generalized_louvain, "evaluate_quality"):
raise Exception("Without Louvain or Leiden solver, we cannot run PyGenStability")

if method == "louvain" and not hasattr(generalized_louvain, "evaluate_quality"):
print("Louvain is not available, we fallback to leiden.")
return "leiden"

if method == "leiden" and _NO_LEIDEN:
print("Leiden is not available, we fallback to louvain.")
return "louvain"

return method


@_timing
def run(
graph=None,
Expand Down Expand Up @@ -171,6 +192,7 @@ def run(
- 'NVI': NVI at each scale
- 'ttprime': ttprime matrix
"""
method = _check_method(method)
run_params = _get_params(locals())
graph = _graph_checks(graph)
scales = _get_scales(
Expand Down Expand Up @@ -217,7 +239,7 @@ def run(

if with_postprocessing:
L.info("Apply postprocessing...")
_apply_postprocessing(all_results, pool, constructor_data, tqdm_disable)
_apply_postprocessing(all_results, pool, constructor_data, tqdm_disable, method=method)

if with_ttprime or with_optimal_scales:
L.info("Compute ttprimes...")
Expand Down Expand Up @@ -342,18 +364,34 @@ def _optimise(_, quality_indices, quality_values, null_model, global_shift, meth
return stability + global_shift, community_id


def _evaluate_quality(partition_id, qualities_index, null_model, global_shift):
def _evaluate_quality(partition_id, qualities_index, null_model, global_shift, method="louvain"):
"""Worker for generalized Markov Stability optimisation runs."""
quality = generalized_louvain.evaluate_quality(
qualities_index[0][0],
qualities_index[0][1],
qualities_index[1],
len(qualities_index[1]),
null_model,
np.shape(null_model)[0],
1.0,
partition_id,
)
if method == "louvain":
quality = generalized_louvain.evaluate_quality(
qualities_index[0][0],
qualities_index[0][1],
qualities_index[1],
len(qualities_index[1]),
null_model,
np.shape(null_model)[0],
1.0,
partition_id,
)

if method == "leiden":
quality = np.mean(
[
leidenalg.CPMVertexPartition(
ig.Graph(edges=zip(*qualities_index[0]), directed=True),
initial_membership=partition_id,
weights=qualities_index[1],
node_sizes=null.tolist(),
correct_self_loops=True,
).quality()
for null in null_model[::2]
]
)

return quality + global_shift


Expand Down Expand Up @@ -390,7 +428,7 @@ def _compute_ttprime(all_results, pool):


@_timing
def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False):
def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False, method="louvain"):
"""Apply postprocessing."""
all_results_raw = all_results.copy()

Expand All @@ -402,6 +440,7 @@ def _apply_postprocessing(all_results, pool, constructors, tqdm_disable=False):
qualities_index=_to_indices(constructor["quality"]),
null_model=constructor["null_model"],
global_shift=constructor.get("shift", 0.0),
method=method,
)
best_quality_id = np.argmax(
pool.map(
Expand Down
50 changes: 25 additions & 25 deletions tests/data/test_run_default_leiden.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,28 @@ community_id:
- 18.0
- 19.0
- 20.0
- - 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- - 1.0
- 2.0
- 3.0
- 4.0
- 5.0
- 6.0
- 7.0
- 8.0
- 9.0
- 10.0
- 0.0
- 0.0
- 2.0
- 2.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 11.0
- 12.0
- 13.0
- 14.0
- 15.0
- 16.0
- 17.0
- 18.0
- 19.0
- 20.0
- - 0.0
- 0.0
- 0.0
Expand All @@ -90,8 +90,8 @@ community_id:
number_of_communities:
- 21.0
- 21.0
- 20.0
- 3.0
- 21.0
- 21.0
- 3.0
run_params:
constructor: linearized
Expand Down Expand Up @@ -123,6 +123,6 @@ scales:
stability:
- 0.8526881720430108
- 0.8526881720430108
- 0.6401045688749567
- 0.5087182926944845
- 0.5087182926944845
- 0.8526881720430108
- 0.8526881720430108
- 0.49930627818244916
5 changes: 5 additions & 0 deletions tests/test_pygenstability.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,8 @@ def test__evaluate_quality(graph):
qualities_index = pgs._to_indices(data["quality"])
quality = pgs._evaluate_quality(community_id, qualities_index, data["null_model"], 0)
assert_almost_equal(quality, 0.5590341906608186)

quality = pgs._evaluate_quality(
community_id, qualities_index, data["null_model"], 0, method="leiden"
)
assert_almost_equal(quality, 0.2741359784037568)
Loading