Skip to content

Commit

Permalink
Merge pull request #75 from Project-Resilience/torch-config
Browse files Browse the repository at this point in the history
Updated training script for torch prescriptors to allow config
  • Loading branch information
danyoungday authored Mar 22, 2024
2 parents d5619d7 + c3b5b07 commit d0d3f41
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
15 changes: 15 additions & 0 deletions use_cases/eluc/prescriptors/nsga2/configs/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"predictor_path": "predictors/neural_network/trained_models/no_overlap_nn",
"evolution_params": {
"pop_size": 100,
"n_generations": 100,
"p_mutation": 0.2,
"candidate_params": {
"in_size": 12,
"hidden_size": 16,
"out_size": 5
},
"seed_dir": "prescriptors/nsga2/seeds/small_sample"
},
"save_path": "prescriptors/nsga2/trained_prescriptors/test"
}
34 changes: 22 additions & 12 deletions use_cases/eluc/prescriptors/nsga2/train_prescriptors.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
"""
Script to train the NSGA-II prescriptors.
Script used to train NSGA-II prescriptors.
Requires a config file with the same fields as shown in the
test.json file in prescriptors/nsga2/configs
"""
import argparse
import json
from pathlib import Path

from data import constants
from data.eluc_data import ELUCData
from prescriptors.nsga2.torch_prescriptor import TorchPrescriptor
from predictors.neural_network.neural_net_predictor import NeuralNetPredictor

if __name__ == "__main__":

# Load config
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True)
args = parser.parse_args()
with open(Path(args.config_path), "r", encoding="utf-8") as f:
config = json.load(f)

print("Loading dataset...")
dataset = ELUCData()

print("Loading predictor...")
# TODO: We need to make it so you can load any predictor here
nnp = NeuralNetPredictor()
nnp.load("predictors/neural_network/trained_models/no_overlap_nn")
nnp_path = Path(config["predictor_path"])
nnp.load(nnp_path)

print("Initializing prescription...")
candidate_params = {"in_size": len(constants.CAO_MAPPING["context"]),
"hidden_size": 16,
"out_size": len(constants.RECO_COLS)}
if "seed_dir" in config["evolution_params"].keys():
config["evolution_params"]["seed_dir"] = Path(config["evolution_params"]["seed_dir"])
tp = TorchPrescriptor(
pop_size=100,
n_generations=100,
p_mutation=0.2,
eval_df=dataset.train_df.sample(frac=0.001, random_state=42),
encoder=dataset.encoder,
predictor=nnp,
batch_size=4096,
candidate_params=candidate_params,
seed_dir=Path("prescriptors/nsga2/seeds/small_sample")
**config["evolution_params"]
)
print("Training prescriptors...")
save_path = Path("prescriptors/nsga2/trained_prescriptors/test")
save_path = Path(config["save_path"])
final_pop = tp.neuroevolution(save_path)
print("Done!")

0 comments on commit d0d3f41

Please sign in to comment.