-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add wandb-sweep example and clean-up use of configs (#118)
* feat: added support for wandb logging (single runs, not sweeps yet) * feat: replaced hps provided as a dict by a Config() object in seh_frag_moo * feat: added config.desc * fix: added train/valid for wandb log * fix: allow JSON serialization of Enum objects * chore: tox * chore: replaced hps (dict) by Config() in all tasks. Moved qm9.py out of qm9/ * fix: changed default focus_region for frag_moo * fix: added assert to prevent inadvertently manipulating a Config rather than Config() object * removed cfg.use_wandb and simply test if wandb has been initialised in trainer * chore: adding comment * fix: typo * chore: added cfg.task.seh_moo.log_topk to de-clutter a bit * fix: added wandb to dependencies * minor: file name change for consistency * chore: centralised self.cfg.overwrite_existing_exp in GFNTrainer() (removed from all tasks to simplify mains) * feat: added hyperopt/wandb_demo * feat: removed wandb_agent_main.py to have the search and entrypoint defined in a single file * chore: tox * fix: minor in wandb_demo * fix: storage path * chore: tox
- Loading branch information
1 parent
74d6acc
commit 96dde6b
Showing
17 changed files
with
265 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Everything is contained in one file; `init_wandb_sweep.py` both defines the search space of the sweep and is the entrypoint of wandb agents. | ||
|
||
To launch the search: | ||
1. `python init_wandb_sweep.py` to intialize the sweep | ||
2. `sbatch launch_wandb_agents.sh <SWEEP_ID>` to schedule a jobarray in slurm which will launch wandb agents. | ||
The number of jobs in the sbatch file should reflect the size of the hyperparameter space that is being sweeped. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
import sys | ||
import time | ||
|
||
import wandb | ||
|
||
from gflownet.config import Config, init_empty | ||
from gflownet.tasks.seh_frag_moo import SEHMOOFragTrainer | ||
|
||
TIME = time.strftime("%m-%d-%H-%M") | ||
ENTITY = "valencelabs" | ||
PROJECT = "gflownet" | ||
SWEEP_NAME = f"{TIME}-sehFragMoo-Zlr-Zlrdecay" | ||
STORAGE_DIR = f"~/storage/wandb_sweeps/{SWEEP_NAME}" | ||
|
||
|
||
# Define the search space of the sweep | ||
sweep_config = { | ||
"name": SWEEP_NAME, | ||
"program": "init_wandb_sweep.py", | ||
"controller": { | ||
"type": "cloud", | ||
}, | ||
"method": "grid", | ||
"parameters": { | ||
"config.algo.tb.Z_learning_rate": {"values": [1e-4, 1e-3, 1e-2]}, | ||
"config.algo.tb.Z_lr_decay": {"values": [2_000, 50_000]}, | ||
}, | ||
} | ||
|
||
|
||
def wandb_config_merger(): | ||
config = init_empty(Config()) | ||
wandb_config = wandb.config | ||
|
||
# Set desired config values | ||
config.log_dir = f"{STORAGE_DIR}/{wandb.run.name}-id-{wandb.run.id}" | ||
config.print_every = 100 | ||
config.validate_every = 1000 | ||
config.num_final_gen_steps = 1000 | ||
config.num_training_steps = 40_000 | ||
config.pickle_mp_messages = True | ||
config.overwrite_existing_exp = False | ||
config.algo.sampling_tau = 0.95 | ||
config.algo.train_random_action_prob = 0.01 | ||
config.algo.tb.Z_learning_rate = 1e-3 | ||
config.task.seh_moo.objectives = ["seh", "qed"] | ||
config.cond.temperature.sample_dist = "constant" | ||
config.cond.temperature.dist_params = [60.0] | ||
config.cond.weighted_prefs.preference_type = "dirichlet" | ||
config.cond.focus_region.focus_type = None | ||
config.replay.use = False | ||
|
||
# Merge the wandb sweep config with the nested config from gflownet | ||
config.algo.tb.Z_learning_rate = wandb_config["config.algo.tb.Z_learning_rate"] | ||
config.algo.tb.Z_lr_decay = wandb_config["config.algo.tb.Z_lr_decay"] | ||
|
||
return config | ||
|
||
|
||
if __name__ == "__main__": | ||
# if there no arguments, initialize the sweep, otherwise this is a wandb agent | ||
if len(sys.argv) == 1: | ||
if os.path.exists(STORAGE_DIR): | ||
raise ValueError(f"Sweep storage directory {STORAGE_DIR} already exists.") | ||
|
||
wandb.sweep(sweep_config, entity=ENTITY, project=PROJECT) | ||
|
||
else: | ||
wandb.init(entity=ENTITY, project=PROJECT) | ||
config = wandb_config_merger() | ||
trial = SEHMOOFragTrainer(config) | ||
trial.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/bash | ||
|
||
# Purpose: Script to allocate a node and run a wandb sweep agent on it | ||
# Usage: sbatch launch_wandb_agent.sh <SWEEP_ID> | ||
|
||
#SBATCH --job-name=wandb_sweep_agent | ||
#SBATCH --array=1-6 | ||
#SBATCH --time=23:59:00 | ||
#SBATCH --output=slurm_output_files/%x_%N_%A_%a.out | ||
#SBATCH --gpus=1 | ||
#SBATCH --cpus-per-task=16 | ||
#SBATCH --mem=16GB | ||
#SBATCH --partition compute | ||
|
||
source activate gfn-py39-torch113 | ||
echo "Using environment={$CONDA_DEFAULT_ENV}" | ||
|
||
# launch wandb agent | ||
wandb agent --count 1 --entity valencelabs --project gflownet $1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.