From 2789b43f0ad7934149d7ee123cac6cb516503a41 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 16 Apr 2024 14:22:05 +0200 Subject: [PATCH] Add trace transform to replace uniform with stateless uniform(uniform_philox) and RNG state query/updating (#114) --- thunder/core/prims.py | 86 +++++++++++++++++++++++++++++ thunder/executors/passes.py | 57 +++++++++++++++++++ thunder/executors/torch_autograd.py | 7 +++ thunder/executors/torchex.py | 80 +++++++++++++++++++++++++++ thunder/tests/test_randomness.py | 58 +++++++++++++++++++ 5 files changed, 288 insertions(+) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index a08ae1e02c..1a53eba004 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -251,6 +251,14 @@ class PrimIDs(Enum): # Memory access methods ITEM = auto() COPY_ = auto() + SET_SEED = auto() + SET_OFFSET=auto() + GET_SEED =auto() + GET_OFFSET = auto() + SET_RNG_STATE = auto() + GET_RNG_STATE = auto() + UNPACK_RNG_STATE = auto() + PACK_RNG_STATE = auto() class OpTags(Enum): @@ -2451,6 +2459,84 @@ def _uniform_meta( tags=(OpTags.RANDOM_OP,), ) +def _set_seed_meta(s): + return None +def _get_seed_meta(): + # return numberproxy(int, 0) + return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64) +def _set_offset_meta(s): + return None +def _get_offset_meta(): + return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64) + # return numberproxy(int, 0) +def _set_rng_state_meta(new_state): # TODO, device + return None +def _get_rng_state_meta(): #TODO device + # static const size_t seed_size = sizeof(uint64_t); + # static const size_t offset_size = sizeof(int64_t); + # static const size_t total_size = seed_size + offset_size; + + # auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2 + return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu) + +set_rng_state = make_prim( + PrimIDs.SET_RNG_STATE, + "set_rng_state", + meta=_set_rng_state_meta, + tags=(OpTags.RANDOM_OP, OpTags.DONT_DCE), +) +get_rng_state = make_prim( + PrimIDs.GET_RNG_STATE, + "get_rng_state", + meta=_get_rng_state_meta, + tags=(OpTags.RANDOM_OP,), +) +def _unpack_rng_state_meta(state): + return numberproxy(int, 0), numberproxy(int, 0) + # return TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64), TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64) + +unpack_rng_state = make_prim( + PrimIDs.UNPACK_RNG_STATE, + "unpack_rng_state", + meta=_unpack_rng_state_meta, + tags=(OpTags.RANDOM_OP,), +) +def _pack_rng_state_meta(seed, offset): + state_shape = dtypes.int64.bytes//dtypes.uint8.bytes * 2 + return TensorProxy(shape=(state_shape,), dtype=dtypes.uint8, device=devices.cpu) + +pack_rng_state = make_prim( + PrimIDs.PACK_RNG_STATE, + "pack_rng_state", + meta=_pack_rng_state_meta, + tags=(OpTags.RANDOM_OP,), +) +set_seed = make_prim( + PrimIDs.SET_SEED, + "set_seed", + meta=_set_seed_meta, + tags=(OpTags.RANDOM_OP,), +) +get_seed = make_prim( + PrimIDs.GET_SEED, + "get_seed", + meta=_get_seed_meta, + tags=(OpTags.RANDOM_OP,), +) +set_offset = make_prim( + PrimIDs.SET_OFFSET, + "set_offset", + meta=_set_offset_meta, + tags=(OpTags.RANDOM_OP,OpTags.DONT_DCE), +) +get_offset = make_prim( + PrimIDs.GET_OFFSET, + "get_offset", + meta=_get_offset_meta, + tags=(OpTags.RANDOM_OP,), +) + def _uniform_philox_meta( shape: Sequence[int], diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 8f1604e718..437ea57924 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -25,6 +25,63 @@ comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL} +def replace_uniform(trace: TraceCtx) -> TraceCtx: + start_time_ns = time.time_ns() + from thunder.torch import uniform_philox + + swapmap: dict[Variable, Proxy] = {} + + def update_swapmap(o: Any, no: Any) -> None: + if isinstance(o, Proxy): + check( + isinstance(no, Proxy), + lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + ) + + vo = variableify(o) + vno = variableify(no) + if vo == vno: + return + swapmap[vno] = o + + def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: + import thunder.torch as ltorch + if bsym.sym.id == prims.PrimIDs.UNIFORM: + rng_state = prims.get_rng_state() + print(rng_state.shape, rng_state.dtype) + # seed, offset = ltorch.chunk(rng_state, 2) + seed, offset = prims.unpack_rng_state(rng_state) + # seed = prims.get_seed() + # offset = prims.get_offset() + out = uniform_philox(*bsym.args,**bsym.kwargs, seed=seed, offset=offset) + adv_offs = 4 + new_offset = prims.add(offset, adv_offs) + new_state = prims.pack_rng_state(seed, new_offset) + prims.set_rng_state(new_state) + safe_map_flat(update_swapmap, bsym.output, out) + return transforms.VISIT_TYPE.REPLACE + else: + return transforms.VISIT_TYPE.NO_OP + + extrace = transforms.visitor_transform(trace, visit_) + + # Restores original variables + bound_symbols: list[BoundSymbol] = [] + for bsym in extrace.bound_symbols: + nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + bound_symbols.append(nbsym) + + extrace.bound_symbols = bound_symbols + + end_time_ns = time.time_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 + extrace.set_provenance( + TraceProvenance(f"Transform for replace uniform (took {elapsed_time_millis} milliseconds)") + ) + return extrace + + # Transforms a trace by determining which execution transforms to call given the list of executors in priority order def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: start_time_ns = time.time_ns() diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 96ecff43c5..f39e690833 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -110,6 +110,8 @@ def make_trace(func): fw_traces = [fw_trace] bw_traces = [bw_trace] + # print(primal_trace) + # print(fw_trace, bw_trace) from thunder.distributed import FSDPType @@ -137,6 +139,11 @@ def make_trace(func): fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace, bw_trace.names) _fsdp_comm_bucketing.update_name_set(bw_trace) + print(fw_trace) + from thunder.executors.passes import replace_uniform + fw_trace = replace_uniform(fw_trace) + print(fw_trace) + # Now we can run the optimization passes on the forward trace # TODO Restore request for no rematerialization fw_extrace = transform_for_execution( diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index d8bf2c0828..e8f7e965f8 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -427,6 +427,86 @@ def _tensor_from_sequence_prims_transform( return tensor_from_sequence(seq_or_number, device=torch_device, dtype=torch_dtype) + +def _get_seed_prim_impl(): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + return torch.tensor(cuda_generator.initial_seed()) + # print("++++: ",cuda_generator.initial_seed()) + # return cuda_generator.initial_seed() +get_seed_prim_impl = ex.register_operator( + "get_seed_prim_impl", meta=prims.get_seed.meta, fn=_get_seed_prim_impl +) +_register_implementation(prims.get_seed, get_seed_prim_impl, checker=_always_executable, ) + +def _get_offset_prim_impl(): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + return torch.tensor(cuda_generator.get_offset()) + # return cuda_generator.get_offset() +get_offset_prim_impl = ex.register_operator( + "get_offset_prim_impl", meta=prims.get_offset.meta, fn=_get_offset_prim_impl +) +_register_implementation(prims.get_offset, get_offset_prim_impl, checker=_always_executable, ) + + +def _set_seed_prim_impl(s): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + return cuda_generator.manual_seed(s) +set_seed_prim_impl = ex.register_operator( + "set_seed_prim_impl", meta=prims.set_seed.meta, fn=_set_seed_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.set_seed, set_seed_prim_impl, checker=_always_executable, ) + + +def _set_offset_prim_impl(s): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + cuda_generator.set_offset(s.item()) +set_offset_prim_impl = ex.register_operator( + "set_offset_prim_impl", meta=prims.set_offset.meta, fn=_set_offset_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.set_offset, set_offset_prim_impl, checker=_always_executable, ) + +def _set_rng_state_prim_impl(s): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + cuda_generator.set_state(s) +set_rng_state_prim_impl = ex.register_operator( + "set_rng_state_prim_impl", meta=prims.set_rng_state.meta, fn=_set_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.set_rng_state, set_rng_state_prim_impl, checker=_always_executable, ) + +def _get_rng_state_prim_impl(): + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + return cuda_generator.get_state() + # return cuda_generator.graphsafe_get_state() +get_rng_state_prim_impl = ex.register_operator( + "get_rng_state_prim_impl", meta=prims.get_rng_state.meta, fn=_get_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.get_rng_state, get_rng_state_prim_impl, checker=_always_executable, ) + +def _unpack_rng_state_prim_impl(s): + seed, offset = torch.chunk(s, 2) + # return seed.view(torch.int64), offset.view(torch.int64) + return seed.view(torch.int64).item(), offset.view(torch.int64).item() +unpack_rng_state_prim_impl = ex.register_operator( + "unpack_rng_state_prim_impl", meta=prims.unpack_rng_state.meta, fn=_unpack_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.unpack_rng_state, unpack_rng_state_prim_impl, checker=_always_executable, ) + +def _pack_rng_state_prim_impl(seed, offset): + seed = torch.tensor(seed) + offset = torch.tensor(offset) + seed_portion = seed.reshape([1]).view(torch.uint8) + offset_portion = offset.reshape([1]).view(torch.uint8) + new_state = torch.cat([seed_portion, offset_portion]) + return new_state +pack_rng_state_prim_impl = ex.register_operator( + "pack_rng_state_prim_impl", meta=prims.pack_rng_state.meta, fn=_pack_rng_state_prim_impl, tags=(prims.OpTags.RANDOM_OP,prims.OpTags.DONT_DCE) +) +_register_implementation(prims.pack_rng_state, pack_rng_state_prim_impl, checker=_always_executable, ) + + + + + _register_implementation(prims.full, checker=_always_executable, execution_transform=_full_transform) _register_implementation(prims.iota, checker=_always_executable, execution_transform=_iota_transform) _register_implementation(prims.uniform, checker=_always_executable, execution_transform=_uniform_transform) diff --git a/thunder/tests/test_randomness.py b/thunder/tests/test_randomness.py index 3f715013ca..cb1cc2e917 100644 --- a/thunder/tests/test_randomness.py +++ b/thunder/tests/test_randomness.py @@ -22,6 +22,64 @@ def func(shape, dtype, device, rng_seed, rng_offset): cf = lc_compile(func, disable_preprocessing=True, executors_list=executor.executors_list()) + import torch + rng_seed_tensor = torch.tensor(rng_seed) + rng_offset_tensor = torch.tensor(rng_offset) + # outputs = [cf(shape, dtype, device, rng_seed_tensor, rng_offset_tensor) for _ in range(3)] outputs = [cf(shape, dtype, device, rng_seed, rng_offset) for _ in range(3)] + import thunder + print(thunder.last_traces(cf)[-1]) + # print(thunder.last_backward_traces(cf)[-1]) for o in outputs: assert_close(o, outputs[0]) + + +@instantiate( + dtypes=(dtypes.float32, dtypes.float16, dtypes.float64), + devicetypes=(devices.DeviceType.CUDA,), +) +def test_rng_state_uniform_philox(executor, device: str, dtype: dtypes.dtype): + import torch + import thunder + def func(a): + b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype) + # b = torch.nn.functional.dropout(a, p=0.5) + c = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype) + # b = torch.uniform(a.shape, device=a.device, dtype=a.dtype) + return c*b + + cuda_generator = torch.cuda.default_generators[torch.cuda.current_device()] + + a = torch.randn(2, 2, device="cuda", requires_grad=True) + a1 = a.detach().clone() + a1.requires_grad_() + + jfunc = thunder.jit(func) + cuda_generator.manual_seed(20) + expects = [] + for _ in range(4): + out = jfunc(a) + print("b: ", out) + out.sum().backward() + print(a.grad) + expects.append(out) + expects.append(a.grad) + print("------------------") + + results = [] + cuda_generator.manual_seed(20) + + # a = torch.randn(2, 2, device="cuda", requires_grad=True) + print(a1) + for _ in range(4): + out = jfunc(a1) + print("b: ", out) + out.sum().backward() + print(a1.grad) + results.append(out) + results.append(a1.grad) + + print(thunder.last_traces(jfunc)[-1]) + print(thunder.last_backward_traces(jfunc)[-1]) + for expected, result in zip(expects, results): + assert_close(expected, result)