Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2372,9 +2372,7 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
except NotImplementedError:
plus, minus = {}, set()
elif isinstance(node, Assignment_Stmt):
lv, op, rv = node.children
_inject_knowns(lv, value=False, pointer=True)
_inject_knowns(rv)
_inject_knowns(node)
lv, op, rv = node.children
lspec, ltyp = None, None
if isinstance(lv, Name):
Expand Down Expand Up @@ -2469,13 +2467,14 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
do_ops = node.children[1:-1]
has_pointer_asgns = bool(walk(node, Pointer_Assignment_Stmt))

# Everything updated in the body will not be constant outside
net_tpm = set()
for op in do_ops:
tp, tm = _track_local_consts(op, alias_map, {}, set())
net_tpm.update(tp.keys())
net_tpm.update(tm)
tp, tm = _track_local_consts(do_ops, alias_map, {}, set())
net_tpm.update(tp.keys())
net_tpm.update(tm)
loop_control = singular(children_of_type(do_stmt, Loop_Control))
_, cntexpr, _, _ = loop_control.children
# The loop variable is not constant
if cntexpr:
loopvar, _ = cntexpr
loopvar_spec = _find_real_ident_spec(loopvar, alias_map)
Expand All @@ -2494,6 +2493,8 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
net_tpm | minus)
_integrate_subresults(tp, tm)

# Loop var cannot be assumed to be generally constant after the loop
# TODO: if it's strictly a for loop, we know the value after the loop
_, loop_ctl = do_stmt.children
_, loop_var, _, _ = loop_ctl.children
if loop_var:
Expand Down
15 changes: 15 additions & 0 deletions tests/fortran/ast_desugaring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,15 @@ def test_exploit_locally_constant_variables():
i = i + 1
end do

! A simple do loop with `i` as a loop variable, cond is known.
! The first reference to out is known.
do i=1, 5
out = 14.4
if (cond) then
out = out * i
end if
end do

! Just making sure that `cond` is still known after all the loops.
if (cond) out = out + 1.

Expand Down Expand Up @@ -2361,6 +2370,12 @@ def test_exploit_locally_constant_variables():
out = out + 1
i = i + 1
END DO
DO i = 1, 5
out = 14.4
IF (.TRUE.) THEN
out = 14.4 * i
END IF
END DO
IF (.TRUE.) out = out + 1.
IF (.TRUE.) THEN
cond = .TRUE.
Expand Down