Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patching tensor proxy shape in trace #1260

Merged
merged 21 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3fb1302
quick fix on reshape
jjsjann123 Oct 4, 2024
7c5677c
patching tracectx for printing
jjsjann123 Oct 4, 2024
60f9f17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2024
3d3700d
switch to _shape to avoid trace
jjsjann123 Oct 4, 2024
218ae1d
Merge remote-tracking branch 'origin/patching_TensorProxyShape_in_tra…
jjsjann123 Oct 4, 2024
27a7130
reverting this first and follow up to check the print again
jjsjann123 Oct 4, 2024
db34394
TensorProxy string print shouldn't trace its shape access
jjsjann123 Oct 4, 2024
cb79133
fixing tests
jjsjann123 Oct 4, 2024
fefea9e
Merge remote-tracking branch 'origin/main' into patching_TensorProxyS…
jjsjann123 Oct 4, 2024
d8e2d97
fixing dce to handle multiple producers; adding tests
jjsjann123 Oct 5, 2024
100c9f8
let's not set the whole world on fire
jjsjann123 Oct 5, 2024
81ae5c5
Merge branch 'main' into patching_TensorProxyShape_in_trace
jjsjann123 Oct 7, 2024
cf42387
Merge remote-tracking branch 'origin/main' into patching_TensorProxyS…
jjsjann123 Oct 7, 2024
851ad77
quick patch on tests
jjsjann123 Oct 7, 2024
1ad3d37
Merge remote-tracking branch 'origin/patching_TensorProxyShape_in_tra…
jjsjann123 Oct 7, 2024
8995ed5
updating torch issue
jjsjann123 Oct 7, 2024
98c5f92
refactor the fix
jjsjann123 Oct 7, 2024
0520e42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2024
30fa4b4
typo
jjsjann123 Oct 7, 2024
cdc98b7
Merge remote-tracking branch 'origin/patching_TensorProxyShape_in_tra…
jjsjann123 Oct 7, 2024
0843c06
addressing review comments
jjsjann123 Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def translate(x: Any, *, name: str | None = None) -> Any:
if isinstance(x, Proxy):
# register proxy name used by NumberProxies in TensorProxy.shape
if isinstance(x, TensorProxy):
for s_p in filter(lambda s: isinstance(s, Proxy), x.shape):
for s_p in filter(lambda s: isinstance(s, Proxy), x._shape):
# TODO need to avoid name conflict here, since s_p.name
# could have conflicted with something defined earlier in
# the trace.
Expand Down
20 changes: 15 additions & 5 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from thunder.core.compile_data import using_symbolic_values, using_jit
from thunder.core.interpreter import is_jitting, ProvenanceRecord, PseudoInst
from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx
from thunder.core.trace import (
VariableInterface,
get_tracectx,
is_tracing,
TraceCtx,
)
from thunder.core.baseutils import (
ProxyInterface,
NumberProxyInterface,
Expand Down Expand Up @@ -1251,7 +1256,7 @@ def _infer_tensor_properties(
else:
# deferred computation of numel
# TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency
_numel = lambda tp: reduce(operator.mul, tp.shape, 1)
_numel = lambda *args: reduce(operator.mul, _shape, 1)
t-vi marked this conversation as resolved.
Show resolved Hide resolved

# TODO Alias rank to ndim?
_ndim = len(_shape)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1459,7 +1464,12 @@ def __init__(
# outside of a trace or language context
@property
def shape(self):
return self._shape
if not using_symbolic_values() or not is_tracing():
return self._shape
else:
from thunder.core.prims import shape

return shape(self)

@property
def ndim(self):
Expand Down Expand Up @@ -1548,10 +1558,10 @@ def replace(self, **changes):
)

def __repr__(self):
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>'
return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape})>'

def type_string(self):
return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}"
return f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}"

# NOTE __getattr__ is overridden to support language-specific methods
def __getattr__(self, attr: str, /):
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def add_to_swap_map(old, new):
# (FSDP, tensor parallel) transforms do "break" shape metadata
new_trace.names.remove(old.name) # taken by the .replace proxy
if isinstance(new, VJPDual):
old = old.replace(shape=new.primal.shape)
old = old.replace(shape=new.primal._shape)
else:
old = old.replace(shape=new.shape)
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
swap_map[variableify(new.primal)] = old
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,7 +2811,7 @@ def _vjp(primals, cotangents, **kwargs):
# If the argument is a CPU scalar tensor, its gradient needs to be summed into a scalar tensor.
vjp_result = tuple(
(
sum_to(grad, arg.shape)
sum_to(grad, arg._shape)
if (grad is not None and isinstance(arg, TensorProxy) and arg.device.type == "cpu")
else grad
)
Expand Down
9 changes: 7 additions & 2 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,9 @@ def __repr__(self) -> str:
return str(self._dict)


# NOTE That this pass does not assume that the bound symbols are in a reasonable order,
# but it does assume that each proxy is uniquely constructed once
# NOTE That this pass does not assume that the bound symbols are in a reasonable order.
# For bound symbols with multiple producers, this pass returns the first producer of
# in order of the presented bound symbols
# Returns a proxy -> producer mapping
# If _map_to_numbers is True then producers are represented by their position in the trace (their "line number")
def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_numbers: bool = False) -> ProxyDict:
Expand All @@ -1021,6 +1022,10 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_
continue

for out in bsym.flat_proxy_outs:
# if a producer has already been traversed, skip
if producers.get(out, None) != None:
continue

vout = variableify(out)

# Checks if the proxy was also an input (in which case this is not its producers)
Expand Down
2 changes: 2 additions & 0 deletions thunder/executors/pythonex.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten
pythonex_pow = ex.register_operator("pow", like=prims.pow, module=operator)
sub = ex.register_operator("sub", like=prims.sub, module=operator)
div = ex.register_operator("div", like=prims.div, fn=_div_prim_impl)
shape = ex.register_operator("shape", like=prims.shape, fn=lambda x: x.shape)

# TODO: Restore truediv once we find it...
# truediv = ex.register_operator("truediv", like=prims.truediv, module=operator)
Expand All @@ -367,6 +368,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten
ex.register_implementation(prims.pow, pythonex_pow, checker=_elementwise_binary_checker)
ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker)
ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker)
ex.register_implementation(prims.shape, shape, checker=_always_executable)


def _sink(*args, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,11 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe
),
# NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor
# e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y)
DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int32, datatypes.int64)),
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
dtypes=(datatypes.int8, datatypes.int16, datatypes.int32, datatypes.int64),
t-vi marked this conversation as resolved.
Show resolved Hide resolved
),
),
)
elementwise_binary_ops.append(pow_opinfo)
Expand Down
46 changes: 45 additions & 1 deletion thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def forward(self, x):
("cpu", "cuda"),
)
def test_cache_symbolic_values_reshape(device):
if not torch.cuda.is_available():
if device == "cuda" and not torch.cuda.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we get anything out of running this test on multiple devices?

I am wondering if it makes more sense to just not parameterize the test and run it once on the CPU.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good call!
At this point since nvfuser isn't taking shape operations at all, GPU test doesn't do anything. Let me clean it up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjsjann123 This is still pending?

pytest.skip("CUDA not available")

a = torch.randn((4, 8, 6), device=device)
Expand Down Expand Up @@ -1439,3 +1439,47 @@ def foo(a):
assert_close(actual, expected)
assert thunder.cache_misses(jfoo) == 1
assert thunder.cache_hits(jfoo) == 1


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
)
def test_cache_symbolic_values_reshape_numel(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

def foo(a):
a = torch.reshape(a, [a.numel()])
return a.relu()

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn(2, 3, 8, requires_grad=True, device=device)
tfogal marked this conversation as resolved.
Show resolved Hide resolved

actual = jfoo(a)
expected = foo(a)

assert_close(actual, expected)


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
)
def test_cache_symbolic_values_slice(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

def foo(a):
a = a[..., : a.shape[-1]]
return a.relu()
tfogal marked this conversation as resolved.
Show resolved Hide resolved

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn(2, 3, 8, requires_grad=True, device=device)

actual = jfoo(a)
expected = foo(a)

assert_close(actual, expected)
Loading