Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved domain and type query support for registerCommand #3786

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 135 additions & 36 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ def extract_ast_block_text(node):
return f.read(int(cstop) - int(cstart)).strip()


class FormalQuery:
def __init__(self, name):
self.name = name

def name(self):
return self.name


class FormalQueryRef:
def __init__(self, name):
self.name = name

def name(self):
return self.name


class StaticType:
def __init__(self, name):
self.name = name

def name(self):
return self.name


def get_formals(fn, require_type_annotations):
"""
Get a function's formal parameters, separating them into concrete and
Expand All @@ -115,29 +139,28 @@ def info_tuple(formal):
ten = "<array>"
extra_info = [None, None]

# TODO: support referencing a domain from another array's domain query
# (e.g., 'proc foo(x: [?d], y: [d])')
# to avoid instantiating all combinations of array ranks for the 2+ arrays

# record domain query name if any
if isinstance(te.iterand(), chapel.TypeQuery):
extra_info[0] = te.iterand().name()
extra_info[0] = FormalQuery(te.iterand().name())
else:
block_text = extract_ast_block_text(te.iterand())
extra_info[0] = FormalQueryRef(block_text)

# record element type query if any
if isinstance(te.body(), chapel.Block):
# hack: chapel-py doesn't have a way to get the text/body of a block (I think)
block_text = extract_ast_block_text(te.body())

# TODO: support referencing a type from another array's dtype query
# (e.g., 'proc foo(x: [] ?t, y: [] t)')
# to avoid instantiating all combinations of array element types
# for the 2+ arrays

if block_text[0] == "?":
extra_info[1] = block_text[1:]
extra_info[1] = FormalQuery(block_text[1:])
elif block_text in chapel_scalar_types.keys():
extra_info[1] = StaticType(block_text)
else:
extra_info[1] = FormalQueryRef(block_text)

extra_info = tuple(extra_info)
else:
# TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions
raise ValueError("invalid formal type expression")
elif isinstance(te, chapel.FnCall):
if ce := te.called_expression():
Expand Down Expand Up @@ -222,7 +245,9 @@ def clean_stamp_name(name):
return name.translate(str.maketrans("[](),=", "______"))


def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc):
def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, is_user_proc
):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand Down Expand Up @@ -295,7 +320,9 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
return value
elif isinstance(value, int):
return [
Expand All @@ -311,7 +338,9 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(f"Could not create a list of parameter values from '{value}'")
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
else:
raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'")

Expand Down Expand Up @@ -349,7 +378,9 @@ def generic_permutations(config, gen_formals):
+ "please check the 'parameter_classes' field in the configuration file"
)

to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname])
to_permute[formal_name] = parse_param_class_value(
config["parameter_classes"][pclass][pname]
)

return permutations(to_permute)

Expand Down Expand Up @@ -388,7 +419,7 @@ def valid_generic_command_signature(fn, con_formals):


# TODO: use var/const depending on user proc's formal intent
def unpack_array_arg(arg_name, array_count):
def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries):
"""
Generate the code to unpack an array symbol from the symbol table

Expand All @@ -404,12 +435,42 @@ def unpack_array_arg(arg_name, array_count):
Returns the chapel code, and the specifications for the
'etype' and 'dimensions' type-constructor arguments
"""
dtype_arg_name = "array_dtype_" + str(array_count)
nd_arg_name = "array_nd_" + str(array_count)

# check if the nd formal is a domain query
if (
finfo is not None
and finfo[0] is not None
and isinstance(finfo[0], FormalQueryRef)
and finfo[0].name in domain_queries
):
nd_arg_name = domain_queries[finfo[0].name]
nd_generic_formal_info = None
else:
nd_arg_name = "array_nd_" + str(array_count)
nd_generic_formal_info = (nd_arg_name, "param", "int", None)

# check if the array formal has a static type or a type-query
# if not, generate a unique name and formal info for the dtype argument
if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticType):
dtype_arg_name = finfo[1].name
dtype_generic_formal_info = None
elif (
finfo is not None
and finfo[1] is not None
and isinstance(finfo[1], FormalQueryRef)
and finfo[1].name in dtype_queries
):
dtype_arg_name = dtype_queries[finfo[1].name]
dtype_generic_formal_info = None
else:
dtype_arg_name = "array_dtype_" + str(array_count)
dtype_generic_formal_info = (dtype_arg_name, "type", None, None)

return (
f"\tvar {arg_name}_array_sym = {SYMTAB_FORMAL_NAME}[{ARGS_FORMAL_NAME}['{arg_name}']]: {ARRAY_ENTRY_CLASS_NAME}({dtype_arg_name}, {nd_arg_name});\n"
+ f"\tref {arg_name} = {arg_name}_array_sym.a;",
[(dtype_arg_name, "type", None, None), (nd_arg_name, "param", "int", None)],
dtype_generic_formal_info,
nd_generic_formal_info,
)


Expand Down Expand Up @@ -508,7 +569,8 @@ def gen_signature(user_proc_name, generic_args=None):
if generic_args:
name = "ark_reg_" + user_proc_name + "_generic"
arg_strings = [
f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args
f"{kind} {name}: {ft}" if ft else f"{kind} {name}"
for name, kind, ft, _ in generic_args
]
proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}"
else:
Expand Down Expand Up @@ -536,19 +598,37 @@ def gen_arg_unpacking(formals):
if ftype in chapel_scalar_types:
unpack_lines.append(unpack_scalar_arg(fname, ftype))
elif ftype == "<array>":
code, array_args = unpack_array_arg(fname, array_arg_counter)
code, gen_dtype_arg, gen_nd_arg = unpack_array_arg(
fname,
array_arg_counter,
finfo,
array_domain_queries,
array_dtype_queries,
)
unpack_lines.append(code)
generic_args += array_args
if gen_dtype_arg is not None:
generic_args.append(gen_dtype_arg)
if gen_nd_arg is not None:
generic_args.append(gen_nd_arg)
array_arg_counter += 1

# when an array formal type has a domain query (e.g., '[?d]'), keep track of
# When an array formal type has a domain query (e.g., '[?d]'), keep track of
# the array's generic rank argument under the domain query's name (e.g., 'd').
# this allows homogeneous-tuple formal types to use the array's rank as a size argument
# Do the same for dtype queries
if finfo is not None:
if finfo[0] is not None:
array_domain_queries[finfo[0]] = array_args[1][0]
if finfo[1] is not None:
array_dtype_queries[finfo[1]] = array_args[0][0]
if (
finfo[0] is not None
and isinstance(finfo[0], FormalQuery)
and gen_nd_arg is not None
):
array_domain_queries[finfo[0].name] = gen_nd_arg[0]
if (
finfo[1] is not None
and isinstance(finfo[1], FormalQuery)
and gen_dtype_arg is not None
):
array_dtype_queries[finfo[1].name] = gen_dtype_arg[0]

elif "list" in ftype:
unpack_lines.append(unpack_list_arg(fname, ftype))
Expand All @@ -572,12 +652,17 @@ def gen_arg_unpacking(formals):

unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype))
else:
# a scalar formal with a generic type
if ftype in array_dtype_queries.keys():

unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype]))
unpack_lines.append(
unpack_scalar_arg(fname, array_dtype_queries[ftype])
)
else:
# TODO: fully handle generic user-defined types
code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter)
code, scalar_args = unpack_scalar_arg_with_generic(
fname, scalar_arg_counter
)
unpack_lines.append(code)
generic_args += scalar_args
scalar_arg_counter += 1
Expand Down Expand Up @@ -671,10 +756,14 @@ def gen_command_proc(name, return_type, formals, mod_name):
arg_unpack, command_formals = gen_arg_unpacking(formals)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type)
fn_call, result_name = gen_user_function_call(
name, [f[0] for f in formals], mod_name, return_type
)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)]
array_etype_queries = [
f[3][1].name for f in formals if (f[2] == "<array>" and f[3] is not None)
]

# assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference
# or if it is a `SymEntry` type-constructor call
Expand All @@ -693,22 +782,30 @@ def gen_command_proc(name, return_type, formals, mod_name):
)
)
returns_array = (
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])
command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)

return (command_proc, cmd_name, is_generic_command, command_formals)


def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc):
def stamp_out_command(
config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc
):
"""
Yield instantiations of a generic command with using the
values from the configuration file
Expand All @@ -730,7 +827,9 @@ def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_
formal_perms = generic_permutations(config, formals)

for fp in formal_perms:
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc)
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, is_user_proc
)
yield stamp


Expand Down
Loading