Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demonstrates autograd integration with NVFuser multidevice #3787

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/python/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,119 @@ def define_fusion(fd: FusionDefinition):
torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1)
assert out_dtensor.device_mesh == in_dtensor.device_mesh
assert out_dtensor.placements == in_dtensor.placements


@pytest.mark.mpi
def test_linear(setup_process_group):
class FusionDefintionArguments:
def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int):
self.d = num_devices
self.b = batch
self.s = sequence
self.e = hidden
Comment on lines +123 to +128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from dataclasses import dataclass

@dataclass
class LinearConfig:
    d: int
    b: int
    s: int
    e: int


class LinearForwardDefinition(FusionDefintionArguments):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel using class and inheritance is an overkill. Functions and partials should be good enough.

def define_linear_forward(config: LinearConfig, fd: FusionDefinition) -> None:

and later

partial(define_linear_forward, config)

def __call__(self, fd: FusionDefinition):
inp = fd.define_tensor([self.b, self.s, self.e])
weight = fd.define_tensor(
[self.d, self.e, self.e], contiguity=[True, True, True]
)
bias = fd.define_tensor([self.d, self.e], contiguity=[True, True])
out = fd.ops.linear(inp, weight, bias)
fd.add_output(out)

class LinearBackwardDefinition(FusionDefintionArguments):
def __call__(self, fd: FusionDefinition):
x = fd.define_tensor([self.b, self.s, self.e])
x = fd.ops.reshape(x, [self.b * self.s, self.e])
w = fd.define_tensor([self.d, self.e, self.e], contiguity=True)
grad = fd.define_tensor([self.d, self.b, self.s, self.e], contiguity=True)
grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e])

grad_x_partials = fd.ops.matmul(grad, w)
grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce
grad_t = fd.ops.permute(grad, [0, 2, 1])
grad_w = fd.ops.matmul(grad_t, x)
grad_b = fd.ops.sum(grad, [1])

grad_x = fd.ops.reshape(grad_x, [self.b, self.s, self.e])
fd.add_output(grad_x)
fd.add_output(grad_w)
fd.add_output(grad_b)

class LinearFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: DTensor,
weight: DTensor,
bias: DTensor,
):
b, s, e = input._local_tensor.shape
d = weight.device_mesh.size()
op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to construct the op in __init__? Example: https://github.com/canqin001/PointDAN/blob/5001b38cb5506b1c6b40ad1329c1d6f4fbbdd26d/Model.py#L29. I'm worried about the overhead of constructing FusionDefinitionWrapper for each forward and backward call.

outputs = op([input, weight, bias])
ctx.save_for_backward(input, weight)
return outputs[0]

@staticmethod
def backward(ctx, grad_output: DTensor):
d, b, s, e = grad_output.shape
op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e))
input, weight = ctx.saved_tensors
outputs = op([input, weight, grad_output])
return outputs[0], outputs[1], outputs[2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return outputs[0], outputs[1], outputs[2]
assert len(outputs) == 3
return outputs


world_size = dist.get_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
world_size = dist.get_world_size()
d = dist.get_world_size()

rank = dist.get_rank()
torch.cuda.set_device(rank)

mesh = dist.device_mesh.init_device_mesh("cuda", [world_size])

d = world_size
b, s, e = 2, 1024, 768
inp_tensor = torch.randn(b, s, e, device="cuda", requires_grad=True)
weight_tensor = torch.randn(world_size, e, e, device="cuda", requires_grad=True)
bias_tensor = torch.randn(world_size, e, device="cuda", requires_grad=True)

inp_dtensor = dist.tensor.distribute_tensor(inp_tensor, mesh, [Replicate()])
weight_dtensor = dist.tensor.distribute_tensor(weight_tensor, mesh, [Shard(0)])
bias_dtensor = dist.tensor.distribute_tensor(bias_tensor, mesh, [Shard(0)])

# expected forward
unsharded_out_tensor = torch.nn.functional.linear(
inp_tensor, weight_tensor.view([d * e, e]), bias_tensor.view([d * e])
)
expected_out_tensor = unsharded_out_tensor.view([b, s, d, e]).permute(2, 0, 1, 3)[
rank : rank + 1
]

# multidevice forward
out_dtensor = LinearFunction.apply(inp_dtensor, weight_dtensor, bias_dtensor)

# expected backward
(expected_grad_x, expected_grad_w, expected_grad_b) = torch.autograd.grad(
unsharded_out_tensor,
(inp_tensor, weight_tensor, bias_tensor),
torch.ones_like(unsharded_out_tensor),
)

# multidevice backward
(grad_x, grad_w, grad_b) = torch.autograd.grad(
out_dtensor,
(inp_dtensor, weight_dtensor, bias_dtensor),
torch.ones_like(out_dtensor),
)

torch.testing.assert_close(
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
expected_out_tensor, out_dtensor.to_local(), rtol=1.3e-6, atol=1e-3

to make the order consistent with other assert_closes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider

def assert_close(expected_tensor, dtensor):
  torch.testing.assert_close(expected_tensor, dtensor.to_local(), rtol=1.3e-6, atol=1e-3)

assert_close(...)
assert_close(...)
assert_close(...)
assert_close(...)

)
torch.testing.assert_close(
expected_grad_x, grad_x.to_local(), rtol=1.3e-6, atol=1e-3
)
torch.testing.assert_close(
expected_grad_w[rank : rank + 1], grad_w.to_local(), rtol=1.3e-6, atol=1e-3
)
torch.testing.assert_close(
expected_grad_b[rank : rank + 1], grad_b.to_local(), rtol=1.3e-6, atol=1e-3
)