diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index f241fe227..84c961c77 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -64,7 +64,7 @@ from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall from pytato.tags import FunctionIdentifier -from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer +from pytato.transform import ArrayOrNames, CachedMapper, CachedWalkMapper, InputGatherer if TYPE_CHECKING: @@ -178,6 +178,33 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +class CodeCollectorMapper(CachedWalkMapper): + def __init__(self) -> None: + super().__init__() + from pytools import Table + self.code = [] + self.table = Table() + + def get_cache_key(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: + return expr + + def post_visit(self, expr): + non_equality_tags = expr.non_equality_tags + + code = _stringify_created_at(non_equality_tags).split("\n") + + # If traceback tag is not enabled + if len(code) < 2: + code.append("") + + # Deduplicate consecutive lines + if len(self.code) > 0 and code == self.code[-1]: + return + + self.code.append(code) + self.table.add_row((code[0][-10:], code[1])) + + class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() @@ -519,7 +546,9 @@ def _gather_partition_node_information( for part in partition.parts.values(): mapper = ArrayToDotNodeInfoMapper() + ccm = CodeCollectorMapper() for out_name in part.output_names: + ccm(partition.name_to_output[out_name]) mapper(partition.name_to_output[out_name]) part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot @@ -561,7 +590,7 @@ def gather_function_info(f: FunctionDefinition) -> None: part_id_to_func_to_id.setdefault(part.pid, {})[f] = fid - return part_id_to_func_to_id, part_id_func_to_node_info + return part_id_to_func_to_id, part_id_func_to_node_info, ccm.table # }}} @@ -610,7 +639,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # The "None" function is the body of the partition. - part_id_to_func_to_id, part_id_func_to_node_info = \ + part_id_to_func_to_id, part_id_func_to_node_info, code = \ _gather_partition_node_information(id_gen, partition) # }}} @@ -620,7 +649,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: emitted_placeholders = set() - emit_root("node [shape=rectangle]") + emit_root(f'node [shape=rectangle] tooltip="{dot_escape_leave_space(code.raw())}"') placeholder_to_id: dict[ArrayOrNames, str] = {} part_id_to_array_to_id: dict[PartId, dict[ArrayOrNames, str]] = {} @@ -683,7 +712,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # }}} - # {{{ emit receives nodes + # {{{ emit receive nodes part_dist_recv_var_name_to_node_id = {} for name, recv in (