Skip to content

Commit

Permalink
Modified training example to follow hyperparameter
Browse files Browse the repository at this point in the history
selection format of other examples.
  • Loading branch information
jikaelgagnon committed Mar 12, 2024
1 parent c42c473 commit 4c6bfcf
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 64 deletions.
1 change: 0 additions & 1 deletion equiadapt/nbody/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def canonicalize(self, nodes, loc, edges, vel, edge_attr, charges):
self.device = nodes.device

group_element_dict = self.get_groupelement(nodes, loc, edges, vel, edge_attr, charges)
rotation_matrix = group_element_dict["rotation_matrix"]
translation_vectors = group_element_dict["translation_vectors"]
rotation_matrix_inverse = group_element_dict["rotation_matrix_inverse"]

Expand Down
4 changes: 2 additions & 2 deletions examples/nbody/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class NBodyPipeline(pl.LightningModule):
def __init__(self, hyperparams: DictConfig):
super().__init__()
self.hyperparams = hyperparams
self.prediction_network = get_prediction_network(hyperparams)
canonicalization_network = get_canonicalization_network(hyperparams)
self.prediction_network = get_prediction_network(hyperparams.pred_hyperparams)
canonicalization_network = get_canonicalization_network(hyperparams.canon_hyperparams)

self.canonicalizer = ContinuousGroupNBody(canonicalization_network, hyperparams)

Expand Down
45 changes: 9 additions & 36 deletions examples/nbody/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,27 @@
from omegaconf import OmegaConf
import torch


def get_canonicalization_network(hyperparams):
architecture = hyperparams.canon_model_type
model_hyperparams = {
"num_layers": hyperparams.canon_num_layers,
"hidden_dim": hyperparams.canon_hidden_dim,
"layer_pooling": hyperparams.canon_layer_pooling,
"final_pooling": hyperparams.canon_final_pooling,
"out_dim": 4,
"batch_size": hyperparams.batch_size,
"nonlinearity": hyperparams.canon_nonlinearity,
"canon_feature": hyperparams.canon_feature,
"canon_translation": hyperparams.canon_translation,
"angular_feature": hyperparams.canon_angular_feature,
"dropout": hyperparams.canon_dropout,
}
model_hyperparams = OmegaConf.create(dict(model_hyperparams))

architecture = hyperparams.architecture
model_dict = {
#"EGNN": lambda: EGNN_vel(hyperparams),
"vndeepsets": lambda: VNDeepSets(model_hyperparams),
"vndeepsets": lambda: VNDeepSets(hyperparams),
}

if architecture not in model_dict:
raise ValueError(f'{architecture} is not implemented as prediction network for now.')

return model_dict[architecture]()


def get_prediction_network(hyperparams):
architecture = hyperparams.pred_model_type
model_hyperparams = {
"num_layers": hyperparams.num_layers,
"hidden_dim": hyperparams.hidden_dim,
"input_dim": hyperparams.input_dim,
"in_node_nf": hyperparams.in_node_nf,
"in_edge_nf": hyperparams.in_edge_nf,
}
model_hyperparams = OmegaConf.create(dict(model_hyperparams))
architecture = hyperparams.architecture
model_dict = {
"GNN": lambda: GNN(model_hyperparams),
"EGNN": lambda: EGNN_vel(model_hyperparams),
"vndeepsets": lambda: VNDeepSets(model_hyperparams),
"Transformer": lambda: Transformer(model_hyperparams),
"GNN": lambda: GNN(hyperparams),
"EGNN": lambda: EGNN_vel(hyperparams),
"vndeepsets": lambda: VNDeepSets(hyperparams),
"Transformer": lambda: Transformer(hyperparams),
}

if architecture not in model_dict:
raise ValueError(f'{architecture} is not implemented as prediction network for now.')

return model_dict[architecture]()

def get_edges(batch_size, n_nodes):
Expand Down
53 changes: 28 additions & 25 deletions examples/nbody/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,51 @@

HYPERPARAMS = {"model": "NBodyPipeline",
"canon_model_type": "vndeepsets",
"pred_model_type": "GNN",
"pred_model_type": "Transformer",
"batch_size": 100,
"dryrun": False,
"use_wandb": False,
"checkpoint": False,
"num_epochs": 1000,
"num_workers":0,
"auto_tune":False,
"seed": 0}

"seed": 0,
"learning_rate": 1e-3, #1e-3
"weight_decay": 1e-12,
"patience": 1000,
}

NBODY_HYPERPARAMS = {
"learning_rate": 1e-3, #1e-3
"weight_decay": 1e-12,
"patience": 1000,
"hidden_dim": 32, #32
"input_dim": 6,
"in_node_nf": 1,
"in_edge_nf": 2,
"num_layers": 4, #4
CANON_HYPERPARAMS = {
"architecture": "vndeepsets",
"num_layers": 4,
"hidden_dim": 16,
"layer_pooling": "mean",
"final_pooling": "mean",
"out_dim": 4,
"canon_num_layers": 4,
"canon_hidden_dim": 16,
"canon_layer_pooling": "mean",
"canon_final_pooling": "mean",
"canon_nonlinearity": "relu",
"batch_size": 100,
"nonlinearity": "relu",
"canon_feature": "p",
"canon_translation": False,
"canon_angular_feature": 0,
"canon_dropout": 0.5,
"freeze_canon": False,
"layer_pooling": "sum",
"final_pooling": "mean",
"nonlinearity": "relu",
"angular_feature": "pv",
"dropout": 0, #0
"dropout": 0.5,
}

PRED_HYPERPARAMS = {
"architecture": "Transformer",
"num_layers": 4,
"hidden_dim": 32,
"input_dim": 6,
"in_node_nf": 1,
"in_edge_nf": 2,
"nheads": 8,
"ff_hidden": 32
}

HYPERPARAMS["canon_hyperparams"] = CANON_HYPERPARAMS
HYPERPARAMS["pred_hyperparams"] = PRED_HYPERPARAMS

def train_nbody():
hyperparams = HYPERPARAMS | NBODY_HYPERPARAMS
hyperparams = HYPERPARAMS

if not hyperparams["use_wandb"]:
print('Wandb disable for logging.')
Expand Down

0 comments on commit 4c6bfcf

Please sign in to comment.