Skip to content

Commit

Permalink
[export] Cover more cases to copy tensor conversions. (pytorch#125628)
Browse files Browse the repository at this point in the history
Summary:

Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed May 8, 2024
1 parent faf0015 commit b9b1aac
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
26 changes: 26 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,32 @@ def forward(self, x):
):
export(Module(), (torch.tensor(1, device="cpu"),))

def test_float_conversion(self):
class Module(torch.nn.Module):
def forward(self, x):
return x.float()

ep = export(Module(), (torch.tensor(1, dtype=torch.float),))
ops = []
for node in ep.graph.nodes:
if node.op == "call_function":
ops.append(node.target)
self.assertGreater(len(ops), 0)
for op in ops:
self.assertIn(op, (torch.ops.aten._to_copy.default,))

def test_device_to_mutation_float(self):
class Module(torch.nn.Module):
def forward(self, x):
y = x.float()
y.add_(1)
return y, x

with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
export(Module(), (torch.tensor(1, dtype=torch.float),))

def test_module(self):
class MyLinear(torch.nn.Module):
def __init__(self):
Expand Down
25 changes: 25 additions & 0 deletions torch/_subclasses/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")


def _conversion_method_template(**extra_kwargs):
def _(self, *args, **kwargs):
return self.to(*args, **{**kwargs, **extra_kwargs})

return _


class FunctionalTensor(torch.Tensor):
"""
Functional tensors represent tensors that will remove mutations
Expand Down Expand Up @@ -225,6 +232,24 @@ def to(self, *args, **kwargs):
return super().to(*args, **{**kwargs, "copy": True})
return super().to(*args, **kwargs)

def cuda(self, device=None, *args, **kwargs):
device = device or torch.cuda.current_device()
if len(args) > 0:
return self.to(device, *args, **kwargs)
else:
return self.to(device=device, **kwargs)

char = _conversion_method_template(dtype=torch.int8)
cpu = _conversion_method_template(device=torch.device("cpu"))
bfloat16 = _conversion_method_template(dtype=torch.bfloat16)
byte = _conversion_method_template(dtype=torch.uint8)
double = _conversion_method_template(dtype=torch.float64)
float = _conversion_method_template(dtype=torch.float32)
bool = _conversion_method_template(dtype=torch.bool)
half = _conversion_method_template(dtype=torch.float16)
int = _conversion_method_template(dtype=torch.int32)
long = _conversion_method_template(dtype=torch.int64)


class FunctionalTensorMode(TorchDispatchMode):
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
Expand Down

0 comments on commit b9b1aac

Please sign in to comment.