Skip to content

Commit

Permalink
case of dict in pipelineblock
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 5, 2024
1 parent 3967bee commit bcb94cc
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/nanotron/parallel/pipeline_parallel/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def forward(self, **kwargs):
if isinstance(tensor, TensorPointer):
# Current rank is neither the rank holding the data nor the rank responsible for computing block
continue
elif isinstance(tensor, dict):
for k, v in tensor.items():
if isinstance(v, torch.Tensor):
# We need to send the tensor to the rank that actually runs the compute
if self.pipeline_state is not None:
send_to_pipeline_state_buffer(
v,
to_rank=self.rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue

if v.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)

batch_send_recv.add_send(tensor=v, to_rank=self.rank)
else:
assert isinstance(tensor, torch.Tensor)
# We need to send the tensor to the rank that actually runs the compute
Expand Down

0 comments on commit bcb94cc

Please sign in to comment.