From f121cc3ae1a8babed1a09b2e7a625d9129405eed Mon Sep 17 00:00:00 2001 From: "Ehsan K. Ardestani" Date: Tue, 10 May 2022 17:04:52 -0700 Subject: [PATCH] execution time breakdown Summary: Report the time per FWD, BWD, and Optimizer. Differential Revision: D36151122 fbshipit-source-id: 7dbd51e9403c7f4c8597b709bd7a6c51c06ea589 --- train/compute/pt/pytorch_linear.py | 115 ++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 27 deletions(-) diff --git a/train/compute/pt/pytorch_linear.py b/train/compute/pt/pytorch_linear.py index 31bfb090..110258f2 100644 --- a/train/compute/pt/pytorch_linear.py +++ b/train/compute/pt/pytorch_linear.py @@ -34,7 +34,9 @@ def train_cpu( loss_f = nn.CrossEntropyLoss() # model.train() - start_time = time.time() + events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]} + times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0} + events["start_all"].append(time.time()) for i in range(args.steps + args.warmups): data = torch.randn(batch_size, input_size, device=device) @@ -45,15 +47,26 @@ def train_cpu( if data_type == "float16": data = data.half() + events["start_all"].append(time.time()) optimizer.zero_grad() + events["start_fwd"].append(time.time()) output = model(data).float() + events["stop_fwd"].append(time.time()) loss = loss_f(output, target) + events["start_bwd"].append(time.time()) loss.backward() + events["stop_bwd"].append(time.time()) + events["start_opt"].append(time.time()) optimizer.step() + events["stop_opt"].append(time.time()) + events["stop_all"].append(time.time()) if i < args.warmups: - start_time = time.time() + for t in events.values(): + t.clear() - return time.time() - start_time, loss + for key in ["all", "fwd", "bwd", "opt"]: + times[key] = sum([te-ts for ts, te in zip(events["start_"+key], events["stop_"+key])]) + return times, loss def train_gpu( @@ -67,10 +80,11 @@ def train_gpu( model = apex.fp16_utils.network_to_half(model) # model.train() + times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0} + events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]} torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - total_time = 0.0 + for e in events.keys(): + events[e] = torch.cuda.Event(enable_timing=True) for i in range(args.steps + args.warmups): data = torch.randn(batch_size, input_size, device=device) @@ -81,20 +95,26 @@ def train_gpu( if data_type == "float16": data = data.half() - if i >= args.warmups: - start_event.record() + events["start_all"].record() optimizer.zero_grad() + events["start_fwd"].record() output = model(data).float() + events["stop_fwd"].record() loss = loss_f(output, target) + events["start_bwd"].record() loss.backward() + events["stop_bwd"].record() + events["start_opt"].record() optimizer.step() + events["stop_opt"].record() if i >= args.warmups: - end_event.record() + events["stop_all"].record() torch.cuda.synchronize() - total_time += start_event.elapsed_time(end_event) * 1.0e-3 + for key in ["all", "fwd", "bwd", "opt"]: + times[key] += events["start_"+key].elapsed_time(events["stop_"+key]) * 1.0e-3 - return total_time, loss + return times, loss def train_tpu( @@ -105,7 +125,9 @@ def train_tpu( loss_f = nn.CrossEntropyLoss().to(device) # model.train() - start_time = time.time() + times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0} + events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]} + events["start_all"].append(time.time()) for i in range(args.steps + args.warmups): data = torch.randn(batch_size, input_size, device=device) @@ -114,16 +136,27 @@ def train_tpu( ) # data, target = data.to(device), target.to(device) + events["start_all"].append(time.time()) optimizer.zero_grad() + events["start_fwd"].append(time.time()) output = model(data).float() + events["stop_fwd"].append(time.time()) loss = loss_f(output, target) + events["start_bwd"].append(time.time()) loss.backward() + events["stop_bwd"].append(time.time()) + events["start_opt"].append(time.time()) optimizer.step() xm.mark_step() + events["stop_opt"].append(time.time()) + events["stop_all"].append(time.time()) if i < args.warmups: - start_time = time.time() + for t in events.values(): + t.clear() - return time.time() - start_time, loss + for key in ["all", "fwd", "bwd", "opt"]: + times[key] = sum([te-ts for ts, te in zip(events["start_"+key], events["stop_"+key])]) + return times, loss def train( @@ -131,7 +164,7 @@ def train( ): if device.type == "cpu": - elap, loss = train_cpu( + times, loss = train_cpu( model, device, optimizer, @@ -143,7 +176,7 @@ def train( ) elif device.type == "cuda": - elap, loss = train_gpu( + times, loss = train_gpu( model, device, optimizer, @@ -155,7 +188,7 @@ def train( ) elif device.type == "xla": - elap, loss = train_tpu( + times, loss = train_tpu( model, device, optimizer, @@ -166,7 +199,7 @@ def train( args, ) - return elap, loss + return times, loss def run_single(args, layers_size, batch_size): @@ -203,6 +236,14 @@ def run_single(args, layers_size, batch_size): optimizer = apex.optimizers.FusedLAMB( model.parameters(), lr=lr, set_grad_none=True ) + elif optimizer_type == "adam": + optimizer = apex.optimizers.FusedAdam( + model.parameters(), lr=lr, set_grad_none=True + ) + elif optimizer_type == "adagrad": + optimizer = apex.optimizers.FusedAdagrad( + model.parameters(), lr=lr, set_grad_none=True + ) else: assert 0, "Unsupported optimizer type" @@ -225,10 +266,10 @@ def run_single(args, layers_size, batch_size): else: assert 0, "Unsupported optimizer type" - elap, loss = train( + times, loss = train( model, dev, optimizer, data_type, layers_size[0], layers_size[-1], batch_size, args ) - return elap, loss + return times, loss def run(args, dataset): @@ -237,7 +278,7 @@ def run(args, dataset): "--------------------------------------------------------------------------------" ) print( - " #Layer Input Hidden Output Batch Time(s)/step QPS Rate(TF/s)" + " Num Layers Batch Time(s)/step: All FWD BWD OPT QPS (TF/s):FWD BWD OPT(GB/s)" ) print( "--------------------------------------------------------------------------------" @@ -245,33 +286,53 @@ def run(args, dataset): for i in range(len(dataset)): layers_size, batch_size = dataset[i] - elap, loss = run_single( + times, loss = run_single( args, layers_size, batch_size ) + elap = times["all"] + fwd_t = times["fwd"] + bwd_t = times["bwd"] + opt_t = times["opt"] + elap /= args.steps + fwd_t /= args.steps + bwd_t /= args.steps + opt_t /= args.steps flops = 0 for i in range(len(layers_size)-1): flops += layers_size[i] * layers_size[i+1] + params = flops + bytes_per_dtype = 4 if args.dtype == "float" else 2 + params *= bytes_per_dtype + # how many bytes of optimizer states? flops *= batch_size # Forward 2x and Backward 4x + fwd_flops = flops * 2 + bwd_flops = flops * 6 flops *= 6 + + QPS = batch_size / elap # The hidden layer size could vary, but for now keeping for backward # compatibility print( - "{0:6}, {1:6}, {2:6}, {3:6}, {4:6}, {5:10.6f}, {6:8.1f}, {7:10.1f}".format( + "{0:6}, {1:6}, {2:.3f}, {3:.3f}, {4:.3f}, {5:.3f}, {6:8.1f}, {7:10.1f}, {8:.3f}, {9:.3f}, {10:.3f}".format( len(layers_size), - layers_size[0], - layers_size[1], - layers_size[-1], batch_size, elap, + fwd_t, + bwd_t, + opt_t, QPS, flops / elap / 1.0e12, + fwd_flops / fwd_t / 1.0e12, + bwd_flops / bwd_t / 1.0e12, + params / opt_t / 1.0e9 + ) ) @@ -304,7 +365,7 @@ def dash_separated_ints(value): "--optimizer-type", default="sgd", help="Optimizer: SGD", - choices=["sgd", "lamb"], + choices=["sgd", "lamb", "adam", "adagrad"], ) parser.add_argument( "--dtype",