From a3223d086c755b8684289f45686967fd445d7345 Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Mon, 11 Mar 2024 08:43:20 +0000 Subject: [PATCH] feat(test): support testing optcast ring-allreduce implementation Signed-off-by: Wataru Ishida --- test/run.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/test/run.py b/test/run.py index 8345780..da9ad30 100644 --- a/test/run.py +++ b/test/run.py @@ -8,8 +8,8 @@ import os import re import argparse +import time from functools import reduce -from dateutil import parser epoch = 0 @@ -186,6 +186,8 @@ def get_time(line): def analyze_optcast_client_log(filename, output, xlim=None, no_plot=False): + from dateutil import parser + global epoch epoch = 0 @@ -281,6 +283,8 @@ def get_time(line): def analyze_server_log(filename, output, xlim=None, no_plot=False): + from dateutil import parser + global epoch epoch = 0 @@ -467,14 +471,30 @@ async def client(args): rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if args.no_gpu: client_cmd = f"{args.shared_dir}/{SERVER_CMD}" - count = parse_chunksize(args.chunksize) // (4 if args.data_type == "f32" else 2) - cmd = f"{client_cmd} -c -a {args.reduction_servers} --count {count} --try-count 1000 --nreq 4" os.environ["RUST_LOG"] = "TRACE" if rank == 0 else "INFO" + + nreq = 4 + try_count = 1000 + count = parse_chunksize(args.chunksize) // ( + 4 if args.data_type == "f32" else 2 + ) + if args.type == "optcast": + cmd = f"{client_cmd} -c -a {args.reduction_servers} --count {count} --try-count {try_count} --nreq {nreq}" + elif args.type == "ring": + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + if args.nrank == 0: + args.nrank = len(config["clients"]) + clients = config["clients"][: args.nrank] + neighs = [(rank + 1 + i) % args.nrank for i in range(2)] + addrs = ",".join( + clients[j]["name"] + f":{8080+i}" for i, j in enumerate(neighs) + ) + cmd = f"{client_cmd} -a {addrs} --reduce-threads 2 --count {count} --try-count {try_count} --nrank {args.nrank} --ring-rank {rank+1} --nreq {nreq}" else: dt = "float" if args.data_type == "f32" else "half" client_cmd = f"{args.shared_dir}/{CLIENT_CMD}" cmd = f"{client_cmd} -d {dt} -e {args.size} -b {args.size} {args.nccl_test_options}" - # print(f"[{platform.node()}] client:", cmd, file=sys.stderr) os.environ["NCCL_DEBUG"] = "TRACE" if rank == 0 else "INFO" os.environ["NCCL_P2P_DISABLE"] = "1" @@ -558,7 +578,7 @@ async def run( clients = config["clients"][: args.nrank] if args.no_gpu: - if args.type not in ["optcast"]: + if args.type not in ["optcast", "ring"]: raise ValueError(f"no-gpu option doesn't work with {args.type}") reduction_servers = ",".join( @@ -574,7 +594,6 @@ async def run( f"--client --reduction-servers {reduction_servers}", ) ) - print("client:", cmd) client = await asyncio.create_subprocess_shell( cmd, stdout=subprocess.PIPE, @@ -688,7 +707,7 @@ def arguments(): parser.add_argument("--nsplit", default=1, type=int) parser.add_argument("--reduction-servers") parser.add_argument( - "--type", choices=["optcast", "sharp", "nccl"], default="optcast" + "--type", choices=["optcast", "sharp", "nccl", "ring"], default="optcast" ) parser.add_argument("--nccl-test-options", default="-c 1 -n 1 -w 1") parser.add_argument("--data-type", default="f32", choices=["f32", "f16"])