Skip to content

Commit

Permalink
Added 'best_chi2_worst_phi2' loss type
Browse files Browse the repository at this point in the history
  • Loading branch information
Cmurilochem committed Mar 4, 2024
1 parent 1fbabf9 commit fe55608
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
25 changes: 22 additions & 3 deletions validphys2/src/validphys/hyperoptplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def parse_statistics(trial):
# dict_out["std"] = std
dict_out["hlosses"] = convert_string_to_numpy(results["kfold_meta"]["hyper_losses"])
dict_out["vlosses"] = convert_string_to_numpy(results["kfold_meta"]["validation_losses"])
dict_out["hlosses_phi2"] = convert_string_to_numpy(results["kfold_meta"]["hyper_losses_phi2"])
return dict_out


Expand Down Expand Up @@ -383,6 +384,9 @@ def evaluate_trial(trial_dict, validation_multiplier, fail_threshold, loss_targe
test_loss = np.array(trial_dict["hlosses"]).max()
elif loss_target == "std":
test_loss = np.array(trial_dict["hlosses"]).std()
elif loss_target == "min_chi2_max_phi2":
test_loss = np.array(trial_dict["hlosses"]).mean()
phi2 = np.array(trial_dict["hlosses_phi2"]).mean()
loss = val_loss * validation_multiplier + test_loss * test_f

if (
Expand All @@ -396,6 +400,8 @@ def evaluate_trial(trial_dict, validation_multiplier, fail_threshold, loss_targe
loss *= 10

trial_dict["loss"] = loss
if loss_target == "min_chi2_max_phi2":
trial_dict["loss_inverse_phi2"] = np.reciprocal(phi2)


def generate_dictionary(
Expand Down Expand Up @@ -570,9 +576,22 @@ def hyperopt_dataframe(commandline_args):

# Now select the best one
best_idx = dataframe.loss.idxmin()
best_trial_series = dataframe.loc[best_idx]
# Make into a dataframe and transpose or the plotting code will complain
best_trial = best_trial_series.to_frame().T

if args.loss_target == "min_chi2_max_phi2":
minimum = dataframe.loss[best_idx]
std = np.std(dataframe.loss)
lim_max = dataframe.loss[best_idx] + std
# select rows with chi2 losses within the best point and lim_max
selected_chi2 = dataframe[(dataframe.loss >= minimum) & (dataframe.loss <= lim_max)]
# among the selected points, select the nth lowest in 1/phi2
selected_phi2 = selected_chi2.loss_inverse_phi2.nsmallest(args.max_phi2_n_models)
# find the location of these points in the dataframe
indices = dataframe[dataframe['loss_inverse_phi2'].isin(selected_phi2)].index
best_trial = dataframe.loc[indices]
else:
best_trial_series = dataframe.loc[best_idx]
# Make into a dataframe and transpose or the plotting code will complain
best_trial = best_trial_series.to_frame().T

log.info("Best setup:")
with pd.option_context("display.max_rows", None, "display.max_columns", None):
Expand Down
36 changes: 19 additions & 17 deletions validphys2/src/validphys/scripts/vp_hyperoptplot.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
from validphys.app import App
from validphys.loader import Loader, HyperscanNotFound
from validphys import hyperplottemplates
from reportengine.compat import yaml
import pwd
import logging
import os
import pwd

import logging
from reportengine.compat import yaml
from validphys import hyperplottemplates
from validphys.app import App
from validphys.loader import HyperscanNotFound, Loader

log = logging.getLogger(__name__)


class HyperoptPlotApp(App):
def add_positional_arguments(self, parser):
""" Wrapper around argumentparser """
"""Wrapper around argumentparser"""
# Hyperopt settings
parser.add_argument(
"hyperopt_name",
help="Folder of the hyperopt fit to generate the report for",
"hyperopt_name", help="Folder of the hyperopt fit to generate the report for"
)
parser.add_argument(
"-l",
"--loss_target",
help="Choice for the definition of target loss",
choices=['average', 'best_worst', 'std'],
choices=['average', 'best_worst', 'std', 'min_chi2_max_phi2'],
default='average',
)
parser.add_argument(
"--max_phi2_n_models",
help="If --loss_target=best_chi2_worst_phi2, outputs n models with the highest phi2.",
type=int,
default=1,
)
parser.add_argument(
"-v",
"--val_multiplier",
Expand Down Expand Up @@ -73,16 +78,12 @@ def add_positional_arguments(self, parser):
type=str,
default=pwd.getpwuid(os.getuid())[4].replace(",", ""),
)
parser.add_argument(
"--title",
help="Add custom title to the report's meta data",
type=str,
)
parser.add_argument("--title", help="Add custom title to the report's meta data", type=str)
parser.add_argument(
"--keywords",
help="Add keywords to the report's meta data. The keywords must be provided as a list",
type=list,
default=[]
default=[],
)
args = parser.parse_args()

Expand Down Expand Up @@ -127,7 +128,8 @@ def complete_mapping(self):
"combine": args["combine"],
"autofilter": args["autofilter"],
"debug": args["debug"],
"loss_target": args["loss_target"]
"loss_target": args["loss_target"],
"max_phi2_n_models": args["max_phi2_n_models"],
}

try:
Expand Down

0 comments on commit fe55608

Please sign in to comment.