Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Jun 25, 2024
2 parents 3fa5775 + c001d79 commit 8e3de09
Show file tree
Hide file tree
Showing 19 changed files with 168 additions and 63 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: integration-test

on:
pull_request:
paths: # this will only trigger on changes to the demo directory, add more here if required
- 'src/fairchem/demo/**'
workflow_call:

jobs:
test:
runs-on: ubuntu-latest # TODO add macos tests too
strategy:
max-parallel: 10
matrix:
python_version: ['3.9', '3.11']

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python_version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}

- name: Cache pip
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
- name: Install core dependencies and package
run: |
python -m pip install --upgrade pip
if [ -f packages/requirements.txt ]; then pip install -r packages/requirements.txt; fi
if [ -f packages/requirements-optional.txt ]; then pip install -r packages/requirements-optional.txt; fi
pip install -e packages/fairchem-core[dev]
pip install -e packages/fairchem-data-oc[dev]
pip install -e packages/fairchem-demo-ocpapi[dev]
pip install -e packages/fairchem-applications-cattsunami
- name: Integration tests
run: | # skip-ocpapi-integration skips expensive tests with the tag "@pytest.mark.ocpapi_integration_test"
pytest tests/demo/ocpapi/tests/integration/ --skip-ocpapi-integration -c ./packages/fairchem-core/pyproject.toml
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: tests

on:
push:
workflow_call:
pull_request:
branches: [main]
workflow_dispatch:

jobs:
test:
Expand Down Expand Up @@ -40,7 +42,7 @@ jobs:
- name: Test core with pytest
run: |
pytest tests -vv --skip-ocpapi-integration --cov-report=xml --cov=fairchem -c ./packages/fairchem-core/pyproject.toml
pytest tests -vv --ignore=tests/demo/ocpapi/tests/integration/ --cov-report=xml --cov=fairchem -c ./packages/fairchem-core/pyproject.toml
- if: ${{ matrix.python_version == '3.11' }}
name: codecov-report
Expand Down
2 changes: 1 addition & 1 deletion packages/env.cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- pytorch-cluster
- ase
- e3nn>=0.5
- numpy>=1.25.0
- numpy >=1.25.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
Expand Down
2 changes: 1 addition & 1 deletion packages/env.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- pyg
- ase
- e3nn>=0.5
- numpy>=1.25.0
- numpy >=1.25.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
Expand Down
1 change: 0 additions & 1 deletion packages/fairchem-applications-cattsunami/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ description = "Accelerating Transition State Energy Calculations with Pre-traine
license = {text = "MIT License"}
dependencies = [
"torch>=2.2",
"numpy>=1.25.0",
"scipy",
"ase",
"networkx",
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dynamic = ["version", "readme"]
requires-python = ">=3.9, <3.13"
dependencies = [
"torch>=2.2",
"numpy>=1.25.0",
"numpy >=1.25.0, <2.0.0",
"lmdb",
"ase",
"pymatgen>=2023.10.3",
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-data-oc/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version", "readme"]
description = "Code for generating adsorbate-catalyst input configurations"
license = {text = "MIT License"}
dependencies = [
"numpy>=1.25.0",
"numpy >=1.25.0, <2.0.0",
"scipy",
"matplotlib",
"ase", # this was pinned to 3.22.1
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-demo-ocpapi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ documentation = "https://fair-chem.github.io/"

[project.optional-dependencies]
dev = [
"ase == 3.22.1",
"ase",
"readchar == 4.0.5",
]

Expand Down
6 changes: 3 additions & 3 deletions src/fairchem/applications/cattsunami/core/ocpneb.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def __init__(
del config["task"]["relax_dataset"]

self.trainer = registry.get_trainer_class(config.get("trainer", "ocp"))(
task=config["task"],
task=config.get("task", {}),
model=config["model"],
outputs={},
loss_fns={},
eval_metrics={},
loss_functions={},
evaluation_metrics={},
dataset=[config["dataset"]],
optimizer=config["optim"],
identifier="",
Expand Down
26 changes: 24 additions & 2 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common.flags import flags
from fairchem.core.common.utils import (
Expand Down Expand Up @@ -50,6 +51,10 @@ def checkpoint(self, *args, **kwargs):
return DelayedSubmission(new_runner, self.config)


def runner_wrapper(distributed: bool, config: dict):
Runner(distributed=distributed)(config)


def main():
"""Run the main fairchem program."""
setup_logging()
Expand Down Expand Up @@ -85,5 +90,22 @@ def main():
log_file = save_experiment_log(args, jobs, configs)
logging.info(f"Experiment log saved to: {log_file}")

else: # Run locally
Runner()(config)
else: # Run locally on a single node, n-processes
if args.distributed:
logging.info(f"Running in distributed local mode with {args.num_gpus} ranks")
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
# during torch multiprocessing https://github.com/pytorch/examples/issues/526
if "optim" in config and "num_workers" in config["optim"]:
config["optim"]["num_workers"] = 0
logging.info("WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526")

launch_config = LaunchConfig(min_nodes=1, max_nodes=1, nproc_per_node=args.num_gpus, rdzv_backend="c10d", max_restarts=0)
elastic_launch(launch_config, runner_wrapper)(args.distributed, config)
else:
logging.info("Running in non-distributed local mode")
assert args.num_gpus == 1, "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu"
runner_wrapper(args.distributed, config)

if __name__ == "__main__":
main()
9 changes: 2 additions & 7 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,8 @@ def setup(config) -> None:
init_method="env://",
)
else:
# try to read local rank from environment for newer torchrun
# otherwise use local-rank arg for torch.distributed
config["local_rank"] = os.environ.get("LOCAL_RANK", config["local_rank"])
dist.init_process_group(
backend=config["distributed_backend"], init_method="env://"
)
# TODO: SLURM
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend="nccl")


def cleanup() -> None:
Expand Down
18 changes: 5 additions & 13 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,33 +161,27 @@ def __init__(

# for checkpoints with relaxation datasets defined, remove to avoid
# unnecesarily trying to load that dataset
if "relax_dataset" in config["task"]:
if "relax_dataset" in config.get("task", {}):
del config["task"]["relax_dataset"]

# Calculate the edge indices on the fly
config["model"]["otf_graph"] = True

### backwards compatability with OCP v<2.0
### TODO: better format check for older configs
### Taken from base_trainer
if not config.get("loss_fns"):
logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)
config = update_config(config)
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint_path
del config["dataset"]["src"]

self.trainer = registry.get_trainer_class(config["trainer"])(
task=config["task"],
task=config.get("task", {}),
model=config["model"],
dataset=[config["dataset"]],
outputs=config["outputs"],
loss_fns=config["loss_fns"],
eval_metrics=config["eval_metrics"],
loss_functions=config["loss_functions"],
evaluation_metrics=config["evaluation_metrics"],
optimizer=config["optim"],
identifier="",
slurm=config.get("slurm", {}),
Expand Down Expand Up @@ -228,8 +222,6 @@ def load_checkpoint(
checkpoint_path: string
Path to trained model
"""
if checkpoint is None:
checkpoint = {}
try:
self.trainer.load_checkpoint(checkpoint_path, checkpoint)
except NotImplementedError:
Expand Down
28 changes: 23 additions & 5 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,8 +995,8 @@ class _TrainingContext:
outputs=config.get("outputs", {}),
dataset=config["dataset"],
optimizer=config["optim"],
loss_fns=config.get("loss_functions", {}),
eval_metrics=config.get("evaluation_metrics", {}),
loss_functions=config.get("loss_functions", {}),
evaluation_metrics=config.get("evaluation_metrics", {}),
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
Expand Down Expand Up @@ -1176,11 +1176,29 @@ def irreps_sum(ang_mom: int) -> int:

def update_config(base_config):
"""
Configs created prior to OCP 2.0 are organized a little different than they
Configs created prior to FAIRChem/OCP 2.0 are organized a little different than they
are now. Update old configs to fit the new expected structure.
"""
### TODO: better format check for older configs
# some configs have a loss_functions key with an empty dictionary, those need to be updated as well
if len(base_config.get("loss_functions", {})) > 0:
return base_config

logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)

# do we need a copy?
config = copy.deepcopy(base_config)

# initial fairchem/ocp 2.0 configs renamed loss_functions -> loss_fns and evaluation_metrics -> eval_metrics
# so some checkpoints may have configs in new format with the exception of renamed loss_funs and eval_metrics
if "loss_fns" in config:
config["loss_functions"] = config.pop("loss_fns")
if "eval_metrics" in config:
config["evaluation_metrics"] = config.pop("eval_metrics")
return config

# If config["dataset"]["format"] is missing, get it from the task (legacy location).
# If it is not there either, default to LMDB.
config["dataset"]["format"] = config["dataset"].get(
Expand Down Expand Up @@ -1289,8 +1307,8 @@ def update_config(base_config):
config["dataset"]["transforms"] = transforms

### Update config
config.update({"loss_fns": _loss_fns})
config.update({"eval_metrics": _eval_metrics})
config.update({"loss_functions": _loss_fns})
config.update({"evaluation_metrics": _eval_metrics})
config.update({"outputs": _outputs})

return config
Expand Down
23 changes: 9 additions & 14 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(
outputs,
dataset,
optimizer,
loss_fns,
eval_metrics,
loss_functions,
evaluation_metrics,
identifier: str,
timestamp_id: str | None = None,
run_dir: str | None = None,
Expand Down Expand Up @@ -109,8 +109,8 @@ def __init__(
"model_attributes": model,
"outputs": outputs,
"optim": optimizer,
"loss_fns": loss_fns,
"eval_metrics": eval_metrics,
"loss_functions": loss_functions,
"evaluation_metrics": evaluation_metrics,
"logger": logger,
"amp": amp,
"gpus": distutils.get_world_size() if not self.cpu else 0,
Expand Down Expand Up @@ -169,12 +169,7 @@ def __init__(
os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True)

### backwards compatability with OCP v<2.0
### TODO: better format check for older configs
if not self.config.get("loss_fns"):
logging.warning(
"Detected old config, converting to new format. Consider updating to avoid potential incompatibilities."
)
self.config = update_config(self.config)
self.config = update_config(self.config)

if distutils.is_master():
logging.info(yaml.dump(self.config, default_flow_style=False))
Expand Down Expand Up @@ -398,7 +393,7 @@ def load_task(self):
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
self.evaluation_metrics = self.config.get("eval_metrics", {})
self.evaluation_metrics = self.config.get("evaluation_metrics", {})
self.evaluator = Evaluator(
task=self.name,
eval_metrics=self.evaluation_metrics.get(
Expand Down Expand Up @@ -523,8 +518,8 @@ def load_checkpoint(
self.scaler.load_state_dict(checkpoint["amp"])

def load_loss(self) -> None:
self.loss_fns = []
for _idx, loss in enumerate(self.config["loss_fns"]):
self.loss_functions = []
for _idx, loss in enumerate(self.config["loss_functions"]):
for target in loss:
loss_name = loss[target].get("fn", "mae")
coefficient = loss[target].get("coefficient", 1)
Expand All @@ -539,7 +534,7 @@ def load_loss(self) -> None:

loss_fn = DDPLoss(loss_fn, loss_name, loss_reduction)

self.loss_fns.append(
self.loss_functions.append(
(target, {"fn": loss_fn, "coefficient": coefficient})
)

Expand Down
Loading

0 comments on commit 8e3de09

Please sign in to comment.