|
| 1 | +from contextlib import contextmanager |
1 | 2 | import torch |
2 | 3 | import torch.nn as nn |
3 | 4 | from torch import Tensor |
@@ -52,6 +53,19 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: |
52 | 53 | aten = torch.ops.aten |
53 | 54 |
|
54 | 55 |
|
| 56 | +@contextmanager |
| 57 | +def preserve_rng_state(): |
| 58 | + rng_state = torch.clone(torch.random.get_rng_state()) |
| 59 | + if torch.cuda.is_available(): |
| 60 | + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) |
| 61 | + try: |
| 62 | + yield |
| 63 | + finally: |
| 64 | + torch.random.set_rng_state(rng_state) |
| 65 | + if torch.cuda.is_available(): |
| 66 | + torch.cuda.set_rng_state(cuda_rng_state) |
| 67 | + |
| 68 | + |
55 | 69 | def create_joint_forward_backward(fn): |
56 | 70 | def joint_forward_backward( |
57 | 71 | primals: List[Any], tangents: List[Any] |
@@ -147,27 +161,29 @@ class CompiledFunction(torch.autograd.Function): |
147 | 161 | def forward(ctx, *flat_tensor_args): |
148 | 162 | nonlocal compiled_fw, compiled_bw, num_outs |
149 | 163 | if compiled_fw is None: |
150 | | - # Set input tensors that require grad to leaves |
151 | | - flat_tensor_args = pytree.tree_map( |
152 | | - lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args |
153 | | - ) |
154 | | - with torch.set_grad_enabled(grad_state): |
155 | | - out = flat_fn(*flat_tensor_args) |
156 | | - out = pytree.tree_map( |
157 | | - lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out |
158 | | - ) |
| 164 | + with preserve_rng_state(): |
| 165 | + # Set input tensors that require grad to leaves |
| 166 | + flat_tensor_args = pytree.tree_map( |
| 167 | + lambda x: x.detach().requires_grad_(x.requires_grad), flat_tensor_args |
| 168 | + ) |
| 169 | + with torch.set_grad_enabled(grad_state): |
| 170 | + out = flat_fn(*flat_tensor_args) |
| 171 | + out = pytree.tree_map( |
| 172 | + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out |
| 173 | + ) |
159 | 174 |
|
160 | | - if isinstance(out, (list, tuple)): |
161 | | - num_outs = len(out) |
162 | | - else: |
163 | | - num_outs = 1 |
| 175 | + if isinstance(out, (list, tuple)): |
| 176 | + num_outs = len(out) |
| 177 | + else: |
| 178 | + num_outs = 1 |
| 179 | + |
| 180 | + joint_inputs = (flat_tensor_args, out) |
| 181 | + aot_decompositions = {**aot_autograd_decompositions, **decompositions} |
| 182 | + with torch.set_grad_enabled(grad_state): |
| 183 | + fx_g = make_fx(joint_forward_backward, aot_decompositions)( |
| 184 | + *joint_inputs |
| 185 | + ) |
164 | 186 |
|
165 | | - joint_inputs = (flat_tensor_args, out) |
166 | | - aot_decompositions = {**aot_autograd_decompositions, **decompositions} |
167 | | - with torch.set_grad_enabled(grad_state): |
168 | | - fx_g = make_fx(joint_forward_backward, aot_decompositions)( |
169 | | - *joint_inputs |
170 | | - ) |
171 | 187 | fw_module, bw_module = partition_fn(fx_g, joint_inputs) |
172 | 188 | # print(fw_module.code, bw_module.code) |
173 | 189 |
|
|
0 commit comments