Skip to content

Commit 6a019b7

Browse files
carmoccarasbt
authored andcommitted
Smarter thunder.jit decisions (#1204)
1 parent 5638c8d commit 6a019b7

File tree

4 files changed

+193
-15
lines changed

4 files changed

+193
-15
lines changed

extensions/thunder/strategies/thunder_ddp.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,35 @@ def __init__(
4545
cluster_environment: Optional[ClusterEnvironment] = None,
4646
checkpoint_io: Optional[CheckpointIO] = None,
4747
precision: Optional[Precision] = None,
48+
jit: bool = True,
4849
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
4950
process_group_backend: Optional[str] = None,
5051
timeout: Optional[timedelta] = default_pg_timeout,
5152
**kwargs: Any,
5253
):
54+
r"""Strategy for Replicated Data Parallel provided by Lightning Thunder.
55+
56+
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
57+
58+
Arguments:
59+
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
60+
jitting a function that includes the model.
61+
62+
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
63+
or the actual executor instances.
64+
65+
\**kwargs: See available parameters in :func:`thunder.distributed.ddp`.
66+
67+
"""
5368
if not _THUNDER_AVAILABLE:
5469
raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))
5570
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)
5671
self.parallel_devices = parallel_devices
5772
self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment
5873

74+
if not jit and executors is not None:
75+
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
76+
self.jit = jit
5977
self.executors = _validate_executors(executors)
6078
self._num_nodes = 1
6179
self._process_group_backend: Optional[str] = process_group_backend
@@ -111,8 +129,25 @@ def setup_environment(self) -> None:
111129
def setup_module(self, module: Module) -> Module:
112130
import thunder
113131

114-
module = thunder.distributed.ddp(module, **self._ddp_kwargs)
115-
132+
if (cd := thunder.compile_data(module)) is not None:
133+
# the module was already jitted
134+
if thunder.compile_stats(module).last_traces is not None:
135+
raise RuntimeError(
136+
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
137+
" DDP transform. Remove the `forward` call before `fabric.setup()`"
138+
)
139+
assert cd.is_module # sanity check
140+
ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
141+
# update the compile data state
142+
cd.fn = ddp_module
143+
assert hasattr(cd, "_processed_function") # sanity check
144+
cd._processed_function = ddp_module
145+
cd.process_group_for_ddp = ddp_module.process_group_for_ddp
146+
return module
147+
else:
148+
module = thunder.distributed.ddp(module, **self._ddp_kwargs)
149+
if not self.jit:
150+
return module
116151
return thunder.jit(module, executors=self.executors)
117152

118153
@override

extensions/thunder/strategies/thunder_fsdp.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,54 @@ def __init__(
5454
cluster_environment: Optional[ClusterEnvironment] = None,
5555
checkpoint_io: Optional[CheckpointIO] = None,
5656
precision: Optional[Precision] = None,
57+
jit: bool = True,
58+
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
5759
sharding_strategy: "_FSDP_TYPE" = "ZERO3",
5860
bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE",
59-
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
6061
state_dict_type: Literal["full", "sharded"] = "sharded",
6162
**kwargs: Any,
6263
):
64+
r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.
65+
66+
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
67+
68+
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
69+
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
70+
at parity with PyTorch DDP, whilst scaling our model sizes dramatically.
71+
72+
Arguments:
73+
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
74+
jitting a function that includes the model.
75+
76+
executors: The list of Thunder executors to enable. They can be either string aliases for the executors
77+
or the actual executor instances.
78+
79+
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination
80+
of them:
81+
82+
- ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default).
83+
- ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated.
84+
85+
Also accepts a :class:`thunder.distributed.FSDPType` enum value.
86+
87+
bucketing_strategy: Enables combining the collective operations for sets of layers.
88+
89+
- ``"NONE"``: No bucketing (default).
90+
- ``"LAYER"``: Create buckets per layer class.
91+
- ``"BLOCK"``: Create buckets per layer block.
92+
93+
Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.
94+
95+
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
96+
97+
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file
98+
(default).
99+
- ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
100+
a folder with as many files as the world size.
101+
102+
\**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.
103+
104+
"""
63105
if not _TORCH_GREATER_EQUAL_2_2:
64106
raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.")
65107
if not _THUNDER_AVAILABLE:
@@ -77,6 +119,9 @@ def __init__(
77119
if isinstance(bucketing_strategy, str)
78120
else bucketing_strategy
79121
)
122+
if not jit and executors is not None:
123+
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
124+
self.jit = jit
80125
self.executors = _validate_executors(executors)
81126
self._state_dict_type = state_dict_type
82127
self._fsdp_kwargs = kwargs
@@ -115,16 +160,37 @@ def setup_environment(self) -> None:
115160
def setup_module(self, module: Module) -> Module:
116161
import thunder
117162

118-
module = thunder.distributed.fsdp(
119-
module,
120-
device=self.root_device,
121-
sharding_strategy=self.sharding_strategy,
122-
bucketing_strategy=self.bucketing_strategy,
123-
**self._fsdp_kwargs,
124-
)
125-
126-
# NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call.
127-
# we would still `jit(fsdp(undo_jit(jit(model))))` internally
163+
if (cd := thunder.compile_data(module)) is not None:
164+
# the module was already jitted
165+
if thunder.compile_stats(module).last_traces is not None:
166+
raise RuntimeError(
167+
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
168+
" FSDP transform. Remove the `forward` call before `fabric.setup()`"
169+
)
170+
assert cd.is_module # sanity check
171+
fsdp_module = thunder.distributed.fsdp(
172+
cd.fn,
173+
device=self.root_device,
174+
sharding_strategy=self.sharding_strategy,
175+
bucketing_strategy=self.bucketing_strategy,
176+
**self._fsdp_kwargs,
177+
)
178+
# update the compile data state
179+
cd.fn = fsdp_module
180+
assert hasattr(cd, "_processed_function") # sanity check
181+
cd._processed_function = fsdp_module
182+
cd.process_group_for_ddp = fsdp_module.process_group_for_ddp
183+
return module
184+
else:
185+
module = thunder.distributed.fsdp(
186+
module,
187+
device=self.root_device,
188+
sharding_strategy=self.sharding_strategy,
189+
bucketing_strategy=self.bucketing_strategy,
190+
**self._fsdp_kwargs,
191+
)
192+
if not self.jit:
193+
return module
128194
return thunder.jit(module, executors=self.executors)
129195

130196
@override

tests/test_thunder_ddp.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@
1010
wd = Path(__file__).parent.parent.resolve()
1111
sys.path.append(str(wd))
1212

13+
from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy
14+
15+
16+
@RunIf(thunder=True)
17+
def test_thunder_strategy_input_parsing():
18+
with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
19+
ThunderDDPStrategy(jit=False, executors=("python",))
20+
1321

1422
@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
1523
@pytest.mark.parametrize("strategy", ["ddp", "thunder_ddp"])
1624
def test_no_backward_sync(strategy):
1725
if strategy == "thunder_ddp":
18-
from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy
19-
2026
strategy = ThunderDDPStrategy()
2127

2228
fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy)
@@ -47,3 +53,37 @@ def test_no_backward_sync(strategy):
4753
# rank0 rank1 allreduce1 rank0 rank1 allreduce2
4854
assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)
4955
model.weight.grad = None
56+
57+
58+
@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
59+
@pytest.mark.parametrize("jit", (False, True))
60+
def test_jit_before_setup(jit):
61+
import thunder
62+
63+
fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit))
64+
fabric.launch()
65+
66+
x = torch.randn(1, 1, device=fabric.device)
67+
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)
68+
69+
tmodel = thunder.jit(model)
70+
fmodel = fabric.setup(tmodel)
71+
fmodel(x)
72+
73+
assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python()
74+
75+
76+
@RunIf(min_cuda_gpus=1, thunder=True)
77+
def test_setup_already_traced():
78+
import thunder
79+
80+
device = torch.device("cuda")
81+
x = torch.randn(1, 1, device=device)
82+
model = torch.nn.Linear(1, 2, bias=False, device=device)
83+
84+
strategy = ThunderDDPStrategy()
85+
86+
tmodel = thunder.jit(model)
87+
tmodel(x)
88+
with pytest.raises(RuntimeError, match="already called"):
89+
strategy.setup_module(tmodel)

tests/test_thunder_fsdp.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def test_thunder_strategy_input_parsing():
2828
assert strategy.executors == (pythonex,)
2929
assert strategy.sharding_strategy is FSDPType.ZERO3
3030

31+
with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"):
32+
ThunderFSDPStrategy(jit=False, executors=("python",))
33+
3134

3235
@RunIf(thunder=True)
3336
def test_validate_executors():
@@ -309,3 +312,37 @@ def test_save_load_sharded_checkpoint(tmp_path):
309312
actual["buf"] = actual["buf"].to(device="cpu")
310313
torch.testing.assert_close(actual, expected)
311314
assert state["primitive"] == 123
315+
316+
317+
@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
318+
@pytest.mark.parametrize("jit", (False, True))
319+
def test_jit_before_setup(jit):
320+
import thunder
321+
322+
fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit))
323+
fabric.launch()
324+
325+
x = torch.randn(1, 1, device=fabric.device)
326+
model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)
327+
328+
tmodel = thunder.jit(model)
329+
fmodel = fabric.setup(tmodel)
330+
fmodel(x)
331+
332+
assert "all_gather" in thunder.last_traces(tmodel)[-1].python()
333+
334+
335+
@RunIf(min_cuda_gpus=1, thunder=True)
336+
def test_setup_already_traced():
337+
import thunder
338+
339+
device = torch.device("cuda")
340+
x = torch.randn(1, 1, device=device)
341+
model = torch.nn.Linear(1, 2, bias=False, device=device)
342+
343+
strategy = ThunderFSDPStrategy()
344+
345+
tmodel = thunder.jit(model)
346+
tmodel(x)
347+
with pytest.raises(RuntimeError, match="already called"):
348+
strategy.setup_module(tmodel)

0 commit comments

Comments
 (0)