diff --git a/services/ui_backend_service/data/cache/custom_flowgraph.py b/services/ui_backend_service/data/cache/custom_flowgraph.py index 7ec6ff2b..cc67f403 100644 --- a/services/ui_backend_service/data/cache/custom_flowgraph.py +++ b/services/ui_backend_service/data/cache/custom_flowgraph.py @@ -1,11 +1,49 @@ import ast -from metaflow.graph import deindent_docstring, DAGNode +import re # NOTE: This is a custom implementation of the FlowGraph class from the Metaflow client # which can parse a graph out of a flow_name and a source code string, instead of relying on # importing the source code as a module. +def deindent_docstring(doc): + if doc: + # Find the indent to remove from the docstring. We consider the following possibilities: + # Option 1: + # """This is the first line + # This is the second line + # """ + # Option 2: + # """ + # This is the first line + # This is the second line + # """ + # Option 3: + # """ + # This is the first line + # This is the second line + # """ + # + # In all cases, we can find the indent to remove by doing the following: + # - Check the first non-empty line, if it has an indent, use that as the base indent + # - If it does not have an indent and there is a second line, check the indent of the + # second line and use that + saw_first_line = False + matched_indent = None + for line in doc.splitlines(): + if line: + matched_indent = re.match("[\t ]+", line) + if matched_indent is not None or saw_first_line: + break + saw_first_line = True + if matched_indent: + return re.sub(r"\n" + matched_indent.group(), "\n", doc).strip() + else: + return doc + else: + return "" + + class StepVisitor(ast.NodeVisitor): def __init__(self, nodes): @@ -20,6 +58,93 @@ def visit_FunctionDef(self, node): self.nodes[node.name] = DAGNode(node, decos, doc if doc else '') +class DAGNode(object): + def __init__(self, func_ast, decos, doc): + self.name = func_ast.name + self.func_lineno = func_ast.lineno + self.decorators = decos + self.doc = deindent_docstring(doc) + self.parallel_step = any(getattr(deco, "IS_PARALLEL", False) for deco in decos) + + # these attributes are populated by _parse + self.tail_next_lineno = 0 + self.type = None + self.out_funcs = [] + self.has_tail_next = False + self.invalid_tail_next = False + self.num_args = 0 + self.foreach_param = None + self.num_parallel = 0 + self.parallel_foreach = False + self._parse(func_ast) + + # these attributes are populated by _traverse_graph + self.in_funcs = set() + self.split_parents = [] + self.matching_join = None + # these attributes are populated by _postprocess + self.is_inside_foreach = False + + def _expr_str(self, expr): + return "%s.%s" % (expr.value.id, expr.attr) + + def _parse(self, func_ast): + self.num_args = len(func_ast.args.args) + tail = func_ast.body[-1] + + # end doesn't need a transition + if self.name == "end": + # TYPE: end + self.type = "end" + + # ensure that the tail an expression + if not isinstance(tail, ast.Expr): + return + + # determine the type of self.next transition + try: + if not self._expr_str(tail.value.func) == "self.next": + return + + self.has_tail_next = True + self.invalid_tail_next = True + self.tail_next_lineno = tail.lineno + self.out_funcs = [e.attr for e in tail.value.args] + + keywords = dict( + (k.arg, getattr(k.value, "s", None)) for k in tail.value.keywords + ) + if len(keywords) == 1: + if "foreach" in keywords: + # TYPE: foreach + self.type = "foreach" + if len(self.out_funcs) == 1: + self.foreach_param = keywords["foreach"] + self.invalid_tail_next = False + elif "num_parallel" in keywords: + self.type = "foreach" + self.parallel_foreach = True + if len(self.out_funcs) == 1: + self.num_parallel = keywords["num_parallel"] + self.invalid_tail_next = False + elif len(keywords) == 0: + if len(self.out_funcs) > 1: + # TYPE: split + self.type = "split" + self.invalid_tail_next = False + elif len(self.out_funcs) == 1: + # TYPE: linear + if self.name == "start": + self.type = "start" + elif self.num_args > 1: + self.type = "join" + else: + self.type = "linear" + self.invalid_tail_next = False + except AttributeError: + return + + class FlowGraph(object): # NOTE: This implementation relies on passing in the name of the FlowSpec class # to be parsed from the sourcecode.