diff --git a/docs/source/get_started/distributed_training.rst b/docs/source/get_started/distributed_training.rst index ed7e6960c..b9ecf3be5 100644 --- a/docs/source/get_started/distributed_training.rst +++ b/docs/source/get_started/distributed_training.rst @@ -121,21 +121,33 @@ In above example, you can create a new python file (e.g., `run_a.py`) on node A, nproc = 4, group_offset = 0 ) + + # Optional, only needed if you want to get the result of each process. + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update({ + "world_size": args.world_size, + "ip": args.ip, + "port": args.port, + "nproc": args.nproc, + "offset": args.group_offset, + }) + kwargs = { + "config_dict": config_dict, + "queue": queue, # Optional + } + mp.spawn( run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, + args=(args.model, args.dataset, args.config_file_list, kwargs), + nprocs=nproc, + join=True, ) + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + Then run the following command: @@ -159,21 +171,33 @@ Similarly, you can create a new python file (e.g., `run_b.py`) on node B, and wr nproc = 4, group_offset = 4 ) + + # Optional, only needed if you want to get the result of each process. + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update({ + "world_size": args.world_size, + "ip": args.ip, + "port": args.port, + "nproc": args.nproc, + "offset": args.group_offset, + }) + kwargs = { + "config_dict": config_dict, + "queue": queue, # Optional + } + mp.spawn( run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, + args=(args.model, args.dataset, args.config_file_list, kwargs), + nprocs=nproc, + join=True, ) + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + Then run the following command: diff --git a/recbole/quick_start/__init__.py b/recbole/quick_start/__init__.py index 58b937d6a..2fe193a15 100644 --- a/recbole/quick_start/__init__.py +++ b/recbole/quick_start/__init__.py @@ -1,4 +1,5 @@ from recbole.quick_start.quick_start import ( + run, run_recbole, objective_function, load_data_and_model, diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index a898584e2..0300703fe 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -12,20 +12,17 @@ ######################## """ import logging -from logging import getLogger - import sys +import torch.distributed as dist +from collections.abc import MutableMapping +from logging import getLogger - -import pickle from ray import tune from recbole.config import Config from recbole.data import ( create_dataset, data_preparation, - save_split_dataloaders, - load_split_dataloaders, ) from recbole.data.transform import construct_transform from recbole.utils import ( @@ -39,8 +36,69 @@ ) +def run( + model, + dataset, + config_file_list=None, + config_dict=None, + saved=True, + nproc=1, + world_size=-1, + ip="localhost", + port="5678", + group_offset=0, +): + if nproc == 1 and world_size <= 0: + res = run_recbole( + model=model, + dataset=dataset, + config_file_list=config_file_list, + config_dict=config_dict, + saved=saved, + ) + else: + if world_size == -1: + world_size = nproc + import torch.multiprocessing as mp + + # Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2 + # https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2 + queue = mp.get_context('spawn').SimpleQueue() + + config_dict = config_dict or {} + config_dict.update( + { + "world_size": world_size, + "ip": ip, + "port": port, + "nproc": nproc, + "offset": group_offset, + } + ) + kwargs = { + "config_dict": config_dict, + "queue": queue, + } + + mp.spawn( + run_recboles, + args=(model, dataset, config_file_list, kwargs), + nprocs=nproc, + join=True, + ) + + # Normally, there should be only one item in the queue + res = None if queue.empty() else queue.get() + return res + + def run_recbole( - model=None, dataset=None, config_file_list=None, config_dict=None, saved=True + model=None, + dataset=None, + config_file_list=None, + config_dict=None, + saved=True, + queue=None, ): r"""A fast running api, which includes the complete process of training and testing a model on a specified dataset @@ -51,6 +109,7 @@ def run_recbole( config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. saved (bool, optional): Whether to save the model. Defaults to ``True``. + queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``. """ # configurations initialization config = Config( @@ -104,27 +163,33 @@ def run_recbole( logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}") logger.info(set_color("test result", "yellow") + f": {test_result}") - return { + result = { "best_valid_score": best_valid_score, "valid_score_bigger": config["valid_metric_bigger"], "best_valid_result": best_valid_result, "test_result": test_result, } + if not config["single_spec"]: + dist.destroy_process_group() + + if config["local_rank"] == 0 and queue is not None: + queue.put(result) # for multiprocessing, e.g., mp.spawn + + return result # for the single process + def run_recboles(rank, *args): - ip, port, world_size, nproc, offset = args[3:] - args = args[:3] + kwargs = args[-1] + if not isinstance(kwargs, MutableMapping): + raise ValueError( + f"The last argument of run_recboles should be a dict, but got {type(kwargs)}" + ) + kwargs["config_dict"] = kwargs.get("config_dict", {}) + kwargs["config_dict"]["local_rank"] = rank run_recbole( - *args, - config_dict={ - "local_rank": rank, - "world_size": world_size, - "ip": ip, - "port": port, - "nproc": nproc, - "offset": offset, - }, + *args[:3], + **kwargs, ) diff --git a/run_recbole.py b/run_recbole.py index 09e0740e2..2badf308c 100644 --- a/run_recbole.py +++ b/run_recbole.py @@ -8,9 +8,8 @@ # @Email : chenyuwuxinn@gmail.com, houyupeng@ruc.edu.cn, zhlin@ruc.edu.cn import argparse -from ast import arg -from recbole.quick_start import run_recbole, run_recboles +from recbole.quick_start import run if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -44,26 +43,13 @@ args.config_files.strip().split(" ") if args.config_files else None ) - if args.nproc == 1 and args.world_size <= 0: - run_recbole( - model=args.model, dataset=args.dataset, config_file_list=config_file_list - ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) + run( + args.model, + args.dataset, + config_file_list=config_file_list, + nproc=args.nproc, + world_size=args.world_size, + ip=args.ip, + port=args.port, + group_offset=args.group_offset, + ) diff --git a/run_recbole_group.py b/run_recbole_group.py index 925f1d41a..2468b9577 100644 --- a/run_recbole_group.py +++ b/run_recbole_group.py @@ -4,41 +4,10 @@ import argparse -from ast import arg -from recbole.quick_start import run_recbole, run_recboles +from recbole.quick_start import run from recbole.utils import list_to_latex - -def run(args, model, config_file_list): - if args.nproc == 1 and args.world_size <= 0: - res = run_recbole( - model=model, - dataset=args.dataset, - config_file_list=config_file_list, - ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - res = mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) - return res - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -92,7 +61,16 @@ def run(args, model, config_file_list): valid_res_dict = {"Model": model} test_res_dict = {"Model": model} - result = run(args, model, config_file_list) + result = run( + model, + args.dataset, + config_file_list=config_file_list, + nproc=args.nproc, + world_size=args.world_size, + ip=args.ip, + port=args.port, + group_offset=args.group_offset, + ) valid_res_dict.update(result["best_valid_result"]) test_res_dict.update(result["test_result"]) bigger_flag = result["valid_score_bigger"] diff --git a/significance_test.py b/significance_test.py index bcd65c321..589883482 100644 --- a/significance_test.py +++ b/significance_test.py @@ -8,43 +8,41 @@ # @Email : import argparse -from ast import arg import random -import sys from collections import defaultdict -from scipy import stats - -from recbole.quick_start import run_recbole, run_recboles +from scipy import stats -def run(args, seed): - if args.nproc == 1 and args.world_size <= 0: - res = run_recbole( - model=args.model, - dataset=args.dataset, - config_file_list=config_file_list, +from recbole.quick_start import run + + +def run_test( + model, + dataset, + config_files, + seeds, + nproc, + world_size, + ip, + port, + group_offset, +): + results = defaultdict(list) + for seed in seeds: + res = run( + model, + dataset, + config_files, config_dict={"seed": seed}, + nproc=nproc, + world_size=world_size, + ip=ip, + port=port, + group_offset=group_offset, ) - else: - if args.world_size == -1: - args.world_size = args.nproc - import torch.multiprocessing as mp - - res = mp.spawn( - run_recboles, - args=( - args.model, - args.dataset, - config_file_list, - args.ip, - args.port, - args.world_size, - args.nproc, - args.group_offset, - ), - nprocs=args.nproc, - ) - return res + for _key, _value in res["test_result"].items(): + results[_key].append(_value) + return results if __name__ == "__main__": @@ -101,24 +99,30 @@ def run(args, seed): random.seed(args.st_seed) random_seeds = [random.randint(0, 2**32 - 1) for _ in range(args.run_times)] - result_ours = defaultdict(list) - result_baseline = defaultdict(list) - config_file_ours, config_file_baseline = config_file_list - args.model = args.model_ours - args.config_file_list = [result_ours] - for seed in random_seeds: - res = run(args, seed) - for key, value in res["test_result"].items(): - result_ours[key].append(value) - - args.model = args.model_baseline - args.config_file_list = [config_file_baseline] - for seed in random_seeds: - res = run(args, seed) - for key, value in res["test_result"].items(): - result_baseline[key].append(value) + result_ours = run_test( + args.model_ours, + args.dataset, + [config_file_ours], + random_seeds, + args.nproc, + args.world_size, + args.ip, + args.port, + args.group_offset, + ) + result_baseline = run_test( + args.model_baseline, + args.dataset, + [config_file_baseline], + random_seeds, + args.nproc, + args.world_size, + args.ip, + args.port, + args.group_offset, + ) final_result = {} for key, value in result_ours.items():