Skip to content

Commit

Permalink
Improve plot logs
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Oct 27, 2023
1 parent 7120d67 commit f7cdb34
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
4 changes: 2 additions & 2 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@ def _process_one_streamline_group(
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.debug(" *Remaining: {:,.0f} streamlines."
"".format(len(final_sft)))
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))

conn_matrix = None
conn_info = None
Expand Down
69 changes: 53 additions & 16 deletions scripts_python/dwiml_visualize_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
# -*- coding: utf-8 -*-
import argparse
from argparse import RawTextHelpFormatter
import csv
import itertools
import logging
import os
import pathlib
from typing import Dict, List
from typing import Dict

import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.cm as cmx
import numpy as np

from scilpy.io.utils import assert_outputs_exist, add_overwrite_arg

log_styles = ['-', '--', '_.', ':']
nb_styles = 4
exp_colors = ['b', 'r', 'g', 'k', 'o']
nb_colors = 5


def parse_args():
parser = argparse.ArgumentParser(description=str(parse_args.__doc__),
def _build_arg_parser():
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=RawTextHelpFormatter)
parser.add_argument("paths", type=str, nargs='+',
help="Path to the experiment folder (s). If more than "
Expand All @@ -30,13 +33,27 @@ def parse_args():
"If not set, all logs are shown separately.")
parser.add_argument("--nb_plots_per_fig", type=int, default=3,
help="Number of (rows) of plot per figure.")
parser.add_argument("--save_to_csv", metavar='my_file.csv',
help="If set, save the resulting logs as a csv file.")
parser.add_argument('--xlim', type=int,
help="Graph's xlim. Makes little sense with more than "
"one graph. Format: max_epoch ")
parser.add_argument('--ylims', type=float, nargs=2,
help="Graph's ylim. Makes little sense with more than "
"one graph. Format: ymin ymax ")

args = parser.parse_args()
return args
add_overwrite_arg(parser)
return parser


def visualize_logs(logs: Dict[str, Dict[str, np.ndarray]], graphs, nb_rows):
def visualize_logs(logs: Dict[str, Dict[str, np.ndarray]], graphs, nb_rows,
writer=None, xlim=None, ylims=None):
exp_names = list(logs.keys())
writer.writerow(['Experiment name', "Log name", "Epochs..."])

jet = plt.get_cmap('jet')
c_norm = colors.Normalize(vmin=0, vmax=len(exp_names))
scalar_map = cmx.ScalarMappable(norm=c_norm, cmap=jet)

nb_plots_left = len(graphs)
current_graph = -1
Expand All @@ -60,26 +77,35 @@ def visualize_logs(logs: Dict[str, Dict[str, np.ndarray]], graphs, nb_rows):

# For each experiment to show:
for exp, exp_name in enumerate(exp_names):
color = exp_colors[exp % nb_colors]
color_val = scalar_map.to_rgba(exp)
legend = exp_name
if len(keys) > 1:
legend += ', ' + key
if key in logs[exp_name]:
axs[i].plot(logs[exp_name][key], linestyle=style,
color=color, label=legend)
label=legend, color=color_val)

if writer is not None:
writer.writerow([exp_name, key] +
list(logs[exp_name][key]))

axs[i].legend()
if xlim is not None:
axs[i].set_xlim([0, xlim])
if ylims is not None:
axs[i].set_ylim(ylims)

nb_plots_left -= next_nb_plots

plt.tight_layout()
plt.show()


def main():
args = parse_args()
parser = _build_arg_parser()
args = parser.parse_args()

logging.getLogger().setLevel(level=logging.INFO)

assert_outputs_exist(parser, args, args.save_to_csv)

# One element per experiment
loaded_logs = {} # exp: dict of logs
user_required_names = set(itertools.chain(*args.graph)
Expand Down Expand Up @@ -110,7 +136,7 @@ def main():
if os.path.isfile(nn):
files_to_load.append(nn)
else:
print("File {} not found in path {}. Skipping."
print("File {}.npy not found in path {}. Skipping."
.format(n, log_path))
names_to_load = [n.stem for n in files_to_load]

Expand All @@ -127,7 +153,18 @@ def main():
graphs = list(graphs)
graphs = [[g] for g in graphs]

visualize_logs(loaded_logs, graphs, args.nb_plots_per_fig)
if args.save_to_csv:
print("Will save results in file {}".format(args.save_to_csv))
with open(args.save_to_csv, 'w', newline='') as file:
writer = csv.writer(file)
visualize_logs(loaded_logs, graphs, args.nb_plots_per_fig,
writer=writer, xlim=args.xlim, ylims=args.ylims)
else:
visualize_logs(loaded_logs, graphs, args.nb_plots_per_fig,
xlim=args.xlim, ylims=args.ylims)

plt.tight_layout()
plt.show()


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/tto_visualize_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
argv = sys.argv

parser = build_argparser_transformer_visu()
args = parser.parse_args()
args = parser._build_arg_parser()

assert_inputs_exist(parser, [args.hdf5_file, args.input_streamlines],
args.reference)
Expand Down
2 changes: 1 addition & 1 deletion scripts_python/ttst_visualize_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
argv = sys.argv

parser = build_argparser_transformer_visu()
args = parser.parse_args()
args = parser._build_arg_parser()

assert_inputs_exist(parser, [args.hdf5_file, args.input_streamlines],
args.reference)
Expand Down

0 comments on commit f7cdb34

Please sign in to comment.