Skip to content

Commit

Permalink
Smarter thunder.jit decisions (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 27, 2024
1 parent 84a73fd commit a67dd5c
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 15 deletions.
39 changes: 37 additions & 2 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,35 @@ def __init__(
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
jit: bool = True,
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
**kwargs: Any,
):
r"""Strategy for Replicated Data Parallel provided by Lightning Thunder.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Arguments:
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
jitting a function that includes the model.
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
or the actual executor instances.
\**kwargs: See available parameters in :func:`thunder.distributed.ddp`.
"""
if not _THUNDER_AVAILABLE:
raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
self.parallel_devices = parallel_devices
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment

if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = _validate_executors(executors)
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
Expand Down Expand Up @@ -111,8 +129,25 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> Module:
import thunder

module = thunder.distributed.ddp(module, **self._ddp_kwargs)

if (cd := thunder.compile_data(module)) is not None:
# the module was already jitted
if thunder.compile_stats(module).last_traces is not None:
raise RuntimeError(
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" DDP transform. Remove the `forward` call before `fabric.setup()`"
)
assert cd.is_module # sanity check
ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
# update the compile data state
cd.fn = ddp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = ddp_module
cd.process_group_for_ddp = ddp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.ddp(module, **self._ddp_kwargs)
if not self.jit:
return module
return thunder.jit(module, executors=self.executors)

@override
Expand Down
88 changes: 77 additions & 11 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,54 @@ def __init__(
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
jit: bool = True,
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
sharding_strategy: "_FSDP_TYPE" = "ZERO3",
bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE",
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
**kwargs: Any,
):
r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically.
Arguments:
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
jitting a function that includes the model.
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
or the actual executor instances.
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination
of them:
- ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default).
- ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated.
Also accepts a :class:`thunder.distributed.FSDPType` enum value.
bucketing_strategy: Enables combining the collective operations for sets of layers.
- ``"NONE"``: No bucketing (default).
- ``"LAYER"``: Create buckets per layer class.
- ``"BLOCK"``: Create buckets per layer block.
Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file
(default).
- ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
a folder with as many files as the world size.
\**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.
"""
if not _TORCH_GREATER_EQUAL_2_2:
raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.")
if not _THUNDER_AVAILABLE:
Expand All @@ -77,6 +119,9 @@ def __init__(
if isinstance(bucketing_strategy, str)
else bucketing_strategy
)
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = _validate_executors(executors)
self._state_dict_type = state_dict_type
self._fsdp_kwargs = kwargs
Expand Down Expand Up @@ -115,16 +160,37 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> Module:
import thunder

module = thunder.distributed.fsdp(
module,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)

# NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call.
# we would still `jit(fsdp(undo_jit(jit(model))))` internally
if (cd := thunder.compile_data(module)) is not None:
# the module was already jitted
if thunder.compile_stats(module).last_traces is not None:
raise RuntimeError(
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" FSDP transform. Remove the `forward` call before `fabric.setup()`"
)
assert cd.is_module # sanity check
fsdp_module = thunder.distributed.fsdp(
cd.fn,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
# update the compile data state
cd.fn = fsdp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = fsdp_module
cd.process_group_for_ddp = fsdp_module.process_group_for_ddp
return module
else:
module = thunder.distributed.fsdp(
module,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
if not self.jit:
return module
return thunder.jit(module, executors=self.executors)

@override
Expand Down
44 changes: 42 additions & 2 deletions tests/test_thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy


@RunIf(thunder=True)
def test_thunder_strategy_input_parsing():
with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
ThunderDDPStrategy(jit=False, executors=("python",))


@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("strategy", ["ddp", "thunder_ddp"])
def test_no_backward_sync(strategy):
if strategy == "thunder_ddp":
from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy

strategy = ThunderDDPStrategy()

fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy)
Expand Down Expand Up @@ -47,3 +53,37 @@ def test_no_backward_sync(strategy):
# rank0 rank1 allreduce1 rank0 rank1 allreduce2
assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)
model.weight.grad = None


@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("jit", (False, True))
def test_jit_before_setup(jit):
import thunder

fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit))
fabric.launch()

x = torch.randn(1, 1, device=fabric.device)
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)

tmodel = thunder.jit(model)
fmodel = fabric.setup(tmodel)
fmodel(x)

assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python()


@RunIf(min_cuda_gpus=1, thunder=True)
def test_setup_already_traced():
import thunder

device = torch.device("cuda")
x = torch.randn(1, 1, device=device)
model = torch.nn.Linear(1, 2, bias=False, device=device)

strategy = ThunderDDPStrategy()

tmodel = thunder.jit(model)
tmodel(x)
with pytest.raises(RuntimeError, match="already called"):
strategy.setup_module(tmodel)
37 changes: 37 additions & 0 deletions tests/test_thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def test_thunder_strategy_input_parsing():
assert strategy.executors == (pythonex,)
assert strategy.sharding_strategy is FSDPType.ZERO3

with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
ThunderFSDPStrategy(jit=False, executors=("python",))


@RunIf(thunder=True)
def test_validate_executors():
Expand Down Expand Up @@ -309,3 +312,37 @@ def test_save_load_sharded_checkpoint(tmp_path):
actual["buf"] = actual["buf"].to(device="cpu")
torch.testing.assert_close(actual, expected)
assert state["primitive"] == 123


@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
@pytest.mark.parametrize("jit", (False, True))
def test_jit_before_setup(jit):
import thunder

fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit))
fabric.launch()

x = torch.randn(1, 1, device=fabric.device)
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)

tmodel = thunder.jit(model)
fmodel = fabric.setup(tmodel)
fmodel(x)

assert "all_gather" in thunder.last_traces(tmodel)[-1].python()


@RunIf(min_cuda_gpus=1, thunder=True)
def test_setup_already_traced():
import thunder

device = torch.device("cuda")
x = torch.randn(1, 1, device=device)
model = torch.nn.Linear(1, 2, bias=False, device=device)

strategy = ThunderFSDPStrategy()

tmodel = thunder.jit(model)
tmodel(x)
with pytest.raises(RuntimeError, match="already called"):
strategy.setup_module(tmodel)

0 comments on commit a67dd5c

Please sign in to comment.