-
Notifications
You must be signed in to change notification settings - Fork 86
/
example.py
119 lines (92 loc) · 3.35 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates
# Minimal effort to run this code:
# $ torchrun --nproc-per-node 3 example.py
import os
import torch
from pippy import pipeline, SplitPoint, ScheduleGPipe, PipelineStage
in_dim = 512
layer_dims = [512, 1024, 256]
out_dim = 10
# Single layer definition
class MyNetworkBlock(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin = torch.nn.Linear(in_dim, out_dim)
def forward(self, x):
x = self.lin(x)
x = torch.relu(x)
return x
# Full model definition
class MyNetwork(torch.nn.Module):
def __init__(self):
super().__init__()
self.num_layers = len(layer_dims)
prev_dim = in_dim
# Add layers one by one
for i, dim in enumerate(layer_dims):
super().add_module(f"layer{i}", MyNetworkBlock(prev_dim, dim))
prev_dim = dim
# Final output layer (with OUT_DIM projection classes)
self.output_proj = torch.nn.Linear(layer_dims[-1], out_dim)
def forward(self, x):
for i in range(self.num_layers):
layer = getattr(self, f"layer{i}")
x = layer(x)
return self.output_proj(x)
# To run a distributed training job, we must launch the script in multiple
# different processes. We are using `torchrun` to do so in this example.
# `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`,
# which represent the index of this process within the set of processes and
# the total number of processes, respectively.
#
# To learn more about `torchrun`, see
# https://pytorch.org/docs/stable/elastic/run.html
torch.manual_seed(0)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Figure out device to use
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
else:
device = torch.device("cpu")
# Create the model
mn = MyNetwork().to(device)
split_spec = {
"layer0": SplitPoint.END,
"layer1": SplitPoint.END,
}
batch_size = 32
example_input = torch.randn(batch_size, in_dim, device=device)
chunks = 4
pipe = pipeline(mn, chunks, example_args=(example_input,), split_spec=split_spec)
if rank == 0:
print(" pipe ".center(80, "*"))
print(pipe)
print(" stage 0 ".center(80, "*"))
print(pipe.split_gm.submod_0)
print(" stage 1 ".center(80, "*"))
print(pipe.split_gm.submod_1)
print(" stage 2 ".center(80, "*"))
print(pipe.split_gm.submod_2)
# Initialize distributed environment
import torch.distributed as dist
dist.init_process_group(rank=rank, world_size=world_size)
# Pipeline stage is our main pipeline runtime. It takes in the pipe object,
# the rank of this process, and the device.
stage = PipelineStage(pipe, rank, device)
# Attach to a schedule
schedule = ScheduleGPipe(stage, chunks)
# Input data
x = torch.randn(batch_size, in_dim, device=device)
# Run the pipeline with input `x`. Divide the batch into 4 micro-batches
# and run them in parallel on the pipeline
if rank == 0:
schedule.step(x)
else:
output = schedule.step()
if rank == world_size - 1:
# Run the original code and get the output for comparison
reference_output = mn(x)
# Compare numerics of pipeline and original model
torch.testing.assert_close(output, reference_output)
print(" Pipeline parallel model ran successfully! ".center(80, "*"))