Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support testing optcast ring-allreduce implementation #6

Merged
merged 4 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion reduction_server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ pub(crate) fn client(args: Args) {
.split(',')
.map(|addr| {
info!("connecting to {}", addr);
let mut stream = TcpStream::connect(addr).expect("Could not connect to server");
let mut stream = loop {
let res = TcpStream::connect(&addr);
if res.is_ok() {
break res.unwrap();
}
// sleep 1s
std::thread::sleep(std::time::Duration::from_secs(1));
};

let comms = (0..args.nchannel)
.map(|_| {
Expand Down
226 changes: 187 additions & 39 deletions test/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import re
import argparse
import datetime
import time
from functools import reduce

epoch = 0
Expand Down Expand Up @@ -79,7 +79,14 @@ def plot(data, classes, colormapping, xlim, output):
plt.savefig(output)


def analyze_client_log(filename, output, xlim=None, no_plot=False):
def analyze_client_log(filename, output, no_gpu, xlim=None, no_plot=False):
if no_gpu:
return analyze_optcast_client_log(filename, output, xlim, no_plot)
else:
return analyze_nccl_tests_log(filename, output, xlim, no_plot)


def analyze_nccl_tests_log(filename, output, xlim=None, no_plot=False):
global epoch
epoch = 0

Expand Down Expand Up @@ -178,7 +185,106 @@ def get_time(line):
return avgBusBw


def analyze_optcast_client_log(filename, output, xlim=None, no_plot=False):
from dateutil import parser

global epoch
epoch = 0

with open(filename) as f:
log = f.read()

data = []
classes = {}
stats = {}

def get_time(line):
global epoch

t = parser.parse(line.split(" ")[1][1:])
# datetime to unixnano
ts = t.timestamp() * 1000

if epoch == 0:
epoch = ts
return float(ts - epoch)

avgBusBw = 0

for line in log.split("\n"):
if line.startswith("#") or line.endswith("#") or "bandwidth" in line:
print(line)
if "Avg bus bandwidth" in line:
avgBusBw = float(line.strip().split(" ")[-1])

r = re.search(r"task_id: (?P<tid>\d+)", line)
tid = int(r.group("tid")) if r else None

r = re.search(r"idx: (?P<idx>\d+)", line)
jid = int(r.group("idx")) if r else None

r = re.search(r"j: (?P<j>\d+)", line)
j = int(r.group("j")) if r else None

start = False
end = False

if "start" in line:
start = True
elif "done" in line:
end = True

if "reduce" in line:
name = f"reduce({tid}/{jid})" if tid != None else f"reduce({jid})"
statname = "reduce"
elif "recv" in line:
name = f"recv({jid}/{tid}/{j})" if j != None else f"recv({jid})"
statname = "recv"
elif "send" in line:
name = f"send({jid}/{tid}/{j})" if j != None else f"send({jid})"
statname = "send"

if start:
if name not in classes:
classes[name] = {}
classes[name] = get_time(line)
elif end:
e = get_time(line)
data.append((classes[name], e, name, name))
stat = stats.get(statname, [])
stat.append(e - classes[name])
stats[statname] = stat

# cut off first-half of the data to remove warmup phase
for k, v in stats.items():
stats[k] = v[len(v) // 2 :]

print("client stats:")
for k, v in stats.items():
show_stats(" " + k, v)
print("")

start = (v[0] for v in data if v[-1] == "recv(0)")
end = (v[1] for v in data if v[-1] == "send(11)")

# zip start and end
latency = sorted(e - s for (s, e) in zip(start, end))
if len(latency) > 0:
latency = sum(latency[: len(latency) // 2]) / (len(latency) // 2)
print(f"avg latency: {latency:.2f} ms")

if no_plot:
return avgBusBw

colormapping = {k: f"C{i}" for i, k in enumerate(sorted(classes))}
plot(data, classes, colormapping, xlim, output)

return avgBusBw


def analyze_server_log(filename, output, xlim=None, no_plot=False):
from dateutil import parser

global epoch
epoch = 0

Expand All @@ -195,7 +301,7 @@ def analyze_server_log(filename, output, xlim=None, no_plot=False):
def get_time(line):
global epoch

t = datetime.datetime.fromisoformat(line.split(" ")[1][1:])
t = parser.parse(line.split(" ")[1][1:])
# datetime to unixnano
ts = t.timestamp() * 1000

Expand Down Expand Up @@ -279,6 +385,20 @@ def get_shared_dir():
return os.path.dirname(os.path.dirname(path))


def gen_args(args):
ret = []
for k, v in args._get_kwargs():
if not v:
continue
if type(v) == bool and v == True:
ret.append(f"--{k.replace('_', '-')}")
else:
if type(v) == str and " " in v:
v = f'"{v}"'
ret.append(f"--{k.replace('_', '-')} {v}")
return " ".join(ret)


def parse_chunksize(chunksize):
if chunksize.endswith("M"):
return int(chunksize[:-1]) * 1024 * 1024
Expand All @@ -301,7 +421,7 @@ async def server(args):
// args.nsplit
)

env = server["env"]
env = server.get("env", {})
if "NCCL_DEBUG" not in env:
env["NCCL_DEBUG"] = "TRACE" if rank == 0 else "INFO"
if "RUST_LOG" not in env:
Expand Down Expand Up @@ -349,27 +469,49 @@ async def server(args):

async def client(args):
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
dt = "float" if args.data_type == "f32" else "half"
client_cmd = f"{args.shared_dir}/{CLIENT_CMD}"
cmd = f"{client_cmd} -d {dt} -e {args.size} -b {args.size} {args.nccl_test_options}"
# print(f"[{platform.node()}] client:", cmd, file=sys.stderr)

os.environ["NCCL_DEBUG"] = "TRACE" if rank == 0 else "INFO"
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_SHM_DISABLE"] = "1"

if args.type == "optcast":
os.environ["NCCL_COLLNET_ENABLE"] = "1"
os.environ[
"LD_LIBRARY_PATH"
] = f"{args.shared_dir}/{OPTCAST_PLUGIN_DIR}:{os.environ['LD_LIBRARY_PATH']}"
os.environ["OPTCAST_REDUCTION_SERVERS"] = args.reduction_servers
os.environ["NCCL_BUFFSIZE"] = str(64 * 1024 * 1024)
chunksize = parse_chunksize(args.chunksize) // 2
os.environ["NCCL_COLLNET_CHUNKSIZE"] = str(chunksize)
os.environ["OPTCAST_SPLIT"] = str(args.nsplit)
elif args.type == "sharp":
os.environ["NCCL_COLLNET_ENABLE"] = "1"
if args.no_gpu:
client_cmd = f"{args.shared_dir}/{SERVER_CMD}"
os.environ["RUST_LOG"] = "TRACE" if rank == 0 else "INFO"

nreq = 4
try_count = 1000
count = parse_chunksize(args.chunksize) // (
4 if args.data_type == "f32" else 2
)
if args.type == "optcast":
cmd = f"{client_cmd} -c -a {args.reduction_servers} --count {count} --try-count {try_count} --nreq {nreq}"
elif args.type == "ring":
with open(args.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if args.nrank == 0:
args.nrank = len(config["clients"])
clients = config["clients"][: args.nrank]
neighs = [(rank + 1 + i) % args.nrank for i in range(2)]
addrs = ",".join(
clients[j]["name"] + f":{8080+i}" for i, j in enumerate(neighs)
)
cmd = f"{client_cmd} -a {addrs} --reduce-threads 2 --count {count} --try-count {try_count} --nrank {args.nrank} --ring-rank {rank+1} --nreq {nreq}"
else:
dt = "float" if args.data_type == "f32" else "half"
client_cmd = f"{args.shared_dir}/{CLIENT_CMD}"
cmd = f"{client_cmd} -d {dt} -e {args.size} -b {args.size} {args.nccl_test_options}"

os.environ["NCCL_DEBUG"] = "TRACE" if rank == 0 else "INFO"
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_SHM_DISABLE"] = "1"

if args.type == "optcast":
os.environ["NCCL_COLLNET_ENABLE"] = "1"
os.environ["LD_LIBRARY_PATH"] = (
f"{args.shared_dir}/{OPTCAST_PLUGIN_DIR}:{os.environ['LD_LIBRARY_PATH']}"
)
os.environ["OPTCAST_REDUCTION_SERVERS"] = args.reduction_servers
os.environ["NCCL_BUFFSIZE"] = str(64 * 1024 * 1024)
chunksize = parse_chunksize(args.chunksize) // 2
os.environ["NCCL_COLLNET_CHUNKSIZE"] = str(chunksize)
os.environ["OPTCAST_SPLIT"] = str(args.nsplit)
elif args.type == "sharp":
os.environ["NCCL_COLLNET_ENABLE"] = "1"

proc = await asyncio.create_subprocess_shell(
cmd,
Expand All @@ -381,7 +523,7 @@ async def client(args):
asyncio.create_task(
read_stream(
proc.stdout,
None,
platform.node(),
None,
False,
sys.stdout,
Expand Down Expand Up @@ -435,21 +577,23 @@ async def run(
servers = config["servers"][: args.nservers]
clients = config["clients"][: args.nrank]

reduction_servers = ",".join(f"{s['ipaddr']}:{s['port']}" for s in servers)
if args.no_gpu:
if args.type not in ["optcast", "ring"]:
raise ValueError(f"no-gpu option doesn't work with {args.type}")

reduction_servers = ",".join(
f"{s['ipaddr'] if 'ipaddr' in s else s['name']}:{s['port']}" for s in servers
)
cmd = " ".join(
(
args.mpirun,
f"-np {args.nrank} -H {','.join(c['name'] for c in clients)}",
"-x LD_LIBRARY_PATH",
f"{args.python} {args.shared_dir}/test/run.py",
f"--shared-dir {args.shared_dir}",
f"--client --size {args.size} --chunksize {args.chunksize} --nsplit {args.nsplit}",
f"--reduction-servers {reduction_servers}",
f"--nccl-test-options '{args.nccl_test_options}'",
f"--type {args.type} --data-type {args.data_type}",
gen_args(args),
f"--client --reduction-servers {reduction_servers}",
)
)
# print("client:", cmd)
client = await asyncio.create_subprocess_shell(
cmd,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -487,12 +631,10 @@ async def run(
args.mpirun,
f"-bind-to none -np {args.nservers} -H {','.join(s['name'] for s in servers)}",
f"{args.python} {args.shared_dir}/test/run.py",
f"--shared-dir {args.shared_dir}",
f"--server --num-jobs {args.num_jobs} --num-threads {args.num_threads} --num-recvs {args.num_recvs} --num-sends {args.num_sends}",
f"--nrank {args.nrank} --chunksize {args.chunksize} --nsplit {args.nsplit} --data-type {args.data_type}",
gen_args(args),
"--server",
)
)
# print("server:", cmd)
server = await asyncio.create_subprocess_shell(
cmd,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -525,6 +667,7 @@ async def run(
return analyze_client_log(
args.log_dir + "/client.log",
args.log_dir + "/client.png",
args.no_gpu,
args.xlim,
args.no_plot,
)
Expand All @@ -542,6 +685,7 @@ async def run(
return analyze_client_log(
args.log_dir + "/client.log",
args.log_dir + "/client.png",
args.no_gpu,
args.xlim,
args.no_plot,
)
Expand All @@ -563,7 +707,7 @@ def arguments():
parser.add_argument("--nsplit", default=1, type=int)
parser.add_argument("--reduction-servers")
parser.add_argument(
"--type", choices=["optcast", "sharp", "nccl"], default="optcast"
"--type", choices=["optcast", "sharp", "nccl", "ring"], default="optcast"
)
parser.add_argument("--nccl-test-options", default="-c 1 -n 1 -w 1")
parser.add_argument("--data-type", default="f32", choices=["f32", "f16"])
Expand All @@ -578,6 +722,7 @@ def arguments():
)
parser.add_argument("--analyze", "-a", action="store_true")
parser.add_argument("--no-plot", "-p", action="store_true")
parser.add_argument("--no-gpu", action="store_true")
parser.add_argument("--xlim", "-x")

return parser.parse_args()
Expand All @@ -594,7 +739,10 @@ def main():
if args.analyze:
if os.stat(args.log_dir + "/client.log"):
analyze_client_log(
args.log_dir + "/client.log", args.log_dir + "/client.png", args.xlim
args.log_dir + "/client.log",
args.log_dir + "/client.png",
args.no_gpu,
args.xlim,
)

if os.stat(args.log_dir + "/server.log"):
Expand Down