Skip to content

Commit

Permalink
removing the fuse distributed ops lowering pass for tegra platforms
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Feb 24, 2025
1 parent 0a46392 commit 0d13926
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
fuse_prims_broadcast,
fuse_distributed_ops,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
]
)
pass_list = [
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
fuse_prims_broadcast,
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
]

if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
pass_list.append(fuse_distributed_ops)

ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)

ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand Down

0 comments on commit 0d13926

Please sign in to comment.