diff --git a/autoparallel/shardings/ordered_sharding.py b/autoparallel/shardings/ordered_sharding.py index 5298a505..60df8909 100644 --- a/autoparallel/shardings/ordered_sharding.py +++ b/autoparallel/shardings/ordered_sharding.py @@ -94,7 +94,7 @@ def get_redistributed_input_placements( x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node) ] num_input_nodes = len(all_input_nodes) - curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]] = [ + curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...], None]] = [ sharding_placement[n].output_specs for n in all_input_nodes ] # FIXME ? if node.target == operator.getitem: @@ -186,7 +186,7 @@ def compute_optimal_placement_order_for_parameters(module, sharding_placement): user_src_placement = list(d.values())[0][0] mesh_ndim = len(user_src_placement) - param_grad_map = {p: g for p, g in param_and_grad_nodes} + param_grad_map = dict(param_and_grad_nodes) aligned_pg = [] for param_or_grad_node in redistribution_map.keys(): # just allow for arbitrary execution order if both param and grad