diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index 2afa7290e1..4efbe27c60 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -45,17 +45,35 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, + jit: bool = True, executors: Optional[Tuple[Union["Executor", str], ...]] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, **kwargs: Any, ): + r"""Strategy for Replicated Data Parallel provided by Lightning Thunder. + + .. warning:: This is an :ref:`experimental ` feature. + + Arguments: + jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually + jitting a function that includes the model. + + executors: The list of Thunder executors to enable. They can be either string aliases for the executors + or the actual executor instances. + + \**kwargs: See available parameters in :func:`thunder.distributed.ddp`. + + """ if not _THUNDER_AVAILABLE: raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment + if not jit and executors is not None: + raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") + self.jit = jit self.executors = _validate_executors(executors) self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend @@ -111,8 +129,25 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> Module: import thunder - module = thunder.distributed.ddp(module, **self._ddp_kwargs) - + if (cd := thunder.compile_data(module)) is not None: + # the module was already jitted + if thunder.compile_stats(module).last_traces is not None: + raise RuntimeError( + "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" + " DDP transform. Remove the `forward` call before `fabric.setup()`" + ) + assert cd.is_module # sanity check + ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs) + # update the compile data state + cd.fn = ddp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = ddp_module + cd.process_group_for_ddp = ddp_module.process_group_for_ddp + return module + else: + module = thunder.distributed.ddp(module, **self._ddp_kwargs) + if not self.jit: + return module return thunder.jit(module, executors=self.executors) @override diff --git a/extensions/thunder/strategies/thunder_fsdp.py b/extensions/thunder/strategies/thunder_fsdp.py index 6fd2200d70..d4e60c0085 100644 --- a/extensions/thunder/strategies/thunder_fsdp.py +++ b/extensions/thunder/strategies/thunder_fsdp.py @@ -54,12 +54,54 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, + jit: bool = True, + executors: Optional[Tuple[Union["Executor", str], ...]] = None, sharding_strategy: "_FSDP_TYPE" = "ZERO3", bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE", - executors: Optional[Tuple[Union["Executor", str], ...]] = None, state_dict_type: Literal["full", "sharded"] = "sharded", **kwargs: Any, ): + r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder. + + .. warning:: This is an :ref:`experimental ` feature. + + Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. + + Arguments: + jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually + jitting a function that includes the model. + + executors: The list of Thunder executors to enable. They can be either string aliases for the executors + or the actual executor instances. + + sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination + of them: + + - ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default). + - ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated. + + Also accepts a :class:`thunder.distributed.FSDPType` enum value. + + bucketing_strategy: Enables combining the collective operations for sets of layers. + + - ``"NONE"``: No bucketing (default). + - ``"LAYER"``: Create buckets per layer class. + - ``"BLOCK"``: Create buckets per layer block. + + Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value. + + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. + + - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file + (default). + - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is + a folder with as many files as the world size. + + \**kwargs: See available parameters in :func:`thunder.distributed.fsdp`. + + """ if not _TORCH_GREATER_EQUAL_2_2: raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.") if not _THUNDER_AVAILABLE: @@ -77,6 +119,9 @@ def __init__( if isinstance(bucketing_strategy, str) else bucketing_strategy ) + if not jit and executors is not None: + raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") + self.jit = jit self.executors = _validate_executors(executors) self._state_dict_type = state_dict_type self._fsdp_kwargs = kwargs @@ -115,16 +160,37 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> Module: import thunder - module = thunder.distributed.fsdp( - module, - device=self.root_device, - sharding_strategy=self.sharding_strategy, - bucketing_strategy=self.bucketing_strategy, - **self._fsdp_kwargs, - ) - - # NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call. - # we would still `jit(fsdp(undo_jit(jit(model))))` internally + if (cd := thunder.compile_data(module)) is not None: + # the module was already jitted + if thunder.compile_stats(module).last_traces is not None: + raise RuntimeError( + "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" + " FSDP transform. Remove the `forward` call before `fabric.setup()`" + ) + assert cd.is_module # sanity check + fsdp_module = thunder.distributed.fsdp( + cd.fn, + device=self.root_device, + sharding_strategy=self.sharding_strategy, + bucketing_strategy=self.bucketing_strategy, + **self._fsdp_kwargs, + ) + # update the compile data state + cd.fn = fsdp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = fsdp_module + cd.process_group_for_ddp = fsdp_module.process_group_for_ddp + return module + else: + module = thunder.distributed.fsdp( + module, + device=self.root_device, + sharding_strategy=self.sharding_strategy, + bucketing_strategy=self.bucketing_strategy, + **self._fsdp_kwargs, + ) + if not self.jit: + return module return thunder.jit(module, executors=self.executors) @override diff --git a/tests/test_thunder_ddp.py b/tests/test_thunder_ddp.py index 5ccc853eea..566e883ac3 100644 --- a/tests/test_thunder_ddp.py +++ b/tests/test_thunder_ddp.py @@ -10,13 +10,19 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy + + +@RunIf(thunder=True) +def test_thunder_strategy_input_parsing(): + with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): + ThunderDDPStrategy(jit=False, executors=("python",)) + @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.parametrize("strategy", ["ddp", "thunder_ddp"]) def test_no_backward_sync(strategy): if strategy == "thunder_ddp": - from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy - strategy = ThunderDDPStrategy() fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) @@ -47,3 +53,37 @@ def test_no_backward_sync(strategy): # rank0 rank1 allreduce1 rank0 rank1 allreduce2 assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) model.weight.grad = None + + +@RunIf(min_cuda_gpus=2, thunder=True, standalone=True) +@pytest.mark.parametrize("jit", (False, True)) +def test_jit_before_setup(jit): + import thunder + + fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit)) + fabric.launch() + + x = torch.randn(1, 1, device=fabric.device) + model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) + + tmodel = thunder.jit(model) + fmodel = fabric.setup(tmodel) + fmodel(x) + + assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python() + + +@RunIf(min_cuda_gpus=1, thunder=True) +def test_setup_already_traced(): + import thunder + + device = torch.device("cuda") + x = torch.randn(1, 1, device=device) + model = torch.nn.Linear(1, 2, bias=False, device=device) + + strategy = ThunderDDPStrategy() + + tmodel = thunder.jit(model) + tmodel(x) + with pytest.raises(RuntimeError, match="already called"): + strategy.setup_module(tmodel) diff --git a/tests/test_thunder_fsdp.py b/tests/test_thunder_fsdp.py index 76dc36bae6..8b9c0f4340 100644 --- a/tests/test_thunder_fsdp.py +++ b/tests/test_thunder_fsdp.py @@ -28,6 +28,9 @@ def test_thunder_strategy_input_parsing(): assert strategy.executors == (pythonex,) assert strategy.sharding_strategy is FSDPType.ZERO3 + with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): + ThunderFSDPStrategy(jit=False, executors=("python",)) + @RunIf(thunder=True) def test_validate_executors(): @@ -309,3 +312,37 @@ def test_save_load_sharded_checkpoint(tmp_path): actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) assert state["primitive"] == 123 + + +@RunIf(min_cuda_gpus=2, thunder=True, standalone=True) +@pytest.mark.parametrize("jit", (False, True)) +def test_jit_before_setup(jit): + import thunder + + fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit)) + fabric.launch() + + x = torch.randn(1, 1, device=fabric.device) + model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) + + tmodel = thunder.jit(model) + fmodel = fabric.setup(tmodel) + fmodel(x) + + assert "all_gather" in thunder.last_traces(tmodel)[-1].python() + + +@RunIf(min_cuda_gpus=1, thunder=True) +def test_setup_already_traced(): + import thunder + + device = torch.device("cuda") + x = torch.randn(1, 1, device=device) + model = torch.nn.Linear(1, 2, bias=False, device=device) + + strategy = ThunderFSDPStrategy() + + tmodel = thunder.jit(model) + tmodel(x) + with pytest.raises(RuntimeError, match="already called"): + strategy.setup_module(tmodel)