Skip to content

Commit

Permalink
Improved domain and type query support for registerCommand (#3786)
Browse files Browse the repository at this point in the history
* add better support for interpeting domain and type queries with 'registerCommand' annotation

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

* remove TODO about domain/type query support

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>

---------

Signed-off-by: Jeremiah Corrado <jeremiah.corrado@hpe.com>
  • Loading branch information
jeremiah-corrado authored Sep 20, 2024
1 parent 1d241aa commit 8dae0c5
Showing 1 changed file with 135 additions and 36 deletions.
171 changes: 135 additions & 36 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ def extract_ast_block_text(node):
return f.read(int(cstop) - int(cstart)).strip()


class FormalQuery:
def __init__(self, name):
self.name = name

def name(self):
return self.name


class FormalQueryRef:
def __init__(self, name):
self.name = name

def name(self):
return self.name


class StaticType:
def __init__(self, name):
self.name = name

def name(self):
return self.name


def get_formals(fn, require_type_annotations):
"""
Get a function's formal parameters, separating them into concrete and
Expand All @@ -115,29 +139,28 @@ def info_tuple(formal):
ten = "<array>"
extra_info = [None, None]

# TODO: support referencing a domain from another array's domain query
# (e.g., 'proc foo(x: [?d], y: [d])')
# to avoid instantiating all combinations of array ranks for the 2+ arrays

# record domain query name if any
if isinstance(te.iterand(), chapel.TypeQuery):
extra_info[0] = te.iterand().name()
extra_info[0] = FormalQuery(te.iterand().name())
else:
block_text = extract_ast_block_text(te.iterand())
extra_info[0] = FormalQueryRef(block_text)

# record element type query if any
if isinstance(te.body(), chapel.Block):
# hack: chapel-py doesn't have a way to get the text/body of a block (I think)
block_text = extract_ast_block_text(te.body())

# TODO: support referencing a type from another array's dtype query
# (e.g., 'proc foo(x: [] ?t, y: [] t)')
# to avoid instantiating all combinations of array element types
# for the 2+ arrays

if block_text[0] == "?":
extra_info[1] = block_text[1:]
extra_info[1] = FormalQuery(block_text[1:])
elif block_text in chapel_scalar_types.keys():
extra_info[1] = StaticType(block_text)
else:
extra_info[1] = FormalQueryRef(block_text)

extra_info = tuple(extra_info)
else:
# TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions
raise ValueError("invalid formal type expression")
elif isinstance(te, chapel.FnCall):
if ce := te.called_expression():
Expand Down Expand Up @@ -222,7 +245,9 @@ def clean_stamp_name(name):
return name.translate(str.maketrans("[](),=", "______"))


def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc):
def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, is_user_proc
):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand Down Expand Up @@ -295,7 +320,9 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
return value
elif isinstance(value, int):
return [
Expand All @@ -311,7 +338,9 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(f"Could not create a list of parameter values from '{value}'")
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
else:
raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'")

Expand Down Expand Up @@ -349,7 +378,9 @@ def generic_permutations(config, gen_formals):
+ "please check the 'parameter_classes' field in the configuration file"
)

to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname])
to_permute[formal_name] = parse_param_class_value(
config["parameter_classes"][pclass][pname]
)

return permutations(to_permute)

Expand Down Expand Up @@ -388,7 +419,7 @@ def valid_generic_command_signature(fn, con_formals):


# TODO: use var/const depending on user proc's formal intent
def unpack_array_arg(arg_name, array_count):
def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries):
"""
Generate the code to unpack an array symbol from the symbol table
Expand All @@ -404,12 +435,42 @@ def unpack_array_arg(arg_name, array_count):
Returns the chapel code, and the specifications for the
'etype' and 'dimensions' type-constructor arguments
"""
dtype_arg_name = "array_dtype_" + str(array_count)
nd_arg_name = "array_nd_" + str(array_count)

# check if the nd formal is a domain query
if (
finfo is not None
and finfo[0] is not None
and isinstance(finfo[0], FormalQueryRef)
and finfo[0].name in domain_queries
):
nd_arg_name = domain_queries[finfo[0].name]
nd_generic_formal_info = None
else:
nd_arg_name = "array_nd_" + str(array_count)
nd_generic_formal_info = (nd_arg_name, "param", "int", None)

# check if the array formal has a static type or a type-query
# if not, generate a unique name and formal info for the dtype argument
if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticType):
dtype_arg_name = finfo[1].name
dtype_generic_formal_info = None
elif (
finfo is not None
and finfo[1] is not None
and isinstance(finfo[1], FormalQueryRef)
and finfo[1].name in dtype_queries
):
dtype_arg_name = dtype_queries[finfo[1].name]
dtype_generic_formal_info = None
else:
dtype_arg_name = "array_dtype_" + str(array_count)
dtype_generic_formal_info = (dtype_arg_name, "type", None, None)

return (
f"\tvar {arg_name}_array_sym = {SYMTAB_FORMAL_NAME}[{ARGS_FORMAL_NAME}['{arg_name}']]: {ARRAY_ENTRY_CLASS_NAME}({dtype_arg_name}, {nd_arg_name});\n"
+ f"\tref {arg_name} = {arg_name}_array_sym.a;",
[(dtype_arg_name, "type", None, None), (nd_arg_name, "param", "int", None)],
dtype_generic_formal_info,
nd_generic_formal_info,
)


Expand Down Expand Up @@ -508,7 +569,8 @@ def gen_signature(user_proc_name, generic_args=None):
if generic_args:
name = "ark_reg_" + user_proc_name + "_generic"
arg_strings = [
f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args
f"{kind} {name}: {ft}" if ft else f"{kind} {name}"
for name, kind, ft, _ in generic_args
]
proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}"
else:
Expand Down Expand Up @@ -536,19 +598,37 @@ def gen_arg_unpacking(formals):
if ftype in chapel_scalar_types:
unpack_lines.append(unpack_scalar_arg(fname, ftype))
elif ftype == "<array>":
code, array_args = unpack_array_arg(fname, array_arg_counter)
code, gen_dtype_arg, gen_nd_arg = unpack_array_arg(
fname,
array_arg_counter,
finfo,
array_domain_queries,
array_dtype_queries,
)
unpack_lines.append(code)
generic_args += array_args
if gen_dtype_arg is not None:
generic_args.append(gen_dtype_arg)
if gen_nd_arg is not None:
generic_args.append(gen_nd_arg)
array_arg_counter += 1

# when an array formal type has a domain query (e.g., '[?d]'), keep track of
# When an array formal type has a domain query (e.g., '[?d]'), keep track of
# the array's generic rank argument under the domain query's name (e.g., 'd').
# this allows homogeneous-tuple formal types to use the array's rank as a size argument
# Do the same for dtype queries
if finfo is not None:
if finfo[0] is not None:
array_domain_queries[finfo[0]] = array_args[1][0]
if finfo[1] is not None:
array_dtype_queries[finfo[1]] = array_args[0][0]
if (
finfo[0] is not None
and isinstance(finfo[0], FormalQuery)
and gen_nd_arg is not None
):
array_domain_queries[finfo[0].name] = gen_nd_arg[0]
if (
finfo[1] is not None
and isinstance(finfo[1], FormalQuery)
and gen_dtype_arg is not None
):
array_dtype_queries[finfo[1].name] = gen_dtype_arg[0]

elif "list" in ftype:
unpack_lines.append(unpack_list_arg(fname, ftype))
Expand All @@ -572,12 +652,17 @@ def gen_arg_unpacking(formals):

unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype))
else:
# a scalar formal with a generic type
if ftype in array_dtype_queries.keys():

unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype]))
unpack_lines.append(
unpack_scalar_arg(fname, array_dtype_queries[ftype])
)
else:
# TODO: fully handle generic user-defined types
code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter)
code, scalar_args = unpack_scalar_arg_with_generic(
fname, scalar_arg_counter
)
unpack_lines.append(code)
generic_args += scalar_args
scalar_arg_counter += 1
Expand Down Expand Up @@ -671,10 +756,14 @@ def gen_command_proc(name, return_type, formals, mod_name):
arg_unpack, command_formals = gen_arg_unpacking(formals)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type)
fn_call, result_name = gen_user_function_call(
name, [f[0] for f in formals], mod_name, return_type
)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)]
array_etype_queries = [
f[3][1].name for f in formals if (f[2] == "<array>" and f[3] is not None)
]

# assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference
# or if it is a `SymEntry` type-constructor call
Expand All @@ -693,22 +782,30 @@ def gen_command_proc(name, return_type, formals, mod_name):
)
)
returns_array = (
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])
command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)

return (command_proc, cmd_name, is_generic_command, command_formals)


def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc):
def stamp_out_command(
config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc
):
"""
Yield instantiations of a generic command with using the
values from the configuration file
Expand All @@ -730,7 +827,9 @@ def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_
formal_perms = generic_permutations(config, formals)

for fp in formal_perms:
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc)
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, is_user_proc
)
yield stamp


Expand Down

0 comments on commit 8dae0c5

Please sign in to comment.