Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for more complex parallelism strategies to benchmark_litgpt.py #77

Merged
merged 15 commits into from
Mar 27, 2024
Merged
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 73 additions & 7 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
from torch.utils.data import DataLoader, IterableDataset
import torch.distributed as torch_dist
from torch.distributed.device_mesh import init_device_mesh

import thunder
from thunder.tests.litgpt_model import Config, GPT, Block
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(
shard_mode: str = "zero2",
bucketing_mode: str = "none",
sharding_size: int | None = None,
ddp_bucket_size: float = 256.0,
fsdp_bucket_params: float | None = None,
n_layers: int | None = None,
profiler_start: int = 15,
profiler_stop: int = 15,
Expand Down Expand Up @@ -74,7 +77,46 @@ def __init__(
self.shard_mode = shard_mode
self.bucketing_mode = bucketing_mode
self.sharding_size = sharding_size
self.ddp_bucket_size = ddp_bucket_size
self.fsdp_bucket_params = fsdp_bucket_params
self.micro_batch_size = micro_batch_size

# Clarify benchmark assumptions
if self.sharding_size is not None:
assert (
"thunder" not in self.compile
), "Hybrid Sharding (FSDP/DP) using --sharding_size is not yet supported for Thunder. Coming soon."

assert self.shard_mode in [
"hybrid_zero2",
"hybrid_zero3",
], "Sharding Size is only used with Hybrid FSDP/DP style parallelism. Please "
parthmannan marked this conversation as resolved.
Show resolved Hide resolved

assert (
world_size % self.sharding_size == 0
), f"World size {world_size} is not divisible by Hybrid Sharding Size {self.sharding_size}"
parthmannan marked this conversation as resolved.
Show resolved Hide resolved

if self.bucketing_mode != "none" and self.distributed_mode != "fsdp":
print(
f"[WARNING] --bucketing_mode {self.bucketing_mode} will be ignored as \
it is only used for FSDP style parallelism but running {self.distributed_mode}"
)

assert not (
"thunder" in self.compile and self.bucketing_mode == "size"
), "'size' bucketing mode is not supported for Thunder. Please use 'none' or 'block'."

if self.fsdp_bucket_params is not None:
if self.distributed_mode != "fsdp":
print(
f"[WARNING] Found --fsdp_bucket_params but Distributed mode is {self.distributed_mode}. Will be ignnored"
parthmannan marked this conversation as resolved.
Show resolved Hide resolved
)

if self.bucketing_mode != "size":
print(
f"[WARNING] Bucketing mode is set to {self.bucketing_mode}. --fsdp_bucket_params will be ignoted."
parthmannan marked this conversation as resolved.
Show resolved Hide resolved
)

if global_batch_size is not None:
self.global_batch_size = global_batch_size
else:
Expand Down Expand Up @@ -153,7 +195,7 @@ def setup_distributed(self):
model = ddp(
self.model,
broadcast_from=0,
bucket_size_in_mb=256.0,
bucket_size_in_mb=self.ddp_bucket_size,
)
elif self.distributed_mode == "fsdp":
from thunder.distributed import fsdp, FSDPType, FSDPBucketingStrategy
Expand All @@ -173,26 +215,44 @@ def setup_distributed(self):
model = torch.nn.parallel.DistributedDataParallel(
self.model,
device_ids=[local_rank],
bucket_cap_mb=256.0,
bucket_cap_mb=self.ddp_bucket_size,
)
elif self.distributed_mode == "fsdp":
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, size_based_auto_wrap_policy

mesh = None
if self.sharding_size is not None:
mesh = init_device_mesh("cuda", (int(world_size / self.sharding_size), self.sharding_size))

litgpt_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
size_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.fsdp_bucket_params
)
zero_bucket_wrap_policy = lambda module, recurse, nonwrapped_numel: nonwrapped_numel >= 0

custom_wrap_policy = {
"block": litgpt_auto_wrap_policy,
"size": size_auto_wrap_policy,
"none": zero_bucket_wrap_policy,
}[self.bucketing_mode]

sharding_strategy: ShardingStrategy = {
"zero2": ShardingStrategy.SHARD_GRAD_OP,
"zero3": ShardingStrategy.FULL_SHARD,
"hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
"hybrid_zero3": ShardingStrategy.HYBRID_SHARD,
}[self.shard_mode]

# AssertionError: Dynamo only supports FSDP with use_orig_params=True
torch.cuda.set_device(local_rank)
model = FSDP(
self.model,
sharding_strategy=sharding_strategy,
auto_wrap_policy=litgpt_auto_wrap_policy,
auto_wrap_policy=custom_wrap_policy,
device_id=local_rank,
use_orig_params=True,
device_mesh=mesh,
)
return model

Expand Down Expand Up @@ -409,9 +469,15 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
)
print(f"Distributed Mode: {benchmark.distributed_mode}")
if benchmark.distributed_mode == "fsdp":
print(
f"Sharding Mode: {benchmark.shard_mode}\nSharding Size: {benchmark.sharding_size}\nBucketing: {benchmark.bucketing_mode}"
)
print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}")
if benchmark.sharding_size is not None:
print(
f"Sharding Size: {benchmark.sharding_size}\nReplicate DP Groups: {int(world_size/benchmark.sharding_size)}"
)
if benchmark.bucketing_mode == "size":
print(f"Bucketing Number Params: {benchmark.fsdp_bucket_params}")
elif benchmark.distributed_mode == "ddp":
print(f"DDP Bucketing Size: {benchmark.ddp_bucket_size} MB")
print(f"Compiler: {benchmark.compile}")
print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
Expand Down
Loading