-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Demonstrates autograd integration with NVFuser multidevice #3787
base: main
Are you sure you want to change the base?
Conversation
Cool -- add me to reviewers when it's ready! |
@wujingyue To review. |
Oops I can't add to the reviewers list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
class FusionDefintionArguments: | ||
def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): | ||
self.d = num_devices | ||
self.b = batch | ||
self.s = sequence | ||
self.e = hidden |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from dataclasses import dataclass
@dataclass
class LinearConfig:
d: int
b: int
s: int
e: int
self.s = sequence | ||
self.e = hidden | ||
|
||
class LinearForwardDefinition(FusionDefintionArguments): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel using class and inheritance is an overkill. Functions and partials should be good enough.
def define_linear_forward(config: LinearConfig, fd: FusionDefinition) -> None:
and later
partial(define_linear_forward, config)
): | ||
b, s, e = input._local_tensor.shape | ||
d = weight.device_mesh.size() | ||
op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to construct the op in __init__
? Example: https://github.com/canqin001/PointDAN/blob/5001b38cb5506b1c6b40ad1329c1d6f4fbbdd26d/Model.py#L29. I'm worried about the overhead of constructing FusionDefinitionWrapper for each forward and backward call.
outputs = op([input, weight, grad_output]) | ||
return outputs[0], outputs[1], outputs[2] | ||
|
||
world_size = dist.get_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
world_size = dist.get_world_size() | |
d = dist.get_world_size() |
op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e)) | ||
input, weight = ctx.saved_tensors | ||
outputs = op([input, weight, grad_output]) | ||
return outputs[0], outputs[1], outputs[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return outputs[0], outputs[1], outputs[2] | |
assert len(outputs) == 3 | |
return outputs |
) | ||
|
||
torch.testing.assert_close( | ||
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 | |
expected_out_tensor, out_dtensor.to_local(), rtol=1.3e-6, atol=1e-3 |
to make the order consistent with other assert_closes.
) | ||
|
||
torch.testing.assert_close( | ||
out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider
def assert_close(expected_tensor, dtensor):
torch.testing.assert_close(expected_tensor, dtensor.to_local(), rtol=1.3e-6, atol=1e-3)
assert_close(...)
assert_close(...)
assert_close(...)
assert_close(...)
This PR demonstrates how to wrap a forward and a backward fusion definition in a
torch.autograd.Function
that takes PyTorch DTensors as input and outputs PyTorch DTensors.