From ccd7f7ba2b85d3d3be3e99c72afe357a0103b4c1 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 19 Mar 2024 15:28:03 -0700 Subject: [PATCH 1/3] Updated training script for torch prescriptors to allow config --- .../eluc/prescriptors/nsga2/configs/test.json | 15 ++++++++ .../prescriptors/nsga2/torch_prescriptor.py | 2 ++ .../prescriptors/nsga2/train_prescriptors.py | 34 +++++++++++++------ 3 files changed, 40 insertions(+), 11 deletions(-) create mode 100644 use_cases/eluc/prescriptors/nsga2/configs/test.json diff --git a/use_cases/eluc/prescriptors/nsga2/configs/test.json b/use_cases/eluc/prescriptors/nsga2/configs/test.json new file mode 100644 index 0000000..454e89b --- /dev/null +++ b/use_cases/eluc/prescriptors/nsga2/configs/test.json @@ -0,0 +1,15 @@ +{ + "predictor_path": "predictors/neural_network/trained_models/no_overlap_nn", + "evolution_params": { + "pop_size": 100, + "n_generations": 100, + "p_mutation": 0.2, + "candidate_params": { + "in_size": 12, + "hidden_size": 16, + "out_size": 5 + }, + "seed_dir": "prescriptors/nsga2/seeds/small_sample" + }, + "save_path": "prescriptors/nsga2/trained_prescriptors/test" +} \ No newline at end of file diff --git a/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py b/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py index c698bd3..194e7ac 100644 --- a/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py +++ b/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py @@ -39,6 +39,8 @@ def __init__(self, self.n_generations = n_generations self.p_mutation = p_mutation self.seed_dir=seed_dir + if isinstance(self.seed_dir, str): + self.seed_dir = Path(self.seed_dir) self.eval_df = eval_df self.encoded_eval_df = encoder.encode_as_df(eval_df) diff --git a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py index 7d647b3..34ddcee 100644 --- a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py +++ b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py @@ -1,32 +1,44 @@ +""" +Script used to train NSGA-II prescriptors. +Requires a config file with the same fields as shown in the +test.json file in prescriptors/nsga2/configs +""" + +import argparse +import json from pathlib import Path -from data import constants from data.eluc_data import ELUCData from prescriptors.nsga2.torch_prescriptor import TorchPrescriptor from predictors.neural_network.neural_net_predictor import NeuralNetPredictor if __name__ == "__main__": + + # Load config + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, required=True) + args = parser.parse_args() + with open(Path(args.config_path), "r", encoding="utf-8") as f: + config = json.load(f) + print("Loading dataset...") dataset = ELUCData() + print("Loading predictor...") + # TODO: We need to make it so you can load any predictor here nnp = NeuralNetPredictor() - nnp.load("predictors/neural_network/trained_models/no_overlap_nn") + nnp_path = Path(config["predictor_path"]) + nnp.load(nnp_path) + print("Initializing prescription...") - candidate_params = {"in_size": len(constants.CAO_MAPPING["context"]), - "hidden_size": 16, - "out_size": len(constants.RECO_COLS)} tp = TorchPrescriptor( - pop_size=100, - n_generations=100, - p_mutation=0.2, eval_df=dataset.train_df.sample(frac=0.001, random_state=42), encoder=dataset.encoder, predictor=nnp, batch_size=4096, - candidate_params=candidate_params, - seed_dir=Path("prescriptors/nsga2/seeds/small_sample") + **config["evolution_params"] ) print("Training prescriptors...") - save_path = Path("prescriptors/nsga2/trained_prescriptors/test") + save_path = Path(config["save_path"]) final_pop = tp.neuroevolution(save_path) print("Done!") \ No newline at end of file From 5fb8c328bbb1862ea78d039653a6d3c2fd3bab99 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 19 Mar 2024 15:31:18 -0700 Subject: [PATCH 2/3] Moved handling of converting seed dir to path in train script --- use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py | 2 -- use_cases/eluc/prescriptors/nsga2/train_prescriptors.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py b/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py index 194e7ac..c698bd3 100644 --- a/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py +++ b/use_cases/eluc/prescriptors/nsga2/torch_prescriptor.py @@ -39,8 +39,6 @@ def __init__(self, self.n_generations = n_generations self.p_mutation = p_mutation self.seed_dir=seed_dir - if isinstance(self.seed_dir, str): - self.seed_dir = Path(self.seed_dir) self.eval_df = eval_df self.encoded_eval_df = encoder.encode_as_df(eval_df) diff --git a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py index 34ddcee..e1453aa 100644 --- a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py +++ b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py @@ -31,6 +31,8 @@ nnp.load(nnp_path) print("Initializing prescription...") + if "seed_dir" in config.keys(): + config["seed_dir"] = Path(config["seed_dir"]) tp = TorchPrescriptor( eval_df=dataset.train_df.sample(frac=0.001, random_state=42), encoder=dataset.encoder, From 7f5bdf924381c9fd02fe3c936244fa6aa2d9f338 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 19 Mar 2024 15:42:43 -0700 Subject: [PATCH 3/3] Fixed error in config seed dir Path conversion --- use_cases/eluc/prescriptors/nsga2/train_prescriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py index e1453aa..98c2947 100644 --- a/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py +++ b/use_cases/eluc/prescriptors/nsga2/train_prescriptors.py @@ -31,8 +31,8 @@ nnp.load(nnp_path) print("Initializing prescription...") - if "seed_dir" in config.keys(): - config["seed_dir"] = Path(config["seed_dir"]) + if "seed_dir" in config["evolution_params"].keys(): + config["evolution_params"]["seed_dir"] = Path(config["evolution_params"]["seed_dir"]) tp = TorchPrescriptor( eval_df=dataset.train_df.sample(frac=0.001, random_state=42), encoder=dataset.encoder,