@@ -54,12 +54,54 @@ def __init__(
54
54
cluster_environment : Optional [ClusterEnvironment ] = None ,
55
55
checkpoint_io : Optional [CheckpointIO ] = None ,
56
56
precision : Optional [Precision ] = None ,
57
+ jit : bool = True ,
58
+ executors : Optional [Tuple [Union ["Executor" , str ], ...]] = None ,
57
59
sharding_strategy : "_FSDP_TYPE" = "ZERO3" ,
58
60
bucketing_strategy : "_BUCKETING_STRATEGY" = "NONE" ,
59
- executors : Optional [Tuple [Union ["Executor" , str ], ...]] = None ,
60
61
state_dict_type : Literal ["full" , "sharded" ] = "sharded" ,
61
62
** kwargs : Any ,
62
63
):
64
+ r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.
65
+
66
+ .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
67
+
68
+ Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
69
+ size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
70
+ at parity with PyTorch DDP, whilst scaling our model sizes dramatically.
71
+
72
+ Arguments:
73
+ jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
74
+ jitting a function that includes the model.
75
+
76
+ executors: The list of Thunder executors to enable. They can be either string aliases for the executors
77
+ or the actual executor instances.
78
+
79
+ sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination
80
+ of them:
81
+
82
+ - ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default).
83
+ - ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated.
84
+
85
+ Also accepts a :class:`thunder.distributed.FSDPType` enum value.
86
+
87
+ bucketing_strategy: Enables combining the collective operations for sets of layers.
88
+
89
+ - ``"NONE"``: No bucketing (default).
90
+ - ``"LAYER"``: Create buckets per layer class.
91
+ - ``"BLOCK"``: Create buckets per layer block.
92
+
93
+ Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.
94
+
95
+ state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
96
+
97
+ - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file
98
+ (default).
99
+ - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
100
+ a folder with as many files as the world size.
101
+
102
+ \**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.
103
+
104
+ """
63
105
if not _TORCH_GREATER_EQUAL_2_2 :
64
106
raise ImportError ("Thunder's FSDP strategy requires PyTorch 2.2 or higher." )
65
107
if not _THUNDER_AVAILABLE :
@@ -77,6 +119,9 @@ def __init__(
77
119
if isinstance (bucketing_strategy , str )
78
120
else bucketing_strategy
79
121
)
122
+ if not jit and executors is not None :
123
+ raise ValueError (f"Passing executors={ executors } doesn't have an effect with `jit={ jit } `" )
124
+ self .jit = jit
80
125
self .executors = _validate_executors (executors )
81
126
self ._state_dict_type = state_dict_type
82
127
self ._fsdp_kwargs = kwargs
@@ -115,16 +160,37 @@ def setup_environment(self) -> None:
115
160
def setup_module (self , module : Module ) -> Module :
116
161
import thunder
117
162
118
- module = thunder .distributed .fsdp (
119
- module ,
120
- device = self .root_device ,
121
- sharding_strategy = self .sharding_strategy ,
122
- bucketing_strategy = self .bucketing_strategy ,
123
- ** self ._fsdp_kwargs ,
124
- )
125
-
126
- # NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call.
127
- # we would still `jit(fsdp(undo_jit(jit(model))))` internally
163
+ if (cd := thunder .compile_data (module )) is not None :
164
+ # the module was already jitted
165
+ if thunder .compile_stats (module ).last_traces is not None :
166
+ raise RuntimeError (
167
+ "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
168
+ " FSDP transform. Remove the `forward` call before `fabric.setup()`"
169
+ )
170
+ assert cd .is_module # sanity check
171
+ fsdp_module = thunder .distributed .fsdp (
172
+ cd .fn ,
173
+ device = self .root_device ,
174
+ sharding_strategy = self .sharding_strategy ,
175
+ bucketing_strategy = self .bucketing_strategy ,
176
+ ** self ._fsdp_kwargs ,
177
+ )
178
+ # update the compile data state
179
+ cd .fn = fsdp_module
180
+ assert hasattr (cd , "_processed_function" ) # sanity check
181
+ cd ._processed_function = fsdp_module
182
+ cd .process_group_for_ddp = fsdp_module .process_group_for_ddp
183
+ return module
184
+ else :
185
+ module = thunder .distributed .fsdp (
186
+ module ,
187
+ device = self .root_device ,
188
+ sharding_strategy = self .sharding_strategy ,
189
+ bucketing_strategy = self .bucketing_strategy ,
190
+ ** self ._fsdp_kwargs ,
191
+ )
192
+ if not self .jit :
193
+ return module
128
194
return thunder .jit (module , executors = self .executors )
129
195
130
196
@override
0 commit comments