Skip to content

Commit

Permalink
DDP as a transform (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgonmar authored Aug 4, 2024
1 parent 510d8bf commit 0f34c06
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 26 deletions.
22 changes: 22 additions & 0 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ def prep_shard(
)


# When the user calls ddp(jitted_module), this function does the following
# - Marks the original function with appropiate attributes (use_ddp...)
# - Broadcasts parameters if necessary
# - It then registers a transform (callback that runs before prologue is executed) that transforms the
# prologue and compute trace, that insert syncs (and grad syncs for the backward, handled by thunder automatically.)


# TODO Verify parameters are not partially initialized
# TODO Handle buffers
# TODO Improve initial broadcast logic
Expand Down Expand Up @@ -286,6 +293,21 @@ def main():
tdist.is_available(),
lambda: "ddp requires torch distributed to be available (but it's not)",
)
from thunder.core.module import ThunderModule

if isinstance(model, ThunderModule):
from thunder.distributed.transforms.ddp_v2 import DDPTransform
from thunder.core.transforms import add_transform

process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
# will insert syncs for parameters (and gradient syncs in the backward pass, this is handled by thunder)
# usually, other transforms will remove the forward syncs inserted by this transform.
transform_from_trace_to_ddp_trace = DDPTransform(
process_group=process_group, bucket_size_in_mb=bucket_size_in_mb, broadcast_from=broadcast_from
)
model_new = add_transform(model, transform=transform_from_trace_to_ddp_trace)
return model_new

pg = copy_default_process_group()
utils.check(pg is not None, lambda: "The default process group is None")
Expand Down
182 changes: 182 additions & 0 deletions thunder/distributed/transforms/ddp_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING

from thunder.core import devices
from thunder.core import prims
from thunder.core import utils
from thunder.core.proxies import DistParallelType
from thunder.core.trace import from_trace
from thunder.core.trace import tracectx
from thunder.core.trace import TraceProvenance
from thunder.core.transform_common import Transform
from thunder.core.module import ThunderModule
from thunder.distributed import copy_default_process_group
import torch
from torch.utils.weak import WeakTensorKeyDictionary
import torch.distributed as tdist
import copy

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from thunder.core.trace import TraceCtx


@dataclass
class DDPTransform(Transform):
process_group: ProcessGroup
bucket_size_in_mb: float
broadcast_from: int | None

replicated_params: dict[str, torch.nn.Parameter] | None = None
shared_params_name: dict[str, str] | None = None

def transform_module(self, model: ThunderModule):
"""Transforms the ThunderModule. This is executed once on application of the transform"""
from thunder import compile_data as get_compile_data
from thunder.core.module import ThunderModule

process_group = self.process_group
cd = get_compile_data(model)
cd.use_ddp = True
cd.process_group_for_ddp = process_group
orig_module: torch.nn.Module = cd.fn
utils.check(
isinstance(orig_module, torch.nn.Module) and not isinstance(orig_module, ThunderModule),
lambda: f"CompileData.fn expected to be `nn.Module` but {type(orig_module)}",
)
orig_module.use_ddp = True
orig_module.process_group_for_ddp = process_group
orig_module.bucket_size_in_mb = self.bucket_size_in_mb

replicated_params = {}
# We use `shared_params` dictionary to track the shared parameters.
# Key to this dictionary is the original parameter from the user's Module.
# Values are the copied and sharded parameter for the thunder module and meta-data related to sharding.
shared_params = WeakTensorKeyDictionary()

# NOTE: Shared Parameters in Trace
# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
# For shared parameter, we replace all occurence shared parameter with it's corresponding `base` parameter.
# In our implementation `base` parameter is the parameter and corresponding name which we see the first time while
# iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter
# in `shared_params_name` dictionary.
# Then while, transforming the trace - `see DDPTraceTransform.transform_traces` - we replace all the proxy of shared parameter
# with the corresponding proxy of base parameter in the computation trace.

# This is used to track the shared parameters when the transform is applied.
# key - parameter name, value - `base` parameter name.
shared_params_name: dict[str, str] = {}
for module_name, _ in model._model.named_modules():
submodule = model.get_submodule(module_name)
# Since we are doing no sharding, we do not need to materialize the params

# Broadcast parameters if requested
if broadcast_from := self.broadcast_from is not None:
for pn, _ in submodule.named_parameters(recurse=False, prefix=module_name):
tdist.broadcast(model.get_parameter(pn), src=broadcast_from, group=process_group, async_op=False)
for pn, _ in submodule.named_buffers(recurse=False, prefix=module_name):
tdist.broadcast(model.get_buffer(pn), src=broadcast_from, group=process_group, async_op=False)

for pn, p in submodule.named_parameters(recurse=False, prefix=module_name):
# If there are shared params in the original user Module, we reuse the sharded copy created from the original parameter below.
# This way we re-create parameter sharing in thunder's copy of the Module.
if p in shared_params:
# Shared param names : current param - base param
shared_params_name[pn] = shared_params[p]["param_name"]
# Re-use the previous copy of this parameter.
model._overrides_parameters[pn] = shared_params[p]["param_copy"]
replicated_params[pn] = shared_params[p]["param_meta"]
continue

model._overrides_parameters[pn] = copy.copy(p)
# we collect shapes and devices because we do not know if other transforms also change it...
shape = model._overrides_parameters[pn].shape
replicated_params[pn] = (shape, model._overrides_parameters[pn].device)

# Track param information
shared_params[p] = {
"param_copy": model._overrides_parameters[pn],
"param_meta": replicated_params[pn],
"param_name": pn,
}
self.shared_params_name = shared_params_name
self.replicated_params = replicated_params

def transform_traces_pre_prologue(
self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx, **kwargs
):
assert (
self.replicated_params is not None and self.shared_params_name is not None
), "expected transform_module to have run"

from thunder.distributed import prims as dist_prims

prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace)

modules_and_thunder_modules = [
(bsym.args[0], bsym.output)
for bsym in prologue_trace.bound_symbols
if bsym.sym is prims.unpack_thunder_module
]

if len(modules_and_thunder_modules) != 1:
raise NotImplementedError("cannot deal with modules other than the compiled module")

((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules
if prologue_producers[orig_module_proxy].sym is not prims.unpack_function_obj:
raise NotImplementedError("original module does not match the compiled module")

computation_trace.push_scope([])

synchronized_parameters = []
param_name_to_comp_trc_proxy = {} # Track param_name to it's corresponding proxy in computation_trc.
for pro_out_p, comp_inp_p in zip(prologue_trace.output[0], computation_trace.args):
bsym = prologue_producers[pro_out_p]
# if the bsym is an unpack_parameter prim, we need to mark it as REPLICATED (ddp)
# and insert a sync (then, backward pass will be handled automatically)
if bsym.sym == prims.unpack_parameter:
param_thunder_module, param_name = bsym.args
assert param_thunder_module is thunder_module_proxy
if param_name in self.replicated_params:
param_name_to_comp_trc_proxy[param_name] = comp_inp_p
shape, torch_device = self.replicated_params[param_name]
thunder_device = devices.to_device(torch_device)
pro_out_p._distparallel_type = DistParallelType.REPLICATED
pro_out_p._device = thunder_device
if comp_inp_p is not pro_out_p:
comp_inp_p._distparallel_type = DistParallelType.REPLICATED
comp_inp_p._device = thunder_device
with tracectx(computation_trace):
# we will produce a new trace with syncs before using the weights
# then, the backward sync will be automatically handled by thunder (inserting all_reduce for the gradients)
# then, syncs will be removed from the forward pass (as expected, since they are not needed)
synchronized_parameters.append(dist_prims.synchronize(comp_inp_p, self.process_group))
new_scope = computation_trace.pop_scope()
# map of param -> synced param
proxies_to_replace = {id(bsym.args[0]): bsym.output for bsym in new_scope}

# See NOTE: Shared Parameters in Trace
for param_name, base_param in self.shared_params_name.items():
param_proxy = param_name_to_comp_trc_proxy[param_name]
base_param_proxy = param_name_to_comp_trc_proxy[base_param]
synced_base_param_proxy = proxies_to_replace[id(base_param_proxy)]
# Update `proxies_to_replace` so we replace all usage of `param_proxy`
# with the output of the synced param on `base_param_proxy`.
proxies_to_replace[id(param_proxy)] = synced_base_param_proxy

new_computation_trace = from_trace(computation_trace)
for idx, bsym in enumerate(computation_trace.bound_symbols):
if bsym.sym != prims.unpack_trivial:
break
new_computation_trace.bound_symbols.append(bsym.from_bsym(args=bsym.args))

new_computation_trace.bound_symbols += new_scope
for bsym in computation_trace.bound_symbols[idx:]:
# replace param by synced_param
new_args = tuple(proxies_to_replace.get(id(a), a) for a in bsym.args)
new_computation_trace.bound_symbols.append(bsym.from_bsym(args=new_args))

new_computation_trace.set_provenance(TraceProvenance("ddp pass"))

return prologue_trace, new_computation_trace, epilogue_trace
135 changes: 109 additions & 26 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,32 @@
class DDPTest(DistributedParallelTestCase):
# Reference issue "Add an example of DDP(compile(model)) to tests"
def test_ddp_compile_module(self):
model = ToyModel().to(self.rank)
ddp_model = DDP(thunder.jit(model, device_ids=[self.rank]))

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

# Asserts that DDPing a jitted model yields the same results as raw torch DDP.
initial_model_state = ToyModel().state_dict()
ddp_fns = [
lambda model: DDP(thunder.jit(model)),
lambda model: ddp(thunder.jit(model)),
]
x, labels = torch.randn(20, 12).to(self.rank), torch.randn(20, 8).to(self.rank)

init_loss, last_loss = None, None
for i in range(3):
optimizer.zero_grad()
outputs = ddp_model(x)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if i == 0:
init_loss = loss.detach().item()
if i == 2:
last_loss = loss.detach().item()
assert init_loss > last_loss
def _get_last_loss(fn):
model = ToyModel().to(self.rank)
model.load_state_dict(initial_model_state)
ddp_model = fn(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
for i in range(3):
optimizer.zero_grad()
outputs = ddp_model(x)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
return loss

raw_ddp_loss = _get_last_loss(lambda model: DDP(model))
for fn in ddp_fns:
loss = _get_last_loss(fn)
self.assertEqual(loss, raw_ddp_loss)

# Reference issue "[tracker] Support DistributedDataParallel"
def test_compile_ddp_module(self):
Expand Down Expand Up @@ -199,6 +205,78 @@ def fwd_loss(m, x):
with thunder.ThunderModule.no_sync(model):
fwd_loss(model, x)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 devices")
def test_ddp_weight_sharing(self):
# This test is to verify that weight sharing works with ddp.
device = torch.device("cuda", self.rank)

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(16, 16, bias=False)
self.fc2 = torch.nn.Linear(16, 16, bias=False)

def forward(self, x):
return self.fc1(x) + self.fc2(x)

def _test_model_output_and_gradients(model, x):
output = model(x)
with device:
grad_output = torch.ones_like(output)
output.backward(grad_output)
expected_shape = (4, 16)

assert output.shape == expected_shape, f"{output.shape=} - {expected_shape=}"

# Verify that both params point to same grad tensor.
assert id(model.get_parameter("fc1.weight").grad) == id(model.get_parameter("fc2.weight").grad)

# Verify that we accumulate the gradients for the shared parameter.
actual_grad = model.get_parameter("fc1.weight").grad
# Based on the forward, grad for both params is `(grad_output.T @ x)`. Multiplying by 2 as the grad will be accumulated.
expected_grad = 2 * (grad_output.T @ x)
torch.testing.assert_close(actual_grad, expected_grad)

forward_exec_trace = thunder.last_traces(model)[-1]
n_synced_params_forward = 0
for bsym in forward_exec_trace.bound_symbols:
if bsym.sym.id in (thunder.distributed.prims.PrimIDs.SYNCHRONIZE,):
n_synced_params_forward += 1
assert (
n_synced_params_forward == 0
) # Assert that no params were synced on forward (they should be removed by later transforms)

backward_exec_trace = thunder.last_backward_traces(model)[-1]
allreduced_grads = 0
for bsym in backward_exec_trace.bound_symbols:
if bsym.sym.id in (
thunder.distributed.prims.PrimIDs.ALL_REDUCE,
thunder.executors.torchex.all_reduce_prim_impl.id,
):
allreduced_grads += 1

# The expected behaviour is that the gradients were accumulated (since both weights are the same) and then allreduced, so only one allreduce
assert allreduced_grads == 1

with device:
jit_ddp_model = Model()
ddp_jit_model = Model()
x = torch.ones(4, 16)

# Check `jit(ddp(model))` works
jit_ddp_model.fc1.weight = jit_ddp_model.fc2.weight

jit_ddp_model = thunder.jit(thunder.distributed.ddp(jit_ddp_model), executors=["torch"])

_test_model_output_and_gradients(jit_ddp_model, x)

# Check `ddp(jit(model))` works
ddp_jit_model.fc1.weight = ddp_jit_model.fc2.weight

ddp_jit_model = thunder.distributed.ddp(thunder.jit(ddp_jit_model, executors=["torch"]))

_test_model_output_and_gradients(ddp_jit_model, x)


common_utils.instantiate_parametrized_tests(DDPTest)

Expand All @@ -208,7 +286,7 @@ def fwd_loss(m, x):
def _test_native_ddp_helper(input_data):
init_method, world_size, rank, executor, device, dtype, kwargs = input_data
bucket_size_in_mb = kwargs.get("bucket_size_in_mb", 0)

jit_first = kwargs.get("jit_first", False)
num_samples = 2
tensor_shape = (2, 2)
sample_seed = 3456
Expand All @@ -231,11 +309,13 @@ def _test_native_ddp_helper(input_data):

# Creates, compiles, and DDPs the model
model = SmallModel(device, torch_dtype)
ddp_model = ddp(model)
cmodel = thunder.jit(
ddp_model,
executors=executor.executors_list(),
)
if jit_first:
cmodel = ddp(thunder.jit(model, executors=executor.executors_list()))
else:
cmodel = thunder.jit(
ddp(model),
executors=executor.executors_list(),
)

comparison_exceptions = []
for _ in range(num_epochs):
Expand Down Expand Up @@ -515,10 +595,13 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
num_devices=2,
# CPU broke around PyTorch 2.3.1, see PR #545
devicetypes=(devices.DeviceType.CUDA,),
decorators=(pytest.mark.parametrize("bucket_size_in_mb", (0, 25)),),
decorators=(
pytest.mark.parametrize("bucket_size_in_mb", (0, 25)),
pytest.mark.parametrize("jit_first", (True, False)),
),
)
@distributed_wrapper("test_native_ddp", _test_native_ddp_helper)
def test_native_ddp(executor, devices, dtype, bucket_size_in_mb):
def test_native_ddp(executor, devices, dtype, bucket_size_in_mb, jit_first):
pass


Expand Down

0 comments on commit 0f34c06

Please sign in to comment.