Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smarter thunder.jit decisions #1204

Merged
merged 5 commits into from
Mar 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,54 @@ def __init__(
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
jit: bool = True,
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
sharding_strategy: "_FSDP_TYPE" = "ZERO3",
bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE",
executors: Optional[Tuple[Union["Executor", str], ...]] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
**kwargs: Any,
):
r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically.

Arguments:
jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually
jitting a function that includes the model.

executors: The list of Thunder executors to enable. They can be either string aliases for the executors
or the actual executor instances.

sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination
of them:

- ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default).
- ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated.

Also accepts a :class:`thunder.distributed.FSDPType` enum value.

bucketing_strategy: Enables combining the collective operations for sets of layers.

- ``"NONE"``: No bucketing (default).
- ``"LAYER"``: Create buckets per layer class.
- ``"BLOCK"``: Create buckets per layer block.

Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file
(default).
- ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is
a folder with as many files as the world size.

\**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.

"""
if not _TORCH_GREATER_EQUAL_2_2:
raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.")
if not _THUNDER_AVAILABLE:
Expand All @@ -77,6 +119,9 @@ def __init__(
if isinstance(bucketing_strategy, str)
else bucketing_strategy
)
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = _validate_executors(executors)
self._state_dict_type = state_dict_type
self._fsdp_kwargs = kwargs
Expand Down Expand Up @@ -115,16 +160,32 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> Module:
import thunder

module = thunder.distributed.fsdp(
module,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)

# NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call.
# we would still `jit(fsdp(undo_jit(jit(model))))` internally
if (cd := thunder.compile_data(module)) is not None:
# the module was already jitted
if thunder.compile_stats(module).last_traces is not None:
raise RuntimeError(
"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the"
" FSDP transform."
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
# modify the reference
cd.fn = thunder.distributed.fsdp(
cd.fn,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
return module
else:
module = thunder.distributed.fsdp(
module,
device=self.root_device,
sharding_strategy=self.sharding_strategy,
bucketing_strategy=self.bucketing_strategy,
**self._fsdp_kwargs,
)
if not self.jit:
return module
return thunder.jit(module, executors=self.executors)

@override
Expand Down
Loading