-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_arlbench.py
54 lines (40 loc) · 1.33 KB
/
run_arlbench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Console script for arlbench."""
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore")
import logging
import sys
import traceback
from typing import TYPE_CHECKING
import hydra
import jax
from arlbench.arlbench import run_arlbench
if TYPE_CHECKING:
from omegaconf import DictConfig
@hydra.main(version_base=None, config_path="examples/configs", config_name="base")
def execute(cfg: DictConfig):
"""Helper function for nice logging and error handling."""
logging.basicConfig(
filename="job.log", format="%(asctime)s %(message)s", filemode="w"
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if cfg.jax_enable_x64:
logger.info("Enabling x64 support for JAX.")
jax.config.update("jax_enable_x64", True)
try:
return run(cfg, logger)
except Exception:
traceback.print_exc(file=sys.stderr)
raise
def run(cfg: DictConfig, logger: logging.Logger):
"""Console script for arlbench."""
objectives = run_arlbench(cfg, logger=logger)
logger.info(f"Returned objectives: {objectives}")
with open("./performance.csv", "w+") as f:
f.write(str(objectives))
with open("./done.txt", "w+") as f:
f.write("yes")
return objectives
if __name__ == "__main__":
sys.exit(execute()) # pragma: no cover