Skip to content

Commit

Permalink
fix pool + add github run tox (#103)
Browse files Browse the repository at this point in the history
* fix pool

* better fix

* some fix

* remove extra tqdm

* revert fast

* Apply suggestions from code review
  • Loading branch information
arnaudon authored Jan 11, 2022
1 parent b185118 commit 92fbefd
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 26 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/run-tox.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Run all tox python3

on:
push:
pull_request:

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
with:
submodules: 'true'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
pip install tox-gh-actions
- name: Run tox
run: |
tox
8 changes: 0 additions & 8 deletions .travis.yml

This file was deleted.

27 changes: 10 additions & 17 deletions hcga/extraction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions necessary for the extraction of graph features."""
import logging
import time
from collections import defaultdict
from functools import partial
from importlib import import_module
from pathlib import Path
Expand Down Expand Up @@ -101,18 +100,16 @@ def extract(

def _print_runtimes(all_features_df):
"""Print sorted runtimes."""
runtimes = defaultdict(list)
for raw_feature in all_features_df.values():
for feat in raw_feature[1]:
runtimes[feat].append(raw_feature[1][feat])
feature_names, runtimes = list(runtimes.keys()), list(runtimes.values())
runtime_sortid = np.argsort(np.mean(runtimes, axis=1))[::-1]
for feat_id in runtime_sortid:
mean = all_features_df["runtimes"].mean(axis=0).to_list()
std = all_features_df["runtimes"].std(axis=0).to_list()
sortid = np.argsort(mean)[::-1]

for i in sortid:
L.info(
"Runtime of %s is %s ( std = %s ) seconds per graph.",
feature_names[feat_id],
np.round(np.mean(runtimes[feat_id]), 3),
np.round(np.std(runtimes[feat_id]), 3),
all_features_df["runtimes"].columns[i],
np.round(mean[i], 3),
np.round(std[i], 3),
)


Expand Down Expand Up @@ -190,9 +187,6 @@ def feature_extraction(graph, list_feature_classes, with_runtimes=False):
Returns:
(DataFrame): dataframe of calculated features for a given graph.
"""
if with_runtimes:
runtimes = {}

column_indexes = pd.MultiIndex(
levels=[[], []], codes=[[], []], names=["feature_class", "feature_name"]
)
Expand All @@ -205,14 +199,13 @@ def feature_extraction(graph, list_feature_classes, with_runtimes=False):
features = pd.DataFrame(feat_class_inst.get_features(), index=[graph.id])
columns = [(feat_class_inst.shortname, col) for col in features.columns]
features_df[columns] = features
feat_class_inst.pool.close()
del feat_class_inst

if with_runtimes:
runtimes[feature_class.shortname] = time.time() - start_time
features_df[("runtimes", feature_class.name)] = time.time() - start_time

if with_runtimes:
return graph.id, [features_df, runtimes]
return features_df

return features_df

Expand Down
8 changes: 7 additions & 1 deletion hcga/feature_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, graph=None):
Args:
graph (Graph): graph for initialisation, converted to given encoding
"""
self.pool = multiprocessing.Pool(processes=1)
self.pool = multiprocessing.Pool(processes=1, maxtasksperchild=1)
if graph is not None:
self.graph = graph.get_graph(self.__class__.encoding)
self.graph_id = graph.id
Expand All @@ -115,6 +115,12 @@ def __init__(self, graph=None):
self.graph = None
self.features = {}

def __del__(self):
if hasattr(self, 'pool'):
self.pool.close()
self.pool.terminate()
del self.pool

@classmethod
def setup_class(
cls,
Expand Down

0 comments on commit 92fbefd

Please sign in to comment.