diff --git a/tutel/impls/communicate.py b/tutel/impls/communicate.py index 51ba737..4a326f3 100644 --- a/tutel/impls/communicate.py +++ b/tutel/impls/communicate.py @@ -8,12 +8,15 @@ import time import torch import logging +import datetime from torch import Tensor import torch.distributed as dist from .jit_compiler import tutel_custom_kernel +TUTEL_GLOBAL_TIMEOUT_SEC = int(os.environ.get('TUTEL_GLOBAL_TIMEOUT_SEC', 86400)) + def get_world_size(group=None): try: return dist.get_world_size(group) @@ -48,13 +51,13 @@ def create_groups_from_world(group_count, include_init=None): try: if ('LOCAL_RANK' not in os.environ) and ('OMPI_COMM_WORLD_SIZE' in os.environ): if include_init: - dist.init_process_group(backend=backend, + dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC), init_method='tcp://%s:%s' % (os.environ['MASTER_ADDR'], os.environ.get('MASTER_PORT', '23456')), rank=int(os.environ['OMPI_COMM_WORLD_RANK']), world_size=int(os.environ['OMPI_COMM_WORLD_SIZE'])) dist_local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) else: if include_init: - dist.init_process_group(backend=backend) + dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC)) dist_local_rank = min(int(os.environ.get('LOCAL_RANK', 0)), torch.cuda.device_count() - 1) glob_world_size, glob_world_rank = dist.get_world_size(), dist.get_rank() is_distributed = True @@ -84,7 +87,7 @@ def dist_print(*args): groups, inner_ranks = [], [] for gr in range(dist_group_size): group_ranks = [x for x in range(gr * dist_world_size, (gr + 1) * dist_world_size)] - groups += [dist.new_group(ranks=group_ranks)] + groups += [dist.new_group(ranks=group_ranks, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC))] inner_ranks += [group_ranks] model_group = groups[dist_group_rank] @@ -92,7 +95,7 @@ def dist_print(*args): groups, outer_ranks = [], [] for gr in range(dist_world_size): group_ranks = [x for x in range(gr, dist_world_size * dist_group_size, dist_world_size)] - groups += [dist.new_group(ranks=group_ranks)] + groups += [dist.new_group(ranks=group_ranks, timeout=datetime.timedelta(seconds=TUTEL_GLOBAL_TIMEOUT_SEC))] outer_ranks += [group_ranks] data_group = groups[dist_world_rank] else: