diff --git a/equiadapt/nbody/canonicalization/continuous_group.py b/equiadapt/nbody/canonicalization/continuous_group.py index 50abb5d..37190e6 100644 --- a/equiadapt/nbody/canonicalization/continuous_group.py +++ b/equiadapt/nbody/canonicalization/continuous_group.py @@ -40,7 +40,6 @@ def canonicalize(self, nodes, loc, edges, vel, edge_attr, charges): self.device = nodes.device group_element_dict = self.get_groupelement(nodes, loc, edges, vel, edge_attr, charges) - rotation_matrix = group_element_dict["rotation_matrix"] translation_vectors = group_element_dict["translation_vectors"] rotation_matrix_inverse = group_element_dict["rotation_matrix_inverse"] diff --git a/examples/nbody/model.py b/examples/nbody/model.py index b5a64b2..bb1b4c4 100644 --- a/examples/nbody/model.py +++ b/examples/nbody/model.py @@ -11,8 +11,8 @@ class NBodyPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() self.hyperparams = hyperparams - self.prediction_network = get_prediction_network(hyperparams) - canonicalization_network = get_canonicalization_network(hyperparams) + self.prediction_network = get_prediction_network(hyperparams.pred_hyperparams) + canonicalization_network = get_canonicalization_network(hyperparams.canon_hyperparams) self.canonicalizer = ContinuousGroupNBody(canonicalization_network, hyperparams) diff --git a/examples/nbody/model_utils.py b/examples/nbody/model_utils.py index 9c27643..db0b6de 100644 --- a/examples/nbody/model_utils.py +++ b/examples/nbody/model_utils.py @@ -3,54 +3,27 @@ from omegaconf import OmegaConf import torch + def get_canonicalization_network(hyperparams): - architecture = hyperparams.canon_model_type - model_hyperparams = { - "num_layers": hyperparams.canon_num_layers, - "hidden_dim": hyperparams.canon_hidden_dim, - "layer_pooling": hyperparams.canon_layer_pooling, - "final_pooling": hyperparams.canon_final_pooling, - "out_dim": 4, - "batch_size": hyperparams.batch_size, - "nonlinearity": hyperparams.canon_nonlinearity, - "canon_feature": hyperparams.canon_feature, - "canon_translation": hyperparams.canon_translation, - "angular_feature": hyperparams.canon_angular_feature, - "dropout": hyperparams.canon_dropout, - } - model_hyperparams = OmegaConf.create(dict(model_hyperparams)) - + architecture = hyperparams.architecture model_dict = { - #"EGNN": lambda: EGNN_vel(hyperparams), - "vndeepsets": lambda: VNDeepSets(model_hyperparams), + "vndeepsets": lambda: VNDeepSets(hyperparams), } - if architecture not in model_dict: - raise ValueError(f'{architecture} is not implemented as prediction network for now.') - return model_dict[architecture]() def get_prediction_network(hyperparams): - architecture = hyperparams.pred_model_type - model_hyperparams = { - "num_layers": hyperparams.num_layers, - "hidden_dim": hyperparams.hidden_dim, - "input_dim": hyperparams.input_dim, - "in_node_nf": hyperparams.in_node_nf, - "in_edge_nf": hyperparams.in_edge_nf, - } - model_hyperparams = OmegaConf.create(dict(model_hyperparams)) + architecture = hyperparams.architecture model_dict = { - "GNN": lambda: GNN(model_hyperparams), - "EGNN": lambda: EGNN_vel(model_hyperparams), - "vndeepsets": lambda: VNDeepSets(model_hyperparams), - "Transformer": lambda: Transformer(model_hyperparams), + "GNN": lambda: GNN(hyperparams), + "EGNN": lambda: EGNN_vel(hyperparams), + "vndeepsets": lambda: VNDeepSets(hyperparams), + "Transformer": lambda: Transformer(hyperparams), } - if architecture not in model_dict: raise ValueError(f'{architecture} is not implemented as prediction network for now.') - + return model_dict[architecture]() def get_edges(batch_size, n_nodes): diff --git a/examples/nbody/train.py b/examples/nbody/train.py index 8471888..3c9b003 100644 --- a/examples/nbody/train.py +++ b/examples/nbody/train.py @@ -12,7 +12,7 @@ HYPERPARAMS = {"model": "NBodyPipeline", "canon_model_type": "vndeepsets", - "pred_model_type": "GNN", + "pred_model_type": "Transformer", "batch_size": 100, "dryrun": False, "use_wandb": False, @@ -20,40 +20,43 @@ "num_epochs": 1000, "num_workers":0, "auto_tune":False, - "seed": 0} - + "seed": 0, + "learning_rate": 1e-3, #1e-3 + "weight_decay": 1e-12, + "patience": 1000, +} -NBODY_HYPERPARAMS = { - "learning_rate": 1e-3, #1e-3 - "weight_decay": 1e-12, - "patience": 1000, - "hidden_dim": 32, #32 - "input_dim": 6, - "in_node_nf": 1, - "in_edge_nf": 2, - "num_layers": 4, #4 +CANON_HYPERPARAMS = { + "architecture": "vndeepsets", + "num_layers": 4, + "hidden_dim": 16, + "layer_pooling": "mean", + "final_pooling": "mean", "out_dim": 4, - "canon_num_layers": 4, - "canon_hidden_dim": 16, - "canon_layer_pooling": "mean", - "canon_final_pooling": "mean", - "canon_nonlinearity": "relu", + "batch_size": 100, + "nonlinearity": "relu", "canon_feature": "p", "canon_translation": False, - "canon_angular_feature": 0, - "canon_dropout": 0.5, - "freeze_canon": False, - "layer_pooling": "sum", - "final_pooling": "mean", - "nonlinearity": "relu", "angular_feature": "pv", - "dropout": 0, #0 + "dropout": 0.5, +} + +PRED_HYPERPARAMS = { + "architecture": "Transformer", + "num_layers": 4, + "hidden_dim": 32, + "input_dim": 6, + "in_node_nf": 1, + "in_edge_nf": 2, "nheads": 8, "ff_hidden": 32 } +HYPERPARAMS["canon_hyperparams"] = CANON_HYPERPARAMS +HYPERPARAMS["pred_hyperparams"] = PRED_HYPERPARAMS + def train_nbody(): - hyperparams = HYPERPARAMS | NBODY_HYPERPARAMS + hyperparams = HYPERPARAMS if not hyperparams["use_wandb"]: print('Wandb disable for logging.')