Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions graph_net/paddle/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def translate_pir_program_to_sample_codes(
model_dump_path,
split_positions=None,
group_head_and_tail=True,
use_all_inputs=False,
):
ir_programs_path = self.get_ir_programs_path(model_dump_path)
example_inputs_path = self.get_example_inputs_path(model_dump_path)
Expand All @@ -218,6 +219,7 @@ def translate_pir_program_to_sample_codes(
op_example_inputs=op_example_inputs_path,
split_positions=split_positions,
group_head_and_tail=group_head_and_tail,
use_all_inputs=use_all_inputs,
eval_mode=True,
tmp_dir=model_dump_path,
)
Expand Down
23 changes: 13 additions & 10 deletions graph_net/paddle/graph_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def make_config(
self,
split_positions=None,
group_head_and_tail=False,
use_all_inputs=False,
chain_style=False,
output_dir="./tmp/naive_decomposer_dir",
post_extract_process_path=None,
Expand All @@ -42,6 +43,7 @@ def make_config(
return {
"split_positions": split_positions,
"group_head_and_tail": group_head_and_tail,
"use_all_inputs": use_all_inputs,
"chain_style": chain_style,
"output_dir": output_dir,
"post_extract_process_path": post_extract_process_path,
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
)
self.split_positions = self.config["split_positions"]
self.group_head_and_tail = self.config["group_head_and_tail"]
self.use_all_inputs = self.config["use_all_inputs"]
self.post_extract_process = self.make_post_extract_process(self.config)

def do_extract(self, **input_dict):
Expand All @@ -98,21 +101,22 @@ def do_extract(self, **input_dict):
model_dump_path,
split_positions=self.split_positions,
group_head_and_tail=self.group_head_and_tail,
use_all_inputs=self.use_all_inputs,
)

# 3. Save to model_path
self.subgraph_path_list = []
self.subgraph_path2subgraph_range = {}
model_path = os.path.join(
self.builtin_extractor.workspace_path, self.builtin_extractor.name
)
assert len(self.builtin_extractor.subgraph_idx2samples) == 1

samples = self.builtin_extractor.subgraph_idx2samples[0]
for seq_idx in range(len(samples)):
for seq_idx, sample in enumerate(samples):
subgraph_path = f"{model_path}_{seq_idx}"
self.subgraph_path_list.append(subgraph_path)
self.builtin_extractor.write_sample_to_file(subgraph_path, samples[seq_idx])
print(f"Save to {subgraph_path}")
self.subgraph_path2subgraph_range[subgraph_path] = sample.subgraph_range
self.builtin_extractor.write_sample_to_file(subgraph_path, sample)
print(f"[NaiveDecomposerExtractor] Save to {subgraph_path}")
return static_model

def __call__(self, **input_dict):
Expand All @@ -121,13 +125,12 @@ def __call__(self, **input_dict):
extracted_model = self.do_extract(**input_dict)
self.extracted = True

for subgraph_path in self.subgraph_path_list:
self._post_extract_process(subgraph_path)
for subgraph_path, subgraph_range in self.subgraph_path2subgraph_range.items():
return self.post_extract_process(
subgraph_path, subgraph_range, self.use_all_inputs
)
return extracted_model

def _post_extract_process(self, subgraph_path):
return self.post_extract_process(subgraph_path)

def make_post_extract_process(self, config):
if config.get("post_extract_process_path") is None:
return lambda *args, **kwargs: None
Expand Down
84 changes: 64 additions & 20 deletions graph_net/paddle/graph_meta_restorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,23 @@ def __init__(self, config, parent_model_path):
parent_input_meta_classes
)

def __call__(self, model_path):
def __call__(self, model_path, subgraph_range=None, use_all_inputs=False):
assert path_utils.is_single_model_dir(
model_path
), f"{model_path=} is not a graphnet sample."
if isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2:
use_all_inputs = subgraph_range[0] == 0 and use_all_inputs
else:
use_all_inputs = False

(
weight_meta_classes,
input_meta_classes,
) = self._load_weight_and_input_meta_classes(model_path)

assert self.config["update_inplace"]

# Restore weight_meta according to original_name.
(
is_weight_meta_fully_updated,
weight_meta_classes,
Expand All @@ -42,11 +49,19 @@ def __call__(self, model_path):
assert is_weight_meta_fully_updated
self._rewrite_meta_codes(model_path, weight_meta_classes, "weight_meta.py")

is_input_meta_fully_updated = self._update_by_tensor_spec(
input_meta_classes, self.original_name2parent_input_meta_class
)
# Restore input_meta according to name order or tensor spec (dtype and shape),
# because ordinary paddle.Tensor does not support user-defined names.
is_input_meta_fully_updated = False
if use_all_inputs:
is_input_meta_fully_updated = self._update_by_name_order(
input_meta_classes, self.original_name2parent_input_meta_class
)
if not is_input_meta_fully_updated:
is_input_meta_fully_updated = self._update_by_tensor_spec(
input_meta_classes, self.original_name2parent_input_meta_class
)
if (
not self.config["input_meta_allow_partial_update"]
self.config["input_meta_allow_partial_update"]
or is_input_meta_fully_updated
):
self._rewrite_meta_codes(model_path, input_meta_classes, "input_meta.py")
Expand All @@ -73,20 +88,25 @@ def _convert_to_dict(self, meta_classes):
original_name2meta_class[meta_class.original_name] = meta_class
return original_name2meta_class

def _update_tensor_meta(self, meta_class, parent_meta_class):
if (
parent_meta_class
and meta_class.dtype == parent_meta_class.dtype
def _has_same_tensor_spec(self, meta_class, parent_meta_class):
if meta_class is None or parent_meta_class is None:
return False
return (
meta_class.dtype == parent_meta_class.dtype
and meta_class.shape == parent_meta_class.shape
):
for attr_name in ["max_val", "min_val", "mean", "std", "data"]:
if hasattr(meta_class, attr_name) or hasattr(
parent_meta_class, attr_name
):
attr_value = getattr(parent_meta_class, attr_name, None)
setattr(meta_class, attr_name, attr_value)
return True
return False
)

def _update_tensor_meta(self, meta_class, parent_meta_class):
if not self._has_same_tensor_spec(meta_class, parent_meta_class):
return False

for attr_name in ["max_val", "min_val", "mean", "std", "data"]:
if hasattr(parent_meta_class, attr_name):
attr_value = getattr(parent_meta_class, attr_name)
setattr(meta_class, attr_name, attr_value)
elif hasattr(meta_class, attr_name):
delattr(meta_class, attr_name)
return True

def _update_by_original_name(self, meta_classes, original_name2parent_meta_class):
updated_class_names = set()
Expand Down Expand Up @@ -116,14 +136,38 @@ def _reorder_by_original_name(self, meta_classes, original_names):
)
return sorted_meta_classess

def _update_by_name_order(self, meta_classes, original_name2parent_meta_class):
parent_meta_classes = list(original_name2parent_meta_class.values())
if len(meta_classes) != len(parent_meta_classes):
return False

updated_meta_classes = []
name2meta_class = {meta_class.name: meta_class for meta_class in meta_classes}
same_in_order = all(
self._has_same_tensor_spec(
name2meta_class.get(parent_meta_class.name, None), parent_meta_class
)
for parent_meta_class in parent_meta_classes
)
if same_in_order:
for parent_meta_class in parent_meta_classes:
meta_class = name2meta_class[parent_meta_class.name]
if self._update_tensor_meta(meta_class, parent_meta_class):
updated_meta_classes.append(meta_class)
meta_classes[:] = updated_meta_classes

print(
f"[GraphMetaRestorer] {len(updated_meta_classes)}/{len(meta_classes)} classes can be restored."
)
return len(meta_classes) == len(updated_meta_classes)

def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class):
updated_class_names = set()
for meta_class in meta_classes:
matched_parent_meta_class = [
parent_meta_class
for parent_meta_class in original_name2parent_meta_class.values()
if meta_class.dtype == parent_meta_class.dtype
and meta_class.shape == parent_meta_class.shape
if self._has_same_tensor_spec(meta_class, parent_meta_class)
]
if len(matched_parent_meta_class) == 1:
self._update_tensor_meta(meta_class, matched_parent_meta_class[0])
Expand Down
Loading