diff --git a/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst b/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst new file mode 100644 index 0000000..e144ac5 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_extract\_features\_and\_predict module +=============================================================== + +.. automodule:: morphoclass.console.cmd_extract_features_and_predict + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.cmd_predict.rst b/docs/source/api/morphoclass.console.cmd_predict.rst new file mode 100644 index 0000000..9dcbf34 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_predict.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_predict module +======================================= + +.. automodule:: morphoclass.console.cmd_predict + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.rst b/docs/source/api/morphoclass.console.rst index 67415de..54ab684 100644 --- a/docs/source/api/morphoclass.console.rst +++ b/docs/source/api/morphoclass.console.rst @@ -9,10 +9,12 @@ Submodules morphoclass.console.cmd_evaluate morphoclass.console.cmd_extract_features + morphoclass.console.cmd_extract_features_and_predict morphoclass.console.cmd_morphometrics morphoclass.console.cmd_organise_dataset morphoclass.console.cmd_performance_table morphoclass.console.cmd_plot_dataset_stats + morphoclass.console.cmd_predict morphoclass.console.cmd_preprocess_dataset morphoclass.console.cmd_train morphoclass.console.cmd_xai diff --git a/requirements.txt b/requirements.txt index e91ca17..af2be82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ scipy==1.7.3 seaborn==0.11.0 shap[plots]==0.39.0 tmd==2.1.0 -torch==1.7.1 +torch==1.9.0 tqdm==4.53.0 umap-learn==0.5.1 xgboost==1.4.2 diff --git a/setup.cfg b/setup.cfg index 6a0083a..3f00e1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,8 +36,8 @@ install_requires = imbalanced-learn jinja2 matplotlib - morphio - morphology-workflows>=0.3.0 + #morphio + #morphology-workflows>=0.3.0 networkx neurom>=3 NeuroR>=1.6.1 diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index c952be5..93eda5e 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -113,6 +113,29 @@ def cli( no_simplify_graph: bool, keep_diagram: bool, force: bool, +) -> None: + """Extract morphology features.""" + return extract_features( + csv_path, + neurite_type, + feature, + output_dir, + orient, + no_simplify_graph, + keep_diagram, + force, + ) + + +def extract_features( + csv_path: StrPath, + neurite_type: str, + feature: str, + output_dir: StrPath, + orient: bool, + no_simplify_graph: bool, + keep_diagram: bool, + force: bool, ) -> None: """Extract morphology features.""" output_dir = pathlib.Path(output_dir) diff --git a/src/morphoclass/console/cmd_extract_features_and_predict.py b/src/morphoclass/console/cmd_extract_features_and_predict.py new file mode 100644 index 0000000..30d19e6 --- /dev/null +++ b/src/morphoclass/console/cmd_extract_features_and_predict.py @@ -0,0 +1,147 @@ +# Copyright © 2022-2022 Blue Brain Project/EPFL +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the `morphoclass predict` CLI command.""" +from __future__ import annotations + +import logging +import textwrap + +import click + +logger = logging.getLogger(__name__) + + +@click.command( + name="predict", + help="Run inference.", +) +@click.help_option("-h", "--help") +@click.option( + "-i", + "--input_csv", + required=True, + type=click.Path(exists=True, dir_okay=True), + help=textwrap.dedent( + """ + The CSV path with the path to all the morphologies to classify. + """ + ).strip(), +) +@click.option( + "-c", + "--checkpoint", + "checkpoint_file", + required=True, + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help=textwrap.dedent( + """ + The path to the pre-trained model checkpoint. + """ + ).strip(), +) +@click.option( + "-o", + "--output-dir", + required=True, + type=click.Path(exists=False, file_okay=False, writable=True), + help="Output directory for the results.", +) +@click.option( + "-n", + "--results-name", + required=False, + type=click.STRING, + help="The filename of the results file", +) +def cli(input_csv, checkpoint_file, output_dir, results_name): + """Run the `morphoclass predict` CLI command. + + Parameters + ---------- + input_csv + The CSV with all the morphologies path. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ + import pathlib + from datetime import datetime + + input_csv = pathlib.Path(input_csv).resolve() + output_dir = pathlib.Path(output_dir).resolve() + checkpoint_file = pathlib.Path(checkpoint_file).resolve() + if results_name is None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + results_name = f"results_{timestamp}" + results_path = output_dir / (results_name + ".json") + click.secho(f"Input CSV : {input_csv}", fg="yellow") + click.secho(f"Output file : {results_path}", fg="yellow") + click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") + if results_path.exists(): + msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) ' + click.secho(msg, fg="red", bold=True, nl=False) + response = input() + if response.strip().lower() != "y": + click.secho("Stopping.", fg="red") + return + else: + click.secho("You chose to overwrite, proceeding...", fg="red") + + click.secho("✔ Loading checkpoint...", fg="green", bold=True) + import torch + + from morphoclass.console.cmd_extract_features import extract_features + from morphoclass.console.cmd_predict import predict + + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + neurites = ["apical", "axon", "basal", "all"] + neurite_type = [ + neurite for neurite in neurites if neurite in str(checkpoint["features_dir"]) + ] + features_type = [ + "graph-rd", + "graph-proj", + "diagram-tmd-rd", + "diagram-tmd-proj", + "diagram-deepwalk", + "image-tmd-rd", + "image-tmd-proj", + "image-deepwalk", + ] + feature = [ + feature + for feature in features_type + if feature in str(checkpoint["features_dir"]) + ] + + extract_features( + input_csv, + neurite_type[0], + feature[0], + output_dir / "features", + False, + False, + False, + False, + ) + + predict( + features_dir=output_dir / "features", + checkpoint_file=checkpoint_file, + output_dir=output_dir, + results_name=results_name, + ) diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py new file mode 100644 index 0000000..dbf461b --- /dev/null +++ b/src/morphoclass/console/cmd_predict.py @@ -0,0 +1,263 @@ +# Copyright © 2022-2022 Blue Brain Project/EPFL +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the `morphoclass predict-after-extraction` CLI command.""" +from __future__ import annotations + +import logging +import textwrap + +import click + +logger = logging.getLogger(__name__) + + +@click.command( + name="predict-after-extraction", + help="Run inference from features directory.", +) +@click.help_option("-h", "--help") +@click.option( + "-f", + "--features-dir", + required=True, + type=click.Path(exists=True, dir_okay=True), + help=textwrap.dedent( + """ + The path to the extracted features of the morphologies to classify + """ + ).strip(), +) +@click.option( + "-c", + "--checkpoint", + "checkpoint_file", + required=True, + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help=textwrap.dedent( + """ + The path to the pre-trained model checkpoint. + """ + ).strip(), +) +@click.option( + "-o", + "--output-dir", + required=True, + type=click.Path(exists=False, file_okay=False, writable=True), + help="Output directory for the results.", +) +@click.option( + "-n", + "--results-name", + required=False, + type=click.STRING, + help="The filename of the results file", +) +def cli(features_dir, checkpoint_file, output_dir, results_name): + """Run the `morphoclass predict-after-extraction` CLI command. + + Parameters + ---------- + features_dir + The path to the features of the morphologies. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ + return predict(features_dir, checkpoint_file, output_dir, results_name) + + +def predict(features_dir, checkpoint_file, output_dir, results_name): + """Run the predict command. + + Parameters + ---------- + features_dir + The path to the features of the morphologies. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ + import json + import pathlib + from datetime import datetime + + features_dir = pathlib.Path(features_dir).resolve() + output_dir = pathlib.Path(output_dir).resolve() + checkpoint_file = pathlib.Path(checkpoint_file).resolve() + if results_name is None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + results_name = f"results_{timestamp}" + results_path = output_dir / (results_name + ".json") + click.secho(f"Features Dir : {features_dir}", fg="yellow") + click.secho(f"Output file : {results_path}", fg="yellow") + click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") + if results_path.exists(): + msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) ' + click.secho(msg, fg="red", bold=True, nl=False) + response = input() + if response.strip().lower() != "y": + click.secho("Stopping.", fg="red") + return + else: + click.secho("You chose to overwrite, proceeding...", fg="red") + + click.secho("✔ Loading checkpoint...", fg="green", bold=True) + import torch + + from morphoclass.data import MorphologyDataset + from morphoclass.data.morphology_data import MorphologyData + + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) + model_class = checkpoint["model_class"] + click.secho(f"Model : {model_class}", fg="yellow") + if "metadata" in checkpoint: + timestamp = checkpoint["metadata"]["timestamp"] + click.secho(f"Created on : {timestamp}", fg="yellow") + + click.secho("✔ Loading data...", fg="green", bold=True) + data = [] + for path in sorted(features_dir.glob("*.features")): + data.append(MorphologyData.load(path)) + dataset = MorphologyDataset(data) + click.echo(f"> Dataset length: {len(dataset)}") + + click.secho("✔ Computing predictions...", fg="green", bold=True) + if "ManNet" in model_class: + logits = predict_gnn(dataset, checkpoint) + predictions = logits.argmax(axis=1) + elif "CNN" in model_class: + logits = predict_cnn(dataset, checkpoint) + predictions = logits.argmax(axis=1) + elif "XGB" in model_class: + predictions = predict_xgb(dataset, checkpoint) + else: + click.secho( + f"Model not recognized: {model_class}. Stopping.", + fg="red", + bold=True, + nl=False, + ) + return + + click.secho("✔ Exporting results...", fg="green", bold=True) + prediction_labels = {} + for sample, sample_pred in zip(dataset.data, predictions): + sample_path = str(sample.path) + pred_label = dataset.y_to_label[sample_pred] + prediction_labels[sample_path] = pred_label + + results = { + "predictions": prediction_labels, + "checkpoint_path": str(checkpoint_file), + "model": model_class, + } + with open(results_path, "w") as fp: + json.dump(results, fp) + + click.secho("✔ Done.", fg="green", bold=True) + + +def predict_gnn(dataset, checkpoint): + """Compute predictions with a GNN (ManNet) classifier. + + Parameters + ---------- + dataset + The morphology dataset. + checkpoint + The model checkpoint. + + Returns + ------- + logits + The predictions logits. + """ + import numpy as np + + import morphoclass.models + + model_cls = getattr( + morphoclass.models, checkpoint["model_class"].rpartition(".")[2] + ) + model = model_cls(**checkpoint["model_params"]) + model.load_state_dict(checkpoint["all"]["model"]) + model.eval() + logits = [model(sample) for sample in dataset] + + return np.array(logits) + + +def predict_cnn(dataset, checkpoint): + """Compute predictions with a CNN classifier. + + Parameters + ---------- + dataset + The persistence image dataset. + checkpoint + The model checkpoint. + + Returns + ------- + logits + The predictions logits. + """ + import numpy as np + + import morphoclass.models + + # Model + model_cls = getattr( + morphoclass.models, checkpoint["model_class"].rpartition(".")[2] + ) + model = model_cls(**checkpoint["model_params"]) + model.load_state_dict(checkpoint["all"]["model"]) + + # Evaluation + logits = [model(sample.image).detach().numpy() for sample in dataset] + if len(logits) > 0: + logits = np.concatenate(logits) + else: + logits = np.array(logits) + + return logits + + +def predict_xgb(dataset, checkpoint): + """Compute predictions with XGBoost classifier. + + Parameters + ---------- + dataset + The morphology persistence image dataset. + checkpoint + The model checkpoint. + + Returns + ------- + predictions + The predictions. + """ + model = checkpoint["all"]["model"] + predictions = [ + model.predict(sample.image.numpy().reshape(1, 10000))[0] for sample in dataset + ] + return predictions diff --git a/src/morphoclass/console/main.py b/src/morphoclass/console/main.py index 886d2f8..cdb27f9 100644 --- a/src/morphoclass/console/main.py +++ b/src/morphoclass/console/main.py @@ -22,10 +22,12 @@ import morphoclass from morphoclass.console import cmd_evaluate from morphoclass.console import cmd_extract_features +from morphoclass.console import cmd_extract_features_and_predict from morphoclass.console import cmd_morphometrics from morphoclass.console import cmd_organise_dataset from morphoclass.console import cmd_performance_table from morphoclass.console import cmd_plot_dataset_stats +from morphoclass.console import cmd_predict from morphoclass.console import cmd_preprocess_dataset from morphoclass.console import cmd_train from morphoclass.console import cmd_xai @@ -137,9 +139,11 @@ def cli(verbose: int, log_file_path: pathlib.Path | None) -> None: cli.add_command(cmd_xai.cli) cli.add_command(cmd_organise_dataset.cli) cli.add_command(cmd_plot_dataset_stats.cli) +cli.add_command(cmd_predict.cli) cli.add_command(cmd_preprocess_dataset.cli) cli.add_command(cmd_train.cli) cli.add_command(cmd_evaluate.cli) cli.add_command(cmd_performance_table.cli) cli.add_command(cmd_extract_features.cli) +cli.add_command(cmd_extract_features_and_predict.cli) cli.add_command(cmd_morphometrics.cli) diff --git a/src/morphoclass/data/tns_dataset.py b/src/morphoclass/data/tns_dataset.py index 6be3bb2..c5a190c 100644 --- a/src/morphoclass/data/tns_dataset.py +++ b/src/morphoclass/data/tns_dataset.py @@ -135,9 +135,7 @@ def __init__( f"No data corresponding to layer {layer} found in data_path" ) else: - self.class_dict = { - n: m_type for n, m_type in enumerate(sorted(self.m_types)) - } + self.class_dict = dict(enumerate(sorted(self.m_types))) self.class_dict_inv = {v: k for k, v in self.class_dict.items()} self.distributions = input_distributions diff --git a/src/morphoclass/models/concatenet.py b/src/morphoclass/models/concatenet.py index 6ab610b..9953f16 100644 --- a/src/morphoclass/models/concatenet.py +++ b/src/morphoclass/models/concatenet.py @@ -45,7 +45,6 @@ class ConcateNet(nn.Module): """ def __init__(self, n_node_features, n_classes, n_features_perslay, bn=False): - super().__init__() self.n_node_features = n_node_features self.n_classes = n_classes diff --git a/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py b/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py index 54178cd..8279c1e 100644 --- a/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py +++ b/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py @@ -32,7 +32,6 @@ class AddRandomPointsToReductionMask: """ def __init__(self, n_points): - self.n_points = n_points @require_field("tmd_neurites_masks") diff --git a/src/morphoclass/vis.py b/src/morphoclass/vis.py index 008afef..2b7e45c 100644 --- a/src/morphoclass/vis.py +++ b/src/morphoclass/vis.py @@ -820,7 +820,7 @@ def plot_neurite( nx.draw( g, ax=ax, - pos={n: xy for n, xy in enumerate(zip(px, py))}, + pos=dict(enumerate(zip(px, py))), nodelist=[0], node_color="red", node_size=soma_size, diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 06043b4..dfb987d 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -158,7 +158,7 @@ def test_inter_rater_score(targets, predictions, kind, score): def test_inter_rater_score_fail(targets, predictions): - with pytest.raises(Exception): + with pytest.raises(ValueError): morphoclass.metrics.inter_rater_score(targets, predictions, kind="invalid")