Skip to content

Commit 7f61d0d

Browse files
authored
Merge pull request #197 from apax-hub/absl_logging_fix
Absl logging fix
2 parents b90fdb2 + a2c92c9 commit 7f61d0d

File tree

11 files changed

+40
-41
lines changed

11 files changed

+40
-41
lines changed

apax/cli/apax_app.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@ def train(
3434
train_config_path: Path = typer.Argument(
3535
..., help="Training configuration YAML file."
3636
),
37-
log_level: str = typer.Option("error", help="Sets the training logging level."),
38-
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
37+
log_level: str = typer.Option("info", help="Sets the training logging level."),
3938
):
4039
"""
4140
Starts the training of a model with parameters provided by a configuration file.
4241
"""
4342
from apax.train.run import run
4443

45-
run(train_config_path, log_file, log_level)
44+
run(train_config_path, log_level)
4645

4746

4847
@app.command()

apax/config/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
1818
config: Path to the config file or a dictionary
1919
containing the config.
2020
"""
21-
log.info("Loading user config")
2221
if isinstance(config, (str, os.PathLike)):
2322
with open(config, "r") as stream:
2423
config = yaml.safe_load(stream)

apax/config/train_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ def validate_shift_scale_methods(self):
112112

113113
return self
114114

115+
@property
115116
def model_version_path(self):
116117
version_path = Path(self.directory) / self.experiment
117118
return version_path
118119

120+
@property
119121
def best_model_path(self):
120-
return self.model_version_path() / "best"
122+
return self.model_version_path / "best"
121123

122124

123125
class ModelConfig(BaseModel, extra="forbid"):

apax/data/initialization.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import dataclasses
22
import logging
3-
import os
43
from typing import Optional
54

65
import numpy as np
@@ -19,7 +18,7 @@ class RawDataset:
1918
additional_labels: Optional[dict] = None
2019

2120

22-
def load_data_files(data_config, model_version_path):
21+
def load_data_files(data_config):
2322
log.info("Running Input Pipeline")
2423
if data_config.data_path is not None:
2524
log.info(f"Read data file {data_config.data_path}")
@@ -32,7 +31,7 @@ def load_data_files(data_config, model_version_path):
3231
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)
3332

3433
np.savez(
35-
os.path.join(model_version_path, "train_val_idxs"),
34+
data_config.model_version_path / "train_val_idxs",
3635
train_idxs=train_idxs,
3736
val_idxs=val_idxs,
3837
)

apax/md/nvt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
372372
disable_cell_list=True,
373373
)
374374

375-
_, params = restore_parameters(model_config.data.model_version_path())
375+
_, params = restore_parameters(model_config.data.model_version_path)
376376
params = canonicalize_energy_model_parameters(params)
377377
energy_fn = create_energy_fn(
378378
model.apply, params, system.atomic_numbers, system.box, model_config.n_models

apax/train/checkpoints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict:
122122

123123

124124
def load_params(model_version_path: Path, best=True) -> FrozenDict:
125+
model_version_path = Path(model_version_path)
125126
if best:
126127
model_version_path = model_version_path / "best"
127128
log.info(f"loading checkpoint from {model_version_path}")
@@ -142,7 +143,7 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
142143
"""Load the config and parameters of a single model
143144
"""
144145
model_config = parse_config(Path(model_dir) / "config.yaml")
145-
ckpt_dir = model_config.data.model_version_path()
146+
ckpt_dir = model_config.data.model_version_path
146147
return model_config, load_params(ckpt_dir)
147148

148149

apax/train/run.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
2-
import os
3-
from pathlib import Path
2+
import sys
43
from typing import List
54

65
import jax
@@ -32,7 +31,15 @@ def setup_logging(log_file, log_level):
3231
while len(logging.root.handlers) > 0:
3332
logging.root.removeHandler(logging.root.handlers[-1])
3433

35-
logging.basicConfig(filename=log_file, level=log_levels[log_level])
34+
# Remove uninformative checkpointing absl logs
35+
logging.getLogger("absl").setLevel(logging.WARNING)
36+
37+
logging.basicConfig(
38+
level=log_levels[log_level],
39+
format="%(levelname)s | %(asctime)s | %(message)s",
40+
datefmt="%H:%M:%S",
41+
handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)],
42+
)
3643

3744

3845
def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
@@ -43,26 +50,22 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
4350
return LossCollection(loss_funcs)
4451

4552

46-
def run(user_config, log_file="train.log", log_level="error"):
47-
setup_logging(log_file, log_level)
48-
log.info("Loading user config")
53+
def run(user_config, log_level="error"):
4954
config = parse_config(user_config)
5055

5156
seed_py_np_tf(config.seed)
5257
rng_key = jax.random.PRNGKey(config.seed)
5358

54-
experiment = Path(config.data.experiment)
55-
directory = Path(config.data.directory)
56-
model_version_path = directory / experiment
5759
log.info("Initializing directories")
58-
model_version_path.mkdir(parents=True, exist_ok=True)
59-
config.dump_config(model_version_path)
60+
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
61+
setup_logging(config.data.model_version_path / "train.log", log_level)
62+
config.dump_config(config.data.model_version_path)
6063

61-
callbacks = initialize_callbacks(config.callbacks, model_version_path)
64+
callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
6265
loss_fn = initialize_loss_fn(config.loss)
6366
Metrics = initialize_metrics(config.metrics)
6467

65-
train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
68+
train_raw_ds, val_raw_ds = load_data_files(config.data)
6669
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
6770
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)
6871

@@ -112,7 +115,7 @@ def run(user_config, log_file="train.log", log_level="error"):
112115
Metrics,
113116
callbacks,
114117
n_epochs,
115-
ckpt_dir=os.path.join(config.data.directory, config.data.experiment),
118+
ckpt_dir=config.data.model_version_path,
116119
ckpt_interval=config.checkpoints.ckpt_interval,
117120
val_ds=val_ds,
118121
sam_rho=config.optimizer.sam_rho,

apax/train/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def fit(
3131
log.info("Beginning Training")
3232
callbacks.on_train_begin()
3333

34-
latest_dir = ckpt_dir + "/latest"
35-
best_dir = ckpt_dir + "/best"
34+
latest_dir = ckpt_dir / "latest"
35+
best_dir = ckpt_dir / "best"
3636
ckpt_manager = CheckpointManager()
3737

3838
train_step, val_step = make_step_fns(

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,6 @@ def load_and_dump_config(config_path, dump_path):
147147
model_config_dict["data"]["directory"] = dump_path.as_posix()
148148

149149
model_config = Config.model_validate(model_config_dict)
150-
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
151-
model_config.dump_config(model_config.data.model_version_path())
150+
os.makedirs(model_config.data.model_version_path, exist_ok=True)
151+
model_config.dump_config(model_config.data.model_version_path)
152152
return model_config

tests/integration_tests/bal/test_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
2323
_, params = initialize_model(model_config, inputs)
2424

2525
ckpt = {"model": {"params": params}, "epoch": 0}
26-
best_dir = model_config.data.best_model_path()
26+
best_dir = model_config.data.best_model_path
2727
checkpoints.save_checkpoint(
2828
ckpt_dir=best_dir,
2929
target=ckpt,
@@ -41,7 +41,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
4141
bs = 5
4242

4343
selected_indices = kernel_selection(
44-
model_config.data.model_version_path(),
44+
model_config.data.model_version_path,
4545
train_atoms,
4646
pool_atoms,
4747
base_fm_options,

tests/integration_tests/md/test_md.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def test_run_md(get_tmp_path):
3434
md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/atoms.extxyz"
3535

3636
model_config = Config.model_validate(model_config_dict)
37-
os.makedirs(model_config.data.model_version_path())
38-
model_config.dump_config(model_config.data.model_version_path())
37+
os.makedirs(model_config.data.model_version_path)
38+
model_config.dump_config(model_config.data.model_version_path)
3939
md_config = MDConfig.model_validate(md_config_dict)
4040

4141
positions = jnp.array(
@@ -80,11 +80,8 @@ def test_run_md(get_tmp_path):
8080
)
8181

8282
ckpt = {"model": {"params": params}, "epoch": 0}
83-
best_dir = os.path.join(
84-
model_config.data.directory, model_config.data.experiment, "best"
85-
)
8683
checkpoints.save_checkpoint(
87-
ckpt_dir=best_dir,
84+
ckpt_dir=model_config.data.best_model_path,
8885
target=ckpt,
8986
step=0,
9087
overwrite=True,
@@ -106,8 +103,8 @@ def test_ase_calc(get_tmp_path):
106103
model_config_dict["data"]["directory"] = get_tmp_path.as_posix()
107104

108105
model_config = Config.model_validate(model_config_dict)
109-
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
110-
model_config.dump_config(model_config.data.model_version_path())
106+
os.makedirs(model_config.data.model_version_path, exist_ok=True)
107+
model_config.dump_config(model_config.data.model_version_path)
111108

112109
cell_size = 10.0
113110
positions = np.array(
@@ -147,17 +144,16 @@ def test_ase_calc(get_tmp_path):
147144
)
148145
ckpt = {"model": {"params": params}, "epoch": 0}
149146

150-
best_dir = model_config.data.best_model_path()
151147
checkpoints.save_checkpoint(
152-
ckpt_dir=best_dir,
148+
ckpt_dir=model_config.data.best_model_path,
153149
target=ckpt,
154150
step=0,
155151
overwrite=True,
156152
)
157153

158154
atoms = read(initial_structure_path.as_posix())
159155
calc = ASECalculator(
160-
[model_config.data.model_version_path(), model_config.data.model_version_path()]
156+
[model_config.data.model_version_path, model_config.data.model_version_path]
161157
)
162158

163159
atoms.calc = calc

0 commit comments

Comments
 (0)