Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 239 additions & 22 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def _dataref_root(dref: Union[Name, Data_Ref, Data_Pointer_Object], scope_spec:
root_type = find_type_dataref(root, scope_spec, alias_map)
elif isinstance(root, Part_Ref):
root_type = find_type_dataref(root, scope_spec, alias_map)
assert root_type
assert root_type, f"cannot find type: {root} in {scope_spec}"

return root, root_type, rest

Expand Down Expand Up @@ -880,6 +880,103 @@ def find_dataref_component_spec(dref: Union[Name, Data_Ref], scope_spec: SPEC, a
return comp_spec


# TODO: Consider merging functionality with find_dataref_component_spec
def find_indexed_dataref_component_spec(dref: Union[Name, Data_Ref],
scope_spec: SPEC,
alias_map: SPEC_TABLE,
allow_variable_indices=False) -> SPEC:
"""
Generate a component spec like find_dataref_component_spec but add indices.

Data Ref: a % b[4] % c[7]
Component Spec: (a, b, c)
Indexed Component Spec: (a, (b, 4), (c, 7))
"""
# The root must have been a typed object.
root, root_type, rest = _dataref_root(dref, scope_spec, alias_map)

# Initialize idx_spec with root name
# We don't need the full spec because it'll already be available in the root spec
if isinstance(root, Data_Ref):
idx_spec = find_indexed_dataref_component_spec(root, scope_spec,
alias_map,
allow_variable_indices)
if not idx_spec:
return None
assert isinstance(root, (Name, Part_Ref))
if isinstance(root, Part_Ref):
part_name, subsc = root.children[0], root.children[1]
indices = ()
for subsc_arg in subsc.children:
idx = _const_eval_basic_type(subsc_arg, alias_map)
if not idx:
if allow_variable_indices:
if not isinstance(subsc_arg, Name):
return None
idx = subsc_arg.string
else:
# Part_Ref did not have a constant index
return None
indices += (idx,)
idx_spec = ((part_name.string, *indices),)
elif isinstance(root, Name):
idx_spec = (root.string,)

cur_type = root_type
# All component shards except for the last one must have been type objects too.
for comp in rest[:-1]:
assert isinstance(comp, (Name, Part_Ref))
if isinstance(comp, Part_Ref):
part_name, subsc = comp.children[0], comp.children[1]
indices = ()
for subsc_arg in subsc.children:
idx = _const_eval_basic_type(subsc_arg, alias_map)
if not idx:
if allow_variable_indices:
if not isinstance(subsc_arg, Name):
return None
idx = subsc_arg.string
else:
# Part_Ref did not have a constant index
return None
indices += (idx,)
comp_spec = find_real_ident_spec(part_name.string, cur_type.spec, alias_map)
idx_spec += ((comp_spec[-1], *indices),)
elif isinstance(comp, Name):
comp_spec = find_real_ident_spec(comp.string, cur_type.spec, alias_map)
idx_spec += (comp_spec[-1],)
assert comp_spec in alias_map, f"cannot find: {comp_spec} / {dref} in {scope_spec}"
# So, we get the type spec for those component shards.
cur_type = find_type_of_entity(alias_map[comp_spec], alias_map)
assert cur_type

# For the last one, we just need the component spec.
comp = rest[-1]
assert isinstance(comp, (Name, Part_Ref))
if isinstance(comp, Part_Ref):
part_name, subsc = comp.children[0], comp.children[1]
indices = ()
for subsc_arg in subsc.children:
idx = _const_eval_basic_type(subsc_arg, alias_map)
if not idx:
if allow_variable_indices:
if not isinstance(subsc_arg, Name):
return None
idx = subsc_arg.string
else:
# Part_Ref did not have a constant index
return None
indices += (idx,)
comp_spec = find_real_ident_spec(part_name.string, cur_type.spec, alias_map)
idx_spec += ((comp_spec[-1], *indices),)
elif isinstance(comp, Name):
comp_spec = find_real_ident_spec(comp.string, cur_type.spec, alias_map)
idx_spec += (comp_spec[-1],)
assert comp_spec in alias_map, f"cannot find: {comp_spec} / {dref} in {scope_spec}"

return idx_spec


def find_type_dataref(dref: Union[Name, Part_Ref, Data_Ref, Data_Pointer_Object], scope_spec: SPEC,
alias_map: SPEC_TABLE) -> TYPE_SPEC:
_, root_type, rest = _dataref_root(dref, scope_spec, alias_map)
Expand Down Expand Up @@ -2243,33 +2340,77 @@ def _track_local_consts(node: Union[Base, List[Base]], alias_map: SPEC_TABLE,
plus: Dict[Union[SPEC, Tuple[SPEC, SPEC]], LITERAL_TYPES] = copy(plus) if plus else {}
minus: Set[Union[SPEC, Tuple[SPEC, SPEC]]] = copy(minus) if minus else set()

def _root_comp(dref: (Data_Ref, Data_Pointer_Object)):
def _root_comp(dref: (Data_Ref, Data_Pointer_Object),
allow_variable_indices=False):
"""Generate a unique spec for a Data_Ref."""
scope_spec = search_scope_spec(dref)
assert scope_spec
if walk(dref, Part_Ref):
# If we are dealing with any array subscript, we cannot get a "component spec", and should take the
# pessimistic path.
# TODO: Handle the `cfg % a(1:5) % b(1:5) % c` type cases better.
return None
root, _, _ = _dataref_root(dref, scope_spec, alias_map)
for pref in walk(dref, Part_Ref):
# TODO: Handle array range subscripts.
if walk(pref, Subscript_Triplet):
return None
# Find the root spec
root = dref
while not isinstance(root, Name):
root, _, _ = _dataref_root(root, scope_spec, alias_map)
loc = search_real_local_alias_spec(root, alias_map)
assert loc
root_spec = ident_spec(alias_map[loc])
comp_spec = find_dataref_component_spec(dref, scope_spec, alias_map)
comp_spec = find_indexed_dataref_component_spec(dref, scope_spec,
alias_map,
allow_variable_indices)
if not comp_spec:
# Some part of the spec was not constant
return None
return root_spec, comp_spec

def _pref_spec(pref: Part_Ref,
allow_variable_indices=False):
"""Generate a unique spec for a Part_Ref."""
scope_spec = search_scope_spec(pref)
assert scope_spec
root, _, _ = _dataref_root(pref, scope_spec, alias_map)
loc = search_real_local_alias_spec(root, alias_map)
assert loc
root_spec = ident_spec(alias_map[loc])
pref_name = pref.children[0].string
subsc = pref.children[1]
assert isinstance(subsc, Section_Subscript_List)
# TODO: Handle array range subscripts.
if walk(subsc, Subscript_Triplet):
return None
indices = ()
for subsc_arg in subsc.children:
idx = _const_eval_basic_type(subsc_arg, alias_map)
if not idx:
if allow_variable_indices:
if not isinstance(subsc_arg, Name):
return None
idx = subsc_arg.string
else:
# Part_Ref did not have a constant index
return None
indices += (idx,)
idx_spec = (pref_name, *indices)
return root_spec, idx_spec

def _integrate_subresults(tp: Dict[SPEC, LITERAL_TYPES], tm: Set[SPEC]):
"""Update plus, minus with tp, tm."""
# There should be no overlap between tp and tm
assert not (tm & tp.keys())
# Remove tm from plus and add it to minus
for k in tm:
if k in plus:
del plus[k]
minus.add(k)
# Remove tp from minus and add it to plus
for k, v in tp.items():
if k in minus:
minus.remove(k)
plus[k] = v

def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
"""Inject known values into the tree rooted in x."""
if isinstance(x, (*LITERAL_CLASSES, Char_Literal_Constant, Write_Stmt, Close_Stmt, Goto_Stmt, Cycle_Stmt)):
pass
elif isinstance(x, Assignment_Stmt):
Expand All @@ -2292,15 +2433,75 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
if isinstance(par, (Data_Ref, Part_Ref)):
replace_node(par, Data_Ref(par.tofortran()))
elif isinstance(x, Data_Ref):
spec = _root_comp(x)
# Allow variable indices. Spec will still not be found
# in plus if it is not locally constant
spec = _root_comp(x, allow_variable_indices=True)
if spec not in plus:
for pr in x.children[1:]:
for pr in x.children:
if isinstance(pr, Part_Ref):
_, subsc = pr.children
if subsc:
subsc = subsc.children
for sc in subsc:
_inject_knowns(sc, value, pointer)
_inject_knowns(sc, True, pointer)
else:
# Raise unnecessarily nested Data_Refs
assert spec not in minus
scope_spec = find_scope_spec(x)
xtyp = find_type_dataref(x, scope_spec, alias_map)
if (pointer and xtyp.pointer) or value:
par = x.parent
replace_node(x, copy_fparser_node(plus[spec]))
if isinstance(par, (Data_Ref, Part_Ref)):
replace_node(par, Data_Ref(par.tofortran()))
# If there's a Part_Ref at the end of the Data_Ref,
# turn it into a Name and see if it matches.
last = x.children[-1]
if spec and isinstance(last, Part_Ref):
name, subsc = last.children
assert isinstance(subsc, Section_Subscript_List)
subsc_arg = subsc.children[0]
idx = _const_eval_basic_type(subsc_arg, alias_map)
if not idx:
assert isinstance(subsc_arg, Name)
idx = subsc_arg.string
# Cannot use _root_comp because a copied child would have no parent
# So instead we do surgery on the spec
spec = (spec[0], spec[1][:-1] + (spec[1][-1][0],))
if spec in plus:
assert spec not in minus
scope_spec = find_scope_spec(x)
xtyp = find_type_dataref(x, scope_spec, alias_map)
if (pointer and xtyp.pointer) or value:
par = x.parent
repl = Part_Ref(plus[spec].tofortran())

root, subsc = last.children
access = repl.children[-1]
# We cannot just chain accesses, so we need to combine them to produce a single access.
# TODO: Maybe `isinstance(c, Subscript_Triplet)` + offset manipulation?
free_comps = [(i, c) for i, c in enumerate(access.children) if c == Subscript_Triplet(':')]
assert len(free_comps) >= len(subsc.children), \
f"Free rank cannot increase, got {root}/{access} <= {subsc}"
for i, c in enumerate(subsc.children):
idx, _ = free_comps[i]
free_comps[i] = (idx, c)
free_comps = {i: c for i, c in free_comps}
set_children(access, [free_comps.get(i, c) for i, c in enumerate(access.children)])

replace_node(x, repl)
if isinstance(par, (Data_Ref, Part_Ref)):
replace_node(par, Data_Ref(par.tofortran()))
elif isinstance(x, Part_Ref):
# Try replacing the entire Part_Ref
spec = _pref_spec(x)
if spec not in plus:
# Otherwise, work on the subcomponents
par, subsc = x.children
_inject_knowns(par, value=False, pointer=True)
assert isinstance(subsc, Section_Subscript_List)
for c in subsc.children:
_inject_knowns(c)
return
assert spec not in minus
scope_spec = find_scope_spec(x)
Expand All @@ -2310,12 +2511,6 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
replace_node(x, copy_fparser_node(plus[spec]))
if isinstance(par, (Data_Ref, Part_Ref)):
replace_node(par, Data_Ref(par.tofortran()))
elif isinstance(x, Part_Ref):
par, subsc = x.children
_inject_knowns(par, value=False, pointer=True)
assert isinstance(subsc, Section_Subscript_List)
for c in subsc.children:
_inject_knowns(c)
elif isinstance(x, Subscript_Triplet):
for c in x.children:
if c:
Expand Down Expand Up @@ -2345,10 +2540,12 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
raise NotImplementedError(f"cannot handle {x} | {type(x)}")

if isinstance(node, list):
# Iterate _track_local_consts over the node's elements
for c in node:
tp, tm = _track_local_consts(c, alias_map, plus, minus)
_integrate_subresults(tp, tm)
elif isinstance(node, Execution_Part):
# We add declarations from the corresponding specification part to plus
scpart = atmost_one(children_of_type(node.parent, Specification_Part))
knowns: Dict[SPEC, LITERAL_TYPES] = {}
if scpart:
Expand All @@ -2364,6 +2561,7 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
if init and isinstance(init, LITERAL_CLASSES):
knowns[ident_spec(var)] = init
_integrate_subresults(knowns, set())
# Iterate _track_local_consts over the execution_part
for op in node.children:
# TODO: We wouldn't need the exception handling once we implement for all node types.
try:
Expand All @@ -2387,18 +2585,27 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
lspec = _root_comp(lv)
scope_spec = find_scope_spec(lv)
ltyp = find_type_dataref(lv, scope_spec, alias_map)
elif isinstance(lv, Part_Ref):
lspec = _pref_spec(lv)
scope_spec = find_scope_spec(lv)
ltyp = find_type_dataref(lv, scope_spec, alias_map)
if lspec and ltyp:
# Check if the rhs is constant
rval = _const_eval_basic_type(rv, alias_map)
if rval is None:
# We know that the lhs is not constant
_integrate_subresults({}, {lspec})
# Check if we have a scalar
elif not ltyp.shape:
# We know that the lhs is constant
plus[lspec] = numpy_type_to_literal(rval)
if lspec in minus:
minus.remove(lspec)
tp, tm = _track_local_consts(rv, alias_map)
_integrate_subresults(tp, tm)
elif isinstance(node, Pointer_Assignment_Stmt):
lv, _, rv = node.children
# Replace constants on the rhs
_inject_knowns(rv, value=False, pointer=True)
lv, _, rv = node.children
lspec, ltyp = None, None
Expand All @@ -2413,18 +2620,28 @@ def _inject_knowns(x: Base, value: bool = True, pointer: bool = True):
scope_spec = find_scope_spec(lv)
ltyp = find_type_dataref(lv, scope_spec, alias_map)
if lspec and ltyp and ltyp.pointer:
plus[lspec] = rv
if lspec in minus:
minus.remove(lspec)
tp, tm = _track_local_consts(rv, alias_map)
# Replace the pointer with whatever it's pointing to
scope_spec = search_scope_spec(rv)
root_name = _dataref_root(rv, scope_spec, alias_map)[0]
bad = walk(rv.children, (Name, Data_Ref, Part_Ref))
remove = {}
if not bad or bad == [root_name]:
plus[lspec] = rv
if lspec in minus:
minus.remove(lspec)
else:
remove = {lspec}
tp, tm = _track_local_consts(rv, alias_map, {}, remove)
_integrate_subresults(tp, tm)
elif isinstance(node, If_Stmt):
cond, body = node.children
_inject_knowns(cond)
_inject_knowns(body)
cond, body = node.children
# Condition has an effect past the scope
tp, tm = _track_local_consts(cond, alias_map)
_integrate_subresults(tp, tm)
# Because it's in the if body, nothing can be assumed to be constant
tp, tm = _track_local_consts(body, alias_map)
_integrate_subresults({}, tm | tp.keys())
elif isinstance(node, If_Construct):
Expand Down
Loading