-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Demonstrates autograd integration with NVFuser multidevice #3787
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -116,3 +116,119 @@ def define_fusion(fd: FusionDefinition): | |||||||
torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1) | ||||||||
assert out_dtensor.device_mesh == in_dtensor.device_mesh | ||||||||
assert out_dtensor.placements == in_dtensor.placements | ||||||||
|
||||||||
|
||||||||
@pytest.mark.mpi | ||||||||
def test_linear(setup_process_group): | ||||||||
class FusionDefintionArguments: | ||||||||
def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): | ||||||||
self.d = num_devices | ||||||||
self.b = batch | ||||||||
self.s = sequence | ||||||||
self.e = hidden | ||||||||
|
||||||||
class LinearForwardDefinition(FusionDefintionArguments): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel using class and inheritance is an overkill. Functions and partials should be good enough.
and later
|
||||||||
def __call__(self, fd: FusionDefinition): | ||||||||
inp = fd.define_tensor([self.b, self.s, self.e]) | ||||||||
weight = fd.define_tensor( | ||||||||
[self.d, self.e, self.e], contiguity=[True, True, True] | ||||||||
) | ||||||||
bias = fd.define_tensor([self.d, self.e], contiguity=[True, True]) | ||||||||
out = fd.ops.linear(inp, weight, bias) | ||||||||
fd.add_output(out) | ||||||||
|
||||||||
class LinearBackwardDefinition(FusionDefintionArguments): | ||||||||
def __call__(self, fd: FusionDefinition): | ||||||||
x = fd.define_tensor([self.b, self.s, self.e]) | ||||||||
x = fd.ops.reshape(x, [self.b * self.s, self.e]) | ||||||||
w = fd.define_tensor([self.d, self.e, self.e], contiguity=True) | ||||||||
grad = fd.define_tensor([self.d, self.b, self.s, self.e], contiguity=True) | ||||||||
grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e]) | ||||||||
|
||||||||
grad_x_partials = fd.ops.matmul(grad, w) | ||||||||
grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce | ||||||||
grad_t = fd.ops.permute(grad, [0, 2, 1]) | ||||||||
grad_w = fd.ops.matmul(grad_t, x) | ||||||||
grad_b = fd.ops.sum(grad, [1]) | ||||||||
|
||||||||
grad_x = fd.ops.reshape(grad_x, [self.b, self.s, self.e]) | ||||||||
fd.add_output(grad_x) | ||||||||
fd.add_output(grad_w) | ||||||||
fd.add_output(grad_b) | ||||||||
|
||||||||
class LinearFunction(torch.autograd.Function): | ||||||||
@staticmethod | ||||||||
def forward( | ||||||||
ctx, | ||||||||
input: DTensor, | ||||||||
weight: DTensor, | ||||||||
bias: DTensor, | ||||||||
): | ||||||||
b, s, e = input._local_tensor.shape | ||||||||
d = weight.device_mesh.size() | ||||||||
op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you try to construct the op in |
||||||||
outputs = op([input, weight, bias]) | ||||||||
ctx.save_for_backward(input, weight) | ||||||||
return outputs[0] | ||||||||
|
||||||||
@staticmethod | ||||||||
def backward(ctx, grad_output: DTensor): | ||||||||
d, b, s, e = grad_output.shape | ||||||||
op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e)) | ||||||||
input, weight = ctx.saved_tensors | ||||||||
outputs = op([input, weight, grad_output]) | ||||||||
return outputs[0], outputs[1], outputs[2] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
world_size = dist.get_world_size() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
rank = dist.get_rank() | ||||||||
torch.cuda.set_device(rank) | ||||||||
|
||||||||
mesh = dist.device_mesh.init_device_mesh("cuda", [world_size]) | ||||||||
|
||||||||
d = world_size | ||||||||
b, s, e = 2, 1024, 768 | ||||||||
inp_tensor = torch.randn(b, s, e, device="cuda", requires_grad=True) | ||||||||
weight_tensor = torch.randn(world_size, e, e, device="cuda", requires_grad=True) | ||||||||
bias_tensor = torch.randn(world_size, e, device="cuda", requires_grad=True) | ||||||||
|
||||||||
inp_dtensor = dist.tensor.distribute_tensor(inp_tensor, mesh, [Replicate()]) | ||||||||
weight_dtensor = dist.tensor.distribute_tensor(weight_tensor, mesh, [Shard(0)]) | ||||||||
bias_dtensor = dist.tensor.distribute_tensor(bias_tensor, mesh, [Shard(0)]) | ||||||||
|
||||||||
# expected forward | ||||||||
unsharded_out_tensor = torch.nn.functional.linear( | ||||||||
inp_tensor, weight_tensor.view([d * e, e]), bias_tensor.view([d * e]) | ||||||||
) | ||||||||
expected_out_tensor = unsharded_out_tensor.view([b, s, d, e]).permute(2, 0, 1, 3)[ | ||||||||
rank : rank + 1 | ||||||||
] | ||||||||
|
||||||||
# multidevice forward | ||||||||
out_dtensor = LinearFunction.apply(inp_dtensor, weight_dtensor, bias_dtensor) | ||||||||
|
||||||||
# expected backward | ||||||||
(expected_grad_x, expected_grad_w, expected_grad_b) = torch.autograd.grad( | ||||||||
unsharded_out_tensor, | ||||||||
(inp_tensor, weight_tensor, bias_tensor), | ||||||||
torch.ones_like(unsharded_out_tensor), | ||||||||
) | ||||||||
|
||||||||
# multidevice backward | ||||||||
(grad_x, grad_w, grad_b) = torch.autograd.grad( | ||||||||
out_dtensor, | ||||||||
(inp_dtensor, weight_dtensor, bias_dtensor), | ||||||||
torch.ones_like(out_dtensor), | ||||||||
) | ||||||||
|
||||||||
torch.testing.assert_close( | ||||||||
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
to make the order consistent with other assert_closes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider
|
||||||||
) | ||||||||
torch.testing.assert_close( | ||||||||
expected_grad_x, grad_x.to_local(), rtol=1.3e-6, atol=1e-3 | ||||||||
) | ||||||||
torch.testing.assert_close( | ||||||||
expected_grad_w[rank : rank + 1], grad_w.to_local(), rtol=1.3e-6, atol=1e-3 | ||||||||
) | ||||||||
torch.testing.assert_close( | ||||||||
expected_grad_b[rank : rank + 1], grad_b.to_local(), rtol=1.3e-6, atol=1e-3 | ||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.