Skip to content

Commit

Permalink
Change compile for pipeline module torch.compile (#6478)
Browse files Browse the repository at this point in the history
We have encountered and issue with torch.compile and the pipeline
module.
modifying a member of the module (micro_offset) during the forward
function will cause torch compile to restart the analysis and treat the
module as dynamic.
In order to bypass this issue without significantly changing the way the
pipeline module works we propose to compile only the layers in the
pipeline module instead of the forward function of pipeline module. this
will bypass the issue and should still give most of the benefit of torch
compiling the pipeline module while avoiding the issue.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
NirSonnenschein and loadams authored Dec 30, 2024
1 parent cc03c76 commit 3573858
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 8 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,11 @@ def get_additional_losses(self):
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None

def compile(self, *args, **kwargs):
for idx, layer in enumerate(self.forward_funcs):
if isinstance(layer, nn.Module):
layer.compile(*args, **kwargs)
else:
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
8 changes: 6 additions & 2 deletions tests/unit/pipe/test_pipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def batch_input():

class TestPipeModuleSequential(DistributedTest):
world_size = 2
# needs to be set for torch.compile: running torch.compile with daemonic process causes an error
non_daemonic_procs = True

@pytest.mark.parametrize("activation_checkpoints", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
@pytest.mark.parametrize("use_compile", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
Expand All @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi

pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

if (use_compile):
pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
Expand Down

0 comments on commit 3573858

Please sign in to comment.