Commit 4fd761f
[DTensor] Wrap sharding prop error with contextual exception (pytorch#161574)
Mainly, this helps tell the user more info about the operator that
failed to run if it fails during sharding propagation.
Previously, only this exception would be raised:
```
RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
```
Now you get both the above exception as well as
```
The above exception was the direct cause of the following exception:
RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2))
```
<stacktrace omitted>
<details><summary>detailed error</summary>
```
======================================================================
ERROR: test_linear (__main__.TestDTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 668, in wrapper
self._join_processes(fn)
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 932, in _join_processes
self._check_return_codes(fn, elapsed_time)
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 972, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 4 exited with error code 10 and exception:
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 150, in dispatch
self.sharding_propagator.propagate(op_info)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 309, in propagate
OutputSharding, self.propagate_op_sharding(op_info.schema)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
return self.cache(*args, **kwargs)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 329, in propagate_op_sharding_non_cached
op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 673, in reshape_strategy
input_tgt_placements, output_placements = propagate_shape_and_sharding(
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 601, in propagate_shape_and_sharding
in_dim = get_in_dim_to_shard(cmd)
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 537, in get_in_dim_to_shard
raise RuntimeError(
RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 816, in run_test
getattr(self, test_name)()
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 670, in wrapper
fn()
File "/data/users/whc/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper
method(*args, **kwargs)
File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 490, in wrapper
raise e
File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 487, in wrapper
func(self, *args, **kwargs) # type: ignore[misc]
File "/data/users/whc/pytorch/test.py", line 60, in test_linear
print("results: ", distributed_linear(distributed_input))
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/whc/pytorch/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
File "/data/users/whc/pytorch/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
File "/data/users/whc/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn
return fn(*args, **kwargs)
File "/data/users/whc/pytorch/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 163, in dispatch
raise RuntimeError(
RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2))
```
</details>
Pull Request resolved: pytorch#161574
Approved by: https://github.com/zpcore, https://github.com/XilunWu1 parent a8270dd commit 4fd761f
File tree
3 files changed
+8
-10
lines changed- test/distributed/tensor
- torch/distributed/tensor
3 files changed
+8
-10
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
724 | 724 | | |
725 | 725 | | |
726 | 726 | | |
727 | | - | |
| 727 | + | |
728 | 728 | | |
729 | 729 | | |
730 | 730 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
228 | 228 | | |
229 | 229 | | |
230 | 230 | | |
231 | | - | |
232 | | - | |
233 | | - | |
| 231 | + | |
234 | 232 | | |
235 | 233 | | |
236 | 234 | | |
237 | 235 | | |
238 | 236 | | |
239 | 237 | | |
240 | | - | |
241 | | - | |
242 | | - | |
| 238 | + | |
243 | 239 | | |
244 | 240 | | |
245 | 241 | | |
| |||
637 | 633 | | |
638 | 634 | | |
639 | 635 | | |
640 | | - | |
641 | | - | |
642 | | - | |
| 636 | + | |
643 | 637 | | |
644 | 638 | | |
645 | 639 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
159 | 159 | | |
160 | 160 | | |
161 | 161 | | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
162 | 166 | | |
163 | 167 | | |
164 | 168 | | |
| |||
0 commit comments