Skip to content

Commit

Permalink
make input and output buffer dtypes same
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Mar 3, 2024
1 parent 9dd4845 commit e212829
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
10 changes: 9 additions & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,19 @@ def gather(
OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
ALL_GATHER_ITERATOR = None
ALL_GATHER_DTYPE = torch.bfloat16
ALL_GATHER_DTYPE = torch.float32
REDUCE_SCATTER_DTYPE = torch.bfloat16
handles = []
pending_grad_accumulations = []
weights_cache = {}

def set_all_gather_dtype(dtype):
global ALL_GATHER_DTYPE
ALL_GATHER_DTYPE = dtype

def set_reduce_scatter_dtype(dtype):
global REDUCE_SCATTER_DTYPE
REDUCE_SCATTER_DTYPE = dtype

def register_handle(handle):
# ToDo: This might be unnecesary since
Expand Down
16 changes: 10 additions & 6 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def _gather(input_, dim, process_group=None, cache=False):
input_ = input_.contiguous()
# Size and dimension.
rank = dist.get_rank(process_group)

from axonn.intra_layer import ALL_GATHER_DTYPE

tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)
dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
Expand All @@ -70,17 +71,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
assert input_.shape[dim] % total_chunks == 0
tensor_shape = list(input_.shape)
tensor_shape[dim] //= total_chunks

from axonn.intra_layer import REDUCE_SCATTER_DTYPE

output = torch.empty(
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device()
)

if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)
else:
handle = torch.distributed._reduce_scatter_base(
output, input_, group=process_group, async_op=overlap_comm
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
)

if overlap_comm:
Expand Down
1 change: 1 addition & 0 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
ctx.backward_comm_async = backward_comm_async
if not forward_comm_async:
output = input_.matmul(weight.t())

dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)
else:
assert input_.shape[0] % 2 == 0
Expand Down

0 comments on commit e212829

Please sign in to comment.