Skip to content

Commit

Permalink
Merge branch 'main' into add_celu
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Oct 31, 2024
2 parents 381916d + efbfc8a commit 842b207
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
global_batch_size: int | None = None,
model_name: str = "Llama-2-7b-hf",
shard_mode: str = "zero2",
bucketing_mode: str = "none",
bucketing_mode: str | None = None,
sharding_size: int | None = None,
ddp_bucket_size: float = 256.0,
fsdp_bucket_params: float | None = None,
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(
world_size % self.sharding_size == 0
), f"World size {world_size} is not divisible by the sharding size {self.sharding_size}"

if self.bucketing_mode != "none" and self.distributed_mode not in FSDP_MODES:
if self.bucketing_mode is not None and self.distributed_mode not in FSDP_MODES:
warnings.warn(
f"--bucketing_mode {self.bucketing_mode} will be ignored as "
f" it is only used for FSDP style parallelism but running {self.distributed_mode}"
Expand Down Expand Up @@ -430,6 +430,7 @@ def setup_distributed(self, model):
from thunder.distributed import fsdp, FSDPType, FSDPBucketingStrategy

sharding_strategy = {"zero2": FSDPType.ZERO2, "zero3": FSDPType.ZERO3}[self.shard_mode]
self.bucketing_mode = self.bucketing_mode or "none"
bucketing_strategy = {
"none": FSDPBucketingStrategy.NONE,
"block": FSDPBucketingStrategy.BLOCK,
Expand Down Expand Up @@ -458,7 +459,7 @@ def setup_distributed(self, model):
from functools import partial
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

if self.bucketing_mode != "none":
if self.bucketing_mode is not None:
warnings.warn(f"fsdp2 ignores {self.bucketing_mode=}")

torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -506,6 +507,7 @@ def setup_distributed(self, model):
)
zero_bucket_wrap_policy = lambda module, recurse, nonwrapped_numel: nonwrapped_numel >= 0

self.bucketing_mode = self.bucketing_mode or "block"
custom_wrap_policy = {
"block": litgpt_auto_wrap_policy,
"size": size_auto_wrap_policy,
Expand Down Expand Up @@ -685,6 +687,23 @@ def train(self):
logits = self.model(input_ids)
else:
logits = self.model(input_ids)
# This information is accurate only in the case when torch.compile
# uses a single graph for the entire forward pass In the case of
# torch.compile using multiple graphs, the saved_tensors will be
# from the last graph For Thunder, the saved_tensors will be from
# its only graph There's no easy way to get the saved_tensors from
# the entire forward pass in the case of multiple graphs or PyTorch
# Eager mode. It's still useful for single-GPU and single-graph
# cases.
saved_tensors = getattr(logits.grad_fn, "saved_tensors", None)
saved_tensors_len = None
saved_tensors_size_in_mib = None
if saved_tensors:
saved_tensors_len = len([t for t in saved_tensors if t is not None])
saved_tensors_size_in_mib = (
sum(t.numel() * t.element_size() for t in saved_tensors if t is not None) / 1024**2
)
del saved_tensors
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss = (
Expand Down Expand Up @@ -725,6 +744,8 @@ def train(self):
if global_rank in [0, None]:
# print(f"Total time: {(t1 - t0):.2f}s")
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iters)
self.perf_metrics["saved_for_backward_tensor_size_mib"] = saved_tensors_size_in_mib
self.perf_metrics["saved_for_backward_number_of_tensors"] = saved_tensors_len

def add_perf_metrics(self):
if self.throughput:
Expand Down Expand Up @@ -817,6 +838,12 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None

print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
if benchmark.perf_metrics["saved_for_backward_tensor_size_mib"] is not None:
print(f"Saved for backward size: {benchmark.perf_metrics['saved_for_backward_tensor_size_mib']:.02f} MiB")
print(
f"Saved for backward number of tensors: {benchmark.perf_metrics['saved_for_backward_number_of_tensors']}"
)

if "tokens_per_sec" in benchmark.perf_metrics:
print(f"Tokens/s: {benchmark.perf_metrics.get['tokens_per_sec']:.02f}")
print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
Expand Down

0 comments on commit 842b207

Please sign in to comment.