Skip to content

Commit

Permalink
reorg code and first implementation of the new easy API (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 18, 2024
1 parent 879ae57 commit 0b6328c
Show file tree
Hide file tree
Showing 31 changed files with 771 additions and 379 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
export G_inter=${{ matrix.ginter }}
export G_data=$(( 2 / G_inter ))
echo "training with G_inter = ${G_inter}, G_data = $(( 2 / G_inter )) ${{ matrix.memopt }}"
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
PYTHONPATH="." mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
Expand All @@ -46,13 +46,13 @@ jobs:
- name: Run intra-layer FC unit tests
run: |
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_fc.py
- name: Run intra-layer Conv unit tests
run: |
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_conv.py
- name: Run intra-layer Embedding unit tests
run: |
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k bw_pass
torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k fw_pass
#- name: Run intra-layer Conv unit tests
#run: |
#torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_conv.py
#- name: Run intra-layer Embedding unit tests
#run: |
#torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k bw_pass
#torchrun --nproc_per_node 2 --no_python python -m pytest ./axonn/tests/test_intra_layer_emb.py -k fw_pass
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
3 changes: 1 addition & 2 deletions axonn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# Copyright 2021-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# from . import models # noqa: F401
10 changes: 1 addition & 9 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# Copyright 2021-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -9,14 +9,6 @@
from .communication import communication_handle
import torch

try:
import mpi4py

MPI4PY = True
mpi4py.rc.initialize = False # do not initialize MPI automatically
except ImportError:
MPI4PY = False

# True when init has been called
is_initialized = False
# Communication handle for point-to-point (MPI) and collective (NCCL) communication
Expand Down
2 changes: 1 addition & 1 deletion axonn/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# Copyright 2022-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down
5 changes: 2 additions & 3 deletions axonn/communication.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# Copyright 2021-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os

try:
# from mpi4py import MPI
import mpi4py

MPI4PY = True
Expand Down Expand Up @@ -112,7 +111,7 @@ def __init__(
if not torch.distributed.is_initialized():
init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv("MASTER_PORT", "6000")
master_port = os.getenv("MASTER_PORT", "29500")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend="nccl",
Expand Down
2 changes: 1 addition & 1 deletion axonn/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Parallel Software and Systems Group, University of Maryland.
# Copyright 2021-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down
16 changes: 10 additions & 6 deletions axonn/inter_layer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright 2021-2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# from . import models # noqa: F401


from enum import Enum
from dataclasses import dataclass
from axonn import axonn as ax
from mpi4py import MPI
from axonn.intra_layer import (
sync_gradients_data_parallel,
sync_gradients_depth_parallel,
)
from axonn.intra_layer import sync_gradients

import torch
import numpy as np

Expand Down Expand Up @@ -418,8 +423,7 @@ def forward_backward_optimizer(
assert not eval_mode
post_bw_hook(self.model)

sync_gradients_depth_parallel(self.model, mean=True)
sync_gradients_data_parallel(self.model, mean=True)
sync_gradients(self.model, mean=True, expert_mode=True)
if self.computation_dtype == torch.float16:
global_overflow = self._unscale_gradients()
if not global_overflow:
Expand Down
Loading

0 comments on commit 0b6328c

Please sign in to comment.