-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.nn.Dropout
recomputation support during the backward pass to Thunder
#114
Comments
fyi @ptrblck |
Here's a simple test case: import torch
import thunder
def func(a):
t1 = torch.nn.functional.dropout(a, p=0.5)
return t1 @ t1
a = torch.randn(2, 2, device="cuda", requires_grad=True)
jfunc = thunder.jit(func)
out = jfunc(a) Forward trace shows that the dropout mask is saved for backward ( print(thunder.last_traces(jfunc)[-1])
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
[t1, t4] = nvFusion0(a)
# t0 = prims.uniform((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.lt(t0, 0.5) # t1: "cuda:0 b8[2, 2]"
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t3 = prims.mul(a, t2) # t3: "cuda:0 f32[2, 2]"
# t4 = prims.mul(t3, 2.0) # t4: "cuda:0 f32[2, 2]"
t5 = torch.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
# t5 = ltorch.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
# t5 = prims.matmul(t4, t4) # t5: "cuda:0 f32[2, 2]"
return {'output': t5, 'flat_args': [a], 'flat_output': (t5,)}, ((t1, t4), (2.0,)) Backward trace: print(thunder.last_backward_traces(jfunc)[-1])
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t6, = cotangents
clear_collection(cotangents)
del cotangents
t1, t4, = C0
clear_collection(C0)
del C0
f7, = C1
clear_collection(C1)
del C1
t35 = torch.permute(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
# t35 = ltorch.permute(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
# t35 = prims.transpose(t4, (1, 0)) # t35: "cuda:0 f32[2, 2]"
del t4
t36 = torch.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
# t36 = ltorch.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
# t36 = prims.matmul(t35, t6) # t36: "cuda:0 f32[2, 2]"
t34 = torch.matmul(t6, t35) # t34: "cuda:0 f32[2, 2]"
# t34 = ltorch.matmul(t6, t33) # t34: "cuda:0 f32[2, 2]"
# t34 = prims.matmul(t6, t33) # t34: "cuda:0 f32[2, 2]"
del t6, t35
[t39] = nvFusion0(f7, t1, t34, t36)
# t2 = prims.convert_element_type(t1, dtypes.float32) # t2: "cuda:0 f32[2, 2]"
# t37 = prims.add(t34, t36) # t37: "cuda:0 f32[2, 2]"
# t38 = prims.mul(f7, t37) # t38: "cuda:0 f32[2, 2]"
# t39 = prims.mul(t2, t38) # t39: "cuda:0 f32[2, 2]"
del f7, t1, t34, t36
return (t39,) We should implement a trace transformation that replaces |
Triage review:
|
…_philox) and RNG state query/updating (#114)
…_philox) and RNG state query/updating (#114)
…_philox) and RNG state query/updating (#114)
@kiya00 @IvanYashchuk is this closed with #481 or not yet. What is our plan? |
@t-vi this can be closed. Testing result is in #481 (comment) |
Thank you @kiya00 |
🚀 Feature
I would like to have Thunder save the seed and offset from random number generation to allow for the recomputation of Dropout in the backward pass.
There are two pieces needed to make it work:
thunder.prims.uniform_philox
.uniform
call, replacinguniform
withuniform_philox
, and incrementing PRNG state properly. This is not implemented.Motivation
Multihead Attention modules in LLMs often use dropout where the memory used is the square of the sequence length.
cc @apaz-cli
The text was updated successfully, but these errors were encountered: