Skip to content

Commit

Permalink
fix xcom pull for tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Forthoney committed Jan 28, 2024
1 parent bce7578 commit 400dbf8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
6 changes: 6 additions & 0 deletions compiler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ def add_experimental_args(self):
action="store_true",
default=False,
)
self.add_argument(
"--airflow",
help="(experimental) airflowify the script",
action="store_true",
default=False,
)


class RunnerParser(BaseParser):
Expand Down
37 changes: 20 additions & 17 deletions compiler/shell_ast/transformation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,40 +213,43 @@ def _ast_to_airflow(self, ast):
elif isinstance(ast, IfNode):
return dedent(
f"""
cond_{id} = BashOperator(task_id='cond_{id}', bash_command='{ast.cond.pretty()}'), xcom_push=True
cond_{id} = BashOperator(task_id='cond_{id}', bash_command='{ast.cond.pretty()}', xcom_push=True)
@task.branch(task_id='branch_{id}')
def branch_func(ti=None):
xcom_value = bool(ti.xcom_pull(task_ids='cond_{id}'))
if xcom_value:
return 'then_{id}'
else:
return 'else_{id}'
def branch_func_{id}(input):
if input:
return 'then_{id}'
else:
return 'else_{id}'
then_{id} = BashOperator(task_id='then_{id}', bash_command='{ast.then_b.pretty()}')
else_{id} = BashOperator(task_id='else_{id}', bash_command='{ast.else_b.pretty()}')
branch_func_{id}(cond_{id}.output)
"""
)
elif isinstance(ast, OrNode):
return dedent(
f"""
cond_{id}= BashOperator(task_id='cond_task', bash_command='{ast.left_operand.pretty()}'), xcom_push=True
cond_{id}= BashOperator(task_id='cond_task', bash_command='{ast.left_operand.pretty()}', xcom_push=True)
@task.branch(task_id='branch_{id}')
def branch_func(ti=None):
xcom_value = bool(ti.xcom_pull(task_ids='cond_{id}'))
if not xcom_value:
return 'else_{id}'
def branch_func_{id}(input):
if not input:
return 'else_{id}'
else_{id}= BashOperator(task_id='else_{id}', bash_command='{ast.right_operand.pretty()}')
branch_func_{id}(cond_{id}.output)
"""
)
elif isinstance(ast, AndNode):
return dedent(
f"""
cond_{id}= BashOperator(task_id='cond_task', bash_command='{ast.left_operand.pretty()}'), xcom_push=True
cond_{id}= BashOperator(task_id='cond_task', bash_command='{ast.left_operand.pretty()}', xcom_push=True)
@task.branch(task_id='branch_{id}')
def branch_func(ti=None):
xcom_value = bool(ti.xcom_pull(task_ids='cond_{id}'))
if xcom_value:
return 'then_{id}'
def branch_func_{id}(input):
if input:
return 'then_{id}'
then_{id}= BashOperator(task_id='then_{id}', bash_command='{ast.right_operand.pretty()}')
branch_func_{id}(cond_{id}.output)
"""
)
elif isinstance(ast, PipeNode):
Expand Down

0 comments on commit 400dbf8

Please sign in to comment.