Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Absl logging fix #197

Merged
merged 10 commits into from
Nov 14, 2023
5 changes: 2 additions & 3 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ def train(
train_config_path: Path = typer.Argument(
..., help="Training configuration YAML file."
),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts the training of a model with parameters provided by a configuration file.
"""
from apax.train.run import run

run(train_config_path, log_file, log_level)
run(train_config_path, log_level)


@app.command()
Expand Down
1 change: 0 additions & 1 deletion apax/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config: Path to the config file or a dictionary
containing the config.
"""
log.info("Loading user config")
if isinstance(config, (str, os.PathLike)):
with open(config, "r") as stream:
config = yaml.safe_load(stream)
Expand Down
4 changes: 3 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def validate_shift_scale_methods(self):

return self

@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path

@property
def best_model_path(self):
return self.model_version_path() / "best"
return self.model_version_path / "best"


class ModelConfig(BaseModel, extra="forbid"):
Expand Down
5 changes: 2 additions & 3 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import logging
import os
from typing import Optional

import numpy as np
Expand All @@ -19,7 +18,7 @@ class RawDataset:
additional_labels: Optional[dict] = None


def load_data_files(data_config, model_version_path):
def load_data_files(data_config):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
Expand All @@ -32,7 +31,7 @@ def load_data_files(data_config, model_version_path):
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)

np.savez(
os.path.join(model_version_path, "train_val_idxs"),
data_config.model_version_path / "train_val_idxs",
train_idxs=train_idxs,
val_idxs=val_idxs,
)
Expand Down
2 changes: 1 addition & 1 deletion apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
disable_cell_list=True,
)

_, params = restore_parameters(model_config.data.model_version_path())
_, params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(params)
energy_fn = create_energy_fn(
model.apply, params, system.atomic_numbers, system.box, model_config.n_models
Expand Down
3 changes: 2 additions & 1 deletion apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict:


def load_params(model_version_path: Path, best=True) -> FrozenDict:
model_version_path = Path(model_version_path)
if best:
model_version_path = model_version_path / "best"
log.info(f"loading checkpoint from {model_version_path}")
Expand All @@ -142,7 +143,7 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""Load the config and parameters of a single model
"""
model_config = parse_config(Path(model_dir) / "config.yaml")
ckpt_dir = model_config.data.model_version_path()
ckpt_dir = model_config.data.model_version_path
return model_config, load_params(ckpt_dir)


Expand Down
31 changes: 17 additions & 14 deletions apax/train/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
from pathlib import Path
import sys
from typing import List

import jax
Expand Down Expand Up @@ -32,7 +31,15 @@ def setup_logging(log_file, log_level):
while len(logging.root.handlers) > 0:
logging.root.removeHandler(logging.root.handlers[-1])

logging.basicConfig(filename=log_file, level=log_levels[log_level])
# Remove uninformative checkpointing absl logs
logging.getLogger("absl").setLevel(logging.WARNING)

logging.basicConfig(
level=log_levels[log_level],
format="%(levelname)s | %(asctime)s | %(message)s",
datefmt="%H:%M:%S",
handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)],
)


def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
Expand All @@ -43,26 +50,22 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
return LossCollection(loss_funcs)


def run(user_config, log_file="train.log", log_level="error"):
setup_logging(log_file, log_level)
log.info("Loading user config")
def run(user_config, log_level="error"):
config = parse_config(user_config)

seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)

experiment = Path(config.data.experiment)
directory = Path(config.data.directory)
model_version_path = directory / experiment
log.info("Initializing directories")
model_version_path.mkdir(parents=True, exist_ok=True)
config.dump_config(model_version_path)
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)

callbacks = initialize_callbacks(config.callbacks, model_version_path)
callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_raw_ds, val_raw_ds = load_data_files(config.data)
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)

Expand Down Expand Up @@ -112,7 +115,7 @@ def run(user_config, log_file="train.log", log_level="error"):
Metrics,
callbacks,
n_epochs,
ckpt_dir=os.path.join(config.data.directory, config.data.experiment),
ckpt_dir=config.data.model_version_path,
ckpt_interval=config.checkpoints.ckpt_interval,
val_ds=val_ds,
sam_rho=config.optimizer.sam_rho,
Expand Down
4 changes: 2 additions & 2 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def fit(
log.info("Beginning Training")
callbacks.on_train_begin()

latest_dir = ckpt_dir + "/latest"
best_dir = ckpt_dir + "/best"
latest_dir = ckpt_dir / "latest"
best_dir = ckpt_dir / "best"
ckpt_manager = CheckpointManager()

train_step, val_step = make_step_fns(
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,6 @@ def load_and_dump_config(config_path, dump_path):
model_config_dict["data"]["directory"] = dump_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)
return model_config
4 changes: 2 additions & 2 deletions tests/integration_tests/bal/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
_, params = initialize_model(model_config, inputs)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = model_config.data.best_model_path()
best_dir = model_config.data.best_model_path
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
target=ckpt,
Expand All @@ -41,7 +41,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
bs = 5

selected_indices = kernel_selection(
model_config.data.model_version_path(),
model_config.data.model_version_path,
train_atoms,
pool_atoms,
base_fm_options,
Expand Down
18 changes: 7 additions & 11 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_run_md(get_tmp_path):
md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/atoms.extxyz"

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path())
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path)
model_config.dump_config(model_config.data.model_version_path)
md_config = MDConfig.model_validate(md_config_dict)

positions = jnp.array(
Expand Down Expand Up @@ -80,11 +80,8 @@ def test_run_md(get_tmp_path):
)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = os.path.join(
model_config.data.directory, model_config.data.experiment, "best"
)
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
Expand All @@ -106,8 +103,8 @@ def test_ase_calc(get_tmp_path):
model_config_dict["data"]["directory"] = get_tmp_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array(
Expand Down Expand Up @@ -147,17 +144,16 @@ def test_ase_calc(get_tmp_path):
)
ckpt = {"model": {"params": params}, "epoch": 0}

best_dir = model_config.data.best_model_path()
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
)

atoms = read(initial_structure_path.as_posix())
calc = ASECalculator(
[model_config.data.model_version_path(), model_config.data.model_version_path()]
[model_config.data.model_version_path, model_config.data.model_version_path]
)

atoms.calc = calc
Expand Down