Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Absl logging fix #197

Merged
merged 10 commits into from
Nov 14, 2023
5 changes: 2 additions & 3 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -34,15 +34,14 @@ def train(
train_config_path: Path = typer.Argument(
..., help="Training configuration YAML file."
),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts the training of a model with parameters provided by a configuration file.
"""
from apax.train.run import run

run(train_config_path, log_file, log_level)
run(train_config_path, log_level)


@app.command()
1 change: 0 additions & 1 deletion apax/config/common.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config: Path to the config file or a dictionary
containing the config.
"""
log.info("Loading user config")
if isinstance(config, (str, os.PathLike)):
with open(config, "r") as stream:
config = yaml.safe_load(stream)
4 changes: 3 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -112,12 +112,14 @@ def validate_shift_scale_methods(self):

return self

@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path

@property
def best_model_path(self):
return self.model_version_path() / "best"
return self.model_version_path / "best"


class ModelConfig(BaseModel, extra="forbid"):
5 changes: 2 additions & 3 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import logging
import os
from typing import Optional

import numpy as np
@@ -19,7 +18,7 @@ class RawDataset:
additional_labels: Optional[dict] = None


def load_data_files(data_config, model_version_path):
def load_data_files(data_config):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
@@ -32,7 +31,7 @@ def load_data_files(data_config, model_version_path):
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)

np.savez(
os.path.join(model_version_path, "train_val_idxs"),
data_config.model_version_path / "train_val_idxs",
train_idxs=train_idxs,
val_idxs=val_idxs,
)
2 changes: 1 addition & 1 deletion apax/md/nvt.py
Original file line number Diff line number Diff line change
@@ -372,7 +372,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
disable_cell_list=True,
)

_, params = restore_parameters(model_config.data.model_version_path())
_, params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(params)
energy_fn = create_energy_fn(
model.apply, params, system.atomic_numbers, system.box, model_config.n_models
3 changes: 2 additions & 1 deletion apax/train/checkpoints.py
Original file line number Diff line number Diff line change
@@ -122,6 +122,7 @@ def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict:


def load_params(model_version_path: Path, best=True) -> FrozenDict:
model_version_path = Path(model_version_path)
if best:
model_version_path = model_version_path / "best"
log.info(f"loading checkpoint from {model_version_path}")
@@ -142,7 +143,7 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""Load the config and parameters of a single model
"""
model_config = parse_config(Path(model_dir) / "config.yaml")
ckpt_dir = model_config.data.model_version_path()
ckpt_dir = model_config.data.model_version_path
return model_config, load_params(ckpt_dir)


31 changes: 17 additions & 14 deletions apax/train/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
from pathlib import Path
import sys
from typing import List

import jax
@@ -32,7 +31,15 @@ def setup_logging(log_file, log_level):
while len(logging.root.handlers) > 0:
logging.root.removeHandler(logging.root.handlers[-1])

logging.basicConfig(filename=log_file, level=log_levels[log_level])
# Remove uninformative checkpointing absl logs
logging.getLogger("absl").setLevel(logging.WARNING)

logging.basicConfig(
level=log_levels[log_level],
format="%(levelname)s | %(asctime)s | %(message)s",
datefmt="%H:%M:%S",
handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)],
)


def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
@@ -43,26 +50,22 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
return LossCollection(loss_funcs)


def run(user_config, log_file="train.log", log_level="error"):
setup_logging(log_file, log_level)
log.info("Loading user config")
def run(user_config, log_level="error"):
config = parse_config(user_config)

seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)

experiment = Path(config.data.experiment)
directory = Path(config.data.directory)
model_version_path = directory / experiment
log.info("Initializing directories")
model_version_path.mkdir(parents=True, exist_ok=True)
config.dump_config(model_version_path)
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)

callbacks = initialize_callbacks(config.callbacks, model_version_path)
callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_raw_ds, val_raw_ds = load_data_files(config.data)
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)

@@ -112,7 +115,7 @@ def run(user_config, log_file="train.log", log_level="error"):
Metrics,
callbacks,
n_epochs,
ckpt_dir=os.path.join(config.data.directory, config.data.experiment),
ckpt_dir=config.data.model_version_path,
ckpt_interval=config.checkpoints.ckpt_interval,
val_ds=val_ds,
sam_rho=config.optimizer.sam_rho,
4 changes: 2 additions & 2 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
@@ -31,8 +31,8 @@ def fit(
log.info("Beginning Training")
callbacks.on_train_begin()

latest_dir = ckpt_dir + "/latest"
best_dir = ckpt_dir + "/best"
latest_dir = ckpt_dir / "latest"
best_dir = ckpt_dir / "best"
ckpt_manager = CheckpointManager()

train_step, val_step = make_step_fns(
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -147,6 +147,6 @@ def load_and_dump_config(config_path, dump_path):
model_config_dict["data"]["directory"] = dump_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)
return model_config
4 changes: 2 additions & 2 deletions tests/integration_tests/bal/test_api.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
_, params = initialize_model(model_config, inputs)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = model_config.data.best_model_path()
best_dir = model_config.data.best_model_path
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
target=ckpt,
@@ -41,7 +41,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
bs = 5

selected_indices = kernel_selection(
model_config.data.model_version_path(),
model_config.data.model_version_path,
train_atoms,
pool_atoms,
base_fm_options,
18 changes: 7 additions & 11 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
@@ -34,8 +34,8 @@ def test_run_md(get_tmp_path):
md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/atoms.extxyz"

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path())
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path)
model_config.dump_config(model_config.data.model_version_path)
md_config = MDConfig.model_validate(md_config_dict)

positions = jnp.array(
@@ -80,11 +80,8 @@ def test_run_md(get_tmp_path):
)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = os.path.join(
model_config.data.directory, model_config.data.experiment, "best"
)
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
@@ -106,8 +103,8 @@ def test_ase_calc(get_tmp_path):
model_config_dict["data"]["directory"] = get_tmp_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array(
@@ -147,17 +144,16 @@ def test_ase_calc(get_tmp_path):
)
ckpt = {"model": {"params": params}, "epoch": 0}

best_dir = model_config.data.best_model_path()
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
)

atoms = read(initial_structure_path.as_posix())
calc = ASECalculator(
[model_config.data.model_version_path(), model_config.data.model_version_path()]
[model_config.data.model_version_path, model_config.data.model_version_path]
)

atoms.calc = calc