Skip to content

Commit

Permalink
using TUTEL_GLOBAL_TIMEOUT_SEC to make NCCL timeout configurable (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant authored Apr 19, 2024
1 parent 6da6b52 commit 13c7a72
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import time
import torch
import logging
import datetime

from torch import Tensor
import torch.distributed as dist

from .jit_compiler import tutel_custom_kernel

TUTEL_GLOBAL_TIMEOUT_SEC = int(os.environ.get('TUTEL_GLOBAL_TIMEOUT_SEC', 86400))

def get_world_size(group=None):
try:
return dist.get_world_size(group)
Expand Down Expand Up @@ -48,13 +51,13 @@ def create_groups_from_world(group_count, include_init=None):
try:
if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ):
if include_init:
dist.init_process_group(backend=backend,
dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC),
init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')),
rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']))
dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
else:
if include_init:
dist.init_process_group(backend=backend)
dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC))
dist_local_rank = min(int(os.environ.get('LOCAL_RANK', 0)), torch.cuda.device_count() - 1)
glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank()
is_distributed = True
Expand Down Expand Up @@ -84,15 +87,15 @@ def dist_print(*args):
groups, inner_ranks = [], []
for gr in range(dist_group_size):
group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
groups += [dist.new_group(ranks=group_ranks, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC))]
inner_ranks += [group_ranks]
model_group = groups[dist_group_rank]

if dist_group_size != glob_world_size:
groups, outer_ranks = [], []
for gr in range(dist_world_size):
group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)]
groups += [dist.new_group(ranks=group_ranks)]
groups += [dist.new_group(ranks=group_ranks, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC))]
outer_ranks += [group_ranks]
data_group = groups[dist_world_rank]
else:
Expand Down

0 comments on commit 13c7a72

Please sign in to comment.