From e6b327e8c731db6a41440b8fbdb7a93709e14b18 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Sep 2025 15:30:40 +0000 Subject: [PATCH] Remove decomposition from softmax Taken from https://github.com/meta-pytorch/autoparallel/pull/3 and https://github.com/meta-pytorch/autoparallel/pull/29. Decomposing softmax_backward leads to prims.fma, which doesn't have a sharding rule and we end up having a Replicate showing up as only possible sharding --- autoparallel/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autoparallel/api.py b/autoparallel/api.py index e5da5d67..70e65dbe 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -60,6 +60,8 @@ def _get_decomp_table(): decomp_table.pop(torch.ops.aten.native_layer_norm.default) decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) + decomp_table.pop(torch.ops.aten._softmax_backward_data.default) + decomp_table.pop(torch.ops.aten._softmax.default) # decompose addmm to allow for TP on mm decomp_table.pop(torch.ops.aten.addmm.default)