diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index de52067b500..6d976e0ee74 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -116,3 +116,119 @@ def define_fusion(fd: FusionDefinition): torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1) assert out_dtensor.device_mesh == in_dtensor.device_mesh assert out_dtensor.placements == in_dtensor.placements + + +@pytest.mark.mpi +def test_linear(setup_process_group): + class FusionDefintionArguments: + def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): + self.d = num_devices + self.b = batch + self.s = sequence + self.e = hidden + + class LinearForwardDefinition(FusionDefintionArguments): + def __call__(self, fd: FusionDefinition): + inp = fd.define_tensor([self.b, self.s, self.e]) + weight = fd.define_tensor( + [self.d, self.e, self.e], contiguity=[True, True, True] + ) + bias = fd.define_tensor([self.d, self.e], contiguity=[True, True]) + out = fd.ops.linear(inp, weight, bias) + fd.add_output(out) + + class LinearBackwardDefinition(FusionDefintionArguments): + def __call__(self, fd: FusionDefinition): + x = fd.define_tensor([self.b, self.s, self.e]) + x = fd.ops.reshape(x, [self.b * self.s, self.e]) + w = fd.define_tensor([self.d, self.e, self.e], contiguity=True) + grad = fd.define_tensor([self.d, self.b, self.s, self.e], contiguity=True) + grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e]) + + grad_x_partials = fd.ops.matmul(grad, w) + grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce + grad_t = fd.ops.permute(grad, [0, 2, 1]) + grad_w = fd.ops.matmul(grad_t, x) + grad_b = fd.ops.sum(grad, [1]) + + grad_x = fd.ops.reshape(grad_x, [self.b, self.s, self.e]) + fd.add_output(grad_x) + fd.add_output(grad_w) + fd.add_output(grad_b) + + class LinearFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: DTensor, + weight: DTensor, + bias: DTensor, + ): + b, s, e = input._local_tensor.shape + d = weight.device_mesh.size() + op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e)) + outputs = op([input, weight, bias]) + ctx.save_for_backward(input, weight) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: DTensor): + d, b, s, e = grad_output.shape + op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e)) + input, weight = ctx.saved_tensors + outputs = op([input, weight, grad_output]) + return outputs[0], outputs[1], outputs[2] + + world_size = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + + mesh = dist.device_mesh.init_device_mesh("cuda", [world_size]) + + d = world_size + b, s, e = 2, 1024, 768 + inp_tensor = torch.randn(b, s, e, device="cuda", requires_grad=True) + weight_tensor = torch.randn(world_size, e, e, device="cuda", requires_grad=True) + bias_tensor = torch.randn(world_size, e, device="cuda", requires_grad=True) + + inp_dtensor = dist.tensor.distribute_tensor(inp_tensor, mesh, [Replicate()]) + weight_dtensor = dist.tensor.distribute_tensor(weight_tensor, mesh, [Shard(0)]) + bias_dtensor = dist.tensor.distribute_tensor(bias_tensor, mesh, [Shard(0)]) + + # expected forward + unsharded_out_tensor = torch.nn.functional.linear( + inp_tensor, weight_tensor.view([d * e, e]), bias_tensor.view([d * e]) + ) + expected_out_tensor = unsharded_out_tensor.view([b, s, d, e]).permute(2, 0, 1, 3)[ + rank : rank + 1 + ] + + # multidevice forward + out_dtensor = LinearFunction.apply(inp_dtensor, weight_dtensor, bias_dtensor) + + # expected backward + (expected_grad_x, expected_grad_w, expected_grad_b) = torch.autograd.grad( + unsharded_out_tensor, + (inp_tensor, weight_tensor, bias_tensor), + torch.ones_like(unsharded_out_tensor), + ) + + # multidevice backward + (grad_x, grad_w, grad_b) = torch.autograd.grad( + out_dtensor, + (inp_dtensor, weight_dtensor, bias_dtensor), + torch.ones_like(out_dtensor), + ) + + torch.testing.assert_close( + out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_x, grad_x.to_local(), rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_w[rank : rank + 1], grad_w.to_local(), rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_b[rank : rank + 1], grad_b.to_local(), rtol=1.3e-6, atol=1e-3 + )