@@ -611,6 +611,108 @@ def reduce_scatter_tensor_coalesced(
611611 )
612612
613613
614+ class _ParallelWork (Work ):
615+ def __init__ (self , works : List [Work ]) -> None :
616+ super ().__init__ ()
617+ self ._works = works
618+
619+ def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
620+ for work in self ._works :
621+ if timeout is not None :
622+ work .wait (timeout = timeout )
623+ else :
624+ work .wait ()
625+ return True
626+
627+ def get_future (self ) -> torch .futures .Future [object ]:
628+ futures = [work .get_future () for work in self ._works ]
629+ return torch .futures .collect_all (futures )
630+
631+
632+ class ParallelProcessGroup (ProcessGroupWrapper ):
633+ def __init__ (
634+ self ,
635+ base : ProcessGroupWrapper ,
636+ timeout : timedelta = timedelta (seconds = 60 ),
637+ count : int = 10 ,
638+ ) -> None :
639+ super ().__init__ (timeout = timeout )
640+
641+ self ._count = count
642+ self ._pgs = []
643+
644+ self ._create_pg = base ._create_pg
645+
646+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
647+ # abort if already initialized
648+ self .abort ()
649+
650+ for i in range (self ._count ):
651+ store = create_store_client (
652+ f"{ store_addr } /parallel{ i } " , timeout = self ._timeout
653+ )
654+
655+ self ._pgs .append (self ._create_pg (store , rank , world_size ))
656+
657+ self ._pg = self ._pgs [0 ]
658+
659+ def _split_tensors (self , tensors : List [torch .Tensor ]) -> List [List [torch .Tensor ]]:
660+ if not isinstance (tensors , (list , tuple )):
661+ tensors = [tensors ]
662+
663+ tensor_lists = [[] for _ in range (self ._count )]
664+ for t in tensors :
665+ chunks = torch .tensor_split (t .view (- 1 ), self ._count , dim = 0 )
666+ for i , chunk in enumerate (chunks ):
667+ tensor_lists [i ].append (chunk )
668+
669+ return tensor_lists
670+
671+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
672+ tensor_lists = self ._split_tensors (tensors )
673+
674+ with self ._run_context ():
675+ works = []
676+ for i in range (self ._count ):
677+ works .append (
678+ self ._pgs [i ].allreduce (tensor_lists [i ], self ._opts_hook (opts ))
679+ )
680+
681+ return self ._wrap_work (_ParallelWork (works ), opts )
682+
683+ def reduce (self , tensors : List [torch .Tensor ], dst : int , opts : object ) -> Work :
684+ tensor_lists = self ._split_tensors (tensors )
685+
686+ with self ._run_context ():
687+ works = []
688+ for i in range (self ._count ):
689+ works .append (
690+ self ._pgs [i ].reduce (tensor_lists [i ], dst , self ._opts_hook (opts ))
691+ )
692+
693+ return self ._wrap_work (_ParallelWork (works ), opts )
694+
695+ def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
696+ tensor_lists = self ._split_tensors (tensors )
697+
698+ with self ._run_context ():
699+ works = []
700+ for i in range (self ._count ):
701+ works .append (self ._pgs [i ].send (tensor_lists [i ], dst_rank , tag ))
702+
703+ return self ._wrap_work (_ParallelWork (works ), None )
704+
705+ def recv (self , tensors : List [torch .Tensor ], src_rank : int , tag : int ) -> Work :
706+ tensor_lists = self ._split_tensors (tensors )
707+
708+ with self ._run_context ():
709+ works = []
710+ for i in range (self ._count ):
711+ works .append (self ._pgs [i ].recv (tensor_lists [i ], src_rank , tag ))
712+
713+ return self ._wrap_work (_ParallelWork (works ), None )
714+
715+
614716class _WorkCUDATimeout (Work ):
615717 def __init__ (self , pg : ProcessGroup , work : Work , timeout : timedelta ) -> None :
616718 super ().__init__ ()
0 commit comments