-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prims.copy_to_out_
#1194
Add prims.copy_to_out_
#1194
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shino16 .
I think the overall approach is good and to my mind, we can merge when it fully works, but I want to let @crcrpar and/or @IvanYashchuk to take a look.
There seem to be a few cases to look at, though:
FAILED thunder/tests/test_inplace_functionalization.py::test_inplace_to_alias_func_args_nvfuser_cuda_thunder.dtypes.float32 - NotImplementedError: <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="a", dtype=thunder.dtypes.float32, shape=(2, 2))> or None in the region may propagate to <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_inplace_copy_on_fusion_inputs_issue_791_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="x", dtype=thunder.dtypes.float32, shape=(2, 2))> or <TensorProxy(name="t1", dtype=thunder.dtypes.float32, shape=(2, 2))> in the region may propagate to <TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_multiple_inplace_to_multiple_args_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t12", dtype=thunder.dtypes.float32, shape=(2, 2))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="t_1_1", dtype=thunder.dtypes.float32, shape=(2, 2))> or None in the region may propagate to <TensorProxy(name="t12", dtype=thunder.dtypes.float32, shape=(2, 2))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_inplace_functionalization.py::test_single_tensor_adam_like_nvfuser_cuda_None - NotImplementedError: <TensorProxy(name="t2", dtype=thunder.dtypes.float32, shape=(4,))> (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. Copies onto <TensorProxy(name="exp_avg", dtype=thunder.dtypes.float32, shape=(4,))> or None in the region may propagate to <TensorProxy(name="t2", dtype=thunder.dtypes.float32, shape=(4,))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-long-context-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama1-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-llama2-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-falcon-7b-like] - NotImplementedError: <TensorProxy(name="t85", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))> or None in the region may propagate to <TensorProxy(name="t85", dtype=thunder.dtypes.float32, shape=(1, 1, 3, 64))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-falcon-40b-like] - NotImplementedError: <TensorProxy(name="t87", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))> or None in the region may propagate to <TensorProxy(name="t87", dtype=thunder.dtypes.float32, shape=(1, 64, 3, 4))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache[cuda-codellama2-like] - NotImplementedError: <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. Copies onto <TensorProxy(name="value", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))> or None in the region may propagate to <TensorProxy(name="t81", dtype=thunder.dtypes.float32, shape=(1, 4, 3, 16))>. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
thunder/torch/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you educate me why would we want to prefer clang.copy_to_out_
to prims.copy_to_out_
?
To allow out
having a different dtype from computed
? If so, why wouldn't the prim allow it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traditionally, we have resolved typing before the prim level when decomposing, so having clang.copy_to_out_
decompose to (optionally) prims.convert_element_type
plus prims.copy_out_to_
seems matching the patterns we have.
ffbca34
to
77fd9ed
Compare
for copy_bsyms in bsym_to_copy_bsyms[bsym]: | ||
functionalized_bsyms.extend(copy_bsyms) | ||
copy_bsym = functionalized_bsyms[-1] | ||
# wrap_return_value_together_with_argments places all the arguments in the return value | ||
# We swap these arguments in the return value with the outputs of copies onto them | ||
# This prevents subsequent transforms from ordering the return statement before those copies | ||
swap_map_for_return[variableify(copy_bsym.flat_proxy_args[0])] = copy_bsym.flat_proxy_outs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the changes in this file are just variable renaming, except this loop and line 547 brought in commit 47f35d9.
This fixes a bug which caused some of the test failures mentioned in #1194 (review). When bsym
(key_bsym
in line 547) is associated with multiple copy_bsyms
, bsym_to_copy_bsyms[bsym]
previously looked like [(reshape,) copy, (reshape,) copy, ...]
. Now it looks like [[(reshape,) copy], [(reshape,) copy], ...]
, and we iterate through all copies.
As of now, the test Minimal reproducible example: import torch
import thunder
@partial(thunder.jit, disable_inplace_copy_check=True)
def f(q, k, v, mask, idx, src):
q.index_copy_(2, idx, src)
k.index_copy_(2, idx, src)
return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
q = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
k = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
v = torch.randn((1, 4, 3, 16), device='cuda', dtype=torch.float32)
mask = torch.ones((1, 1, 2, 3), device='cuda', dtype=torch.bool)
idx = torch.arange(2).to(device='cuda')
src = torch.randn((1, 4, 2, 16), device='cuda', dtype=torch.float32)
f(q, k, v, mask, idx, src) Execution trace: def computation(q, k, v, mask, idx, src):
# q: "cuda:0 f32[1, 4, 2, 16]"
# k: "cuda:0 f32[1, 4, 3, 16]"
# v: "cuda:0 f32[1, 4, 3, 16]"
# mask: "cuda:0 b8[1, 1, 2, 3]"
# idx: "cuda:0 i64[2]"
# src: "cuda:0 f32[1, 4, 2, 16]"
t0 = torch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = ltorch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = prims.index_copy(q, idx, src, 2) # t0: "cuda:0 f32[1, 4, 2, 16]"
t2 = torch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = ltorch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = prims.index_copy(k, idx, src, 2) # t2: "cuda:0 f32[1, 4, 3, 16]"
[t1, t3] = nvFusion0(t0, q, t2, k)
# t1 = prims.copy_to_out_(t0, out=q) # t1: "cuda:0 f32[1, 4, 2, 16]"
# t3 = prims.copy_to_out_(t2, out=k) # t3: "cuda:0 f32[1, 4, 3, 16]"
del q, k
(t19, _, _, _) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t0, t2, v, mask, 0.0, False, None)
del t0, t2
return t19
Note that, before passing the trace to the nvFuser executor, Trace just before nvFuserdef computation(q, k, v, mask, idx, src):
# q: "cuda:0 f32[1, 4, 2, 16]"
# k: "cuda:0 f32[1, 4, 3, 16]"
# v: "cuda:0 f32[1, 4, 3, 16]"
# mask: "cuda:0 b8[1, 1, 2, 3]"
# idx: "cuda:0 i64[2]"
# src: "cuda:0 f32[1, 4, 2, 16]"
# Functionalized from `t1 = index_copy_(q,2,idx,src)`
t0 = ltorch.index_copy(q, 2, idx, src) # t0: "cuda:0 f32[1, 4, 2, 16]"
# t0 = prims.index_copy(q, idx, src, 2) # t0: "cuda:0 f32[1, 4, 2, 16]"
# Functionalized from `t3 = index_copy_(k,2,idx,src)`
t2 = ltorch.index_copy(k, 2, idx, src) # t2: "cuda:0 f32[1, 4, 3, 16]"
# t2 = prims.index_copy(k, idx, src, 2) # t2: "cuda:0 f32[1, 4, 3, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60: return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
# ['t1', 't3'] are replaced by ['t0', 't2'], respectively
t19 = ltorch.scaled_dot_product_attention(t0, t2, v, mask, 0.0, False, scale=None) # t19: "cuda:0 f32[1, 4, 2, 16]"
# subsymbols omitted
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:58: q.index_copy_(2, idx, src)
t1 = prims.copy_to_out_(t0, out=q) # t1: "cuda:0 f32[1, 4, 2, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:59: k.index_copy_(2, idx, src)
t3 = prims.copy_to_out_(t2, out=k) # t3: "cuda:0 f32[1, 4, 3, 16]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:60: return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
return {'output': t19, 'flat_args': [t1, t3, v, mask, idx, src]} Possible solutions
|
I disabled the sanity check for |
Ugh. Could it be that I wonder if something along the lines of @IvanYashchuk 's planned primitive for dataflow healing would be useful. |
The same happens when we use My internship period is about to end, so I can no longer spend much time on this issue. Maybe we can close this PR for now and wait for functionalization of |
closing for now @crcrpar please reopen as you see fit. |
Fixes #1173. This adds a new primitive
prims.copy_to_out_(computed, *, out)
, which is used instead ofprims.copy_
for the update in in-place ops. Unlikeprims.copy_
,prims.copy_to_out_(computed, *, out)
assumes thatcomputed
is not used by subsequent ops, soout
can simply aliascomputed
.main: 63887b3
#1193: 2781a20
torch.compile(adam.step, backend=thunder)
main
torch.compile(adam.step, backend=thunder)
, #1193The rule with
prims.copy_to_out_
is (link):This rule comes from the fact that any copies onto
out
will be propagated to its alias,computed
.To prevent users from using
prims.copy_to_out_
inappropriately, I made the sanity check onprims.copy_to_out_
rather conservative. When enabled, it raises an error whencomputed
isSee tests for examples.