From 8dae0c58622b59ab27be9d184c8f6d4d7d96b2cb Mon Sep 17 00:00:00 2001 From: jeremiah-corrado <62707311+jeremiah-corrado@users.noreply.github.com> Date: Fri, 20 Sep 2024 09:25:57 -0600 Subject: [PATCH] Improved domain and type query support for `registerCommand` (#3786) * add better support for interpeting domain and type queries with 'registerCommand' annotation Signed-off-by: Jeremiah Corrado * remove TODO about domain/type query support Signed-off-by: Jeremiah Corrado --------- Signed-off-by: Jeremiah Corrado --- src/registry/register_commands.py | 171 +++++++++++++++++++++++------- 1 file changed, 135 insertions(+), 36 deletions(-) diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index e80b84013..a2328730b 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -97,6 +97,30 @@ def extract_ast_block_text(node): return f.read(int(cstop) - int(cstart)).strip() +class FormalQuery: + def __init__(self, name): + self.name = name + + def name(self): + return self.name + + +class FormalQueryRef: + def __init__(self, name): + self.name = name + + def name(self): + return self.name + + +class StaticType: + def __init__(self, name): + self.name = name + + def name(self): + return self.name + + def get_formals(fn, require_type_annotations): """ Get a function's formal parameters, separating them into concrete and @@ -115,29 +139,28 @@ def info_tuple(formal): ten = "" extra_info = [None, None] - # TODO: support referencing a domain from another array's domain query - # (e.g., 'proc foo(x: [?d], y: [d])') - # to avoid instantiating all combinations of array ranks for the 2+ arrays - # record domain query name if any if isinstance(te.iterand(), chapel.TypeQuery): - extra_info[0] = te.iterand().name() + extra_info[0] = FormalQuery(te.iterand().name()) + else: + block_text = extract_ast_block_text(te.iterand()) + extra_info[0] = FormalQueryRef(block_text) # record element type query if any if isinstance(te.body(), chapel.Block): # hack: chapel-py doesn't have a way to get the text/body of a block (I think) block_text = extract_ast_block_text(te.body()) - # TODO: support referencing a type from another array's dtype query - # (e.g., 'proc foo(x: [] ?t, y: [] t)') - # to avoid instantiating all combinations of array element types - # for the 2+ arrays - if block_text[0] == "?": - extra_info[1] = block_text[1:] + extra_info[1] = FormalQuery(block_text[1:]) + elif block_text in chapel_scalar_types.keys(): + extra_info[1] = StaticType(block_text) + else: + extra_info[1] = FormalQueryRef(block_text) extra_info = tuple(extra_info) else: + # TODO: `x: []` and `x: [?d]` are currently treated as invalid formal type expressions raise ValueError("invalid formal type expression") elif isinstance(te, chapel.FnCall): if ce := te.called_expression(): @@ -222,7 +245,9 @@ def clean_stamp_name(name): return name.translate(str.maketrans("[](),=", "______")) -def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc): +def stamp_generic_command( + generic_proc_name, prefix, module_name, formals, line_num, is_user_proc +): """ Create code to stamp out and register a generic command using a generic procedure, and a set values for its generic formals. @@ -295,7 +320,9 @@ def parse_param_class_value(value): if isinstance(value, list): for v in value: if not isinstance(v, (int, float, str)): - raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'") + raise ValueError( + f"Invalid parameter value type ({type(v)}) in list '{value}'" + ) return value elif isinstance(value, int): return [ @@ -311,7 +338,9 @@ def parse_param_class_value(value): if isinstance(vals, list): return vals else: - raise ValueError(f"Could not create a list of parameter values from '{value}'") + raise ValueError( + f"Could not create a list of parameter values from '{value}'" + ) else: raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'") @@ -349,7 +378,9 @@ def generic_permutations(config, gen_formals): + "please check the 'parameter_classes' field in the configuration file" ) - to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname]) + to_permute[formal_name] = parse_param_class_value( + config["parameter_classes"][pclass][pname] + ) return permutations(to_permute) @@ -388,7 +419,7 @@ def valid_generic_command_signature(fn, con_formals): # TODO: use var/const depending on user proc's formal intent -def unpack_array_arg(arg_name, array_count): +def unpack_array_arg(arg_name, array_count, finfo, domain_queries, dtype_queries): """ Generate the code to unpack an array symbol from the symbol table @@ -404,12 +435,42 @@ def unpack_array_arg(arg_name, array_count): Returns the chapel code, and the specifications for the 'etype' and 'dimensions' type-constructor arguments """ - dtype_arg_name = "array_dtype_" + str(array_count) - nd_arg_name = "array_nd_" + str(array_count) + + # check if the nd formal is a domain query + if ( + finfo is not None + and finfo[0] is not None + and isinstance(finfo[0], FormalQueryRef) + and finfo[0].name in domain_queries + ): + nd_arg_name = domain_queries[finfo[0].name] + nd_generic_formal_info = None + else: + nd_arg_name = "array_nd_" + str(array_count) + nd_generic_formal_info = (nd_arg_name, "param", "int", None) + + # check if the array formal has a static type or a type-query + # if not, generate a unique name and formal info for the dtype argument + if finfo is not None and finfo[1] is not None and isinstance(finfo[1], StaticType): + dtype_arg_name = finfo[1].name + dtype_generic_formal_info = None + elif ( + finfo is not None + and finfo[1] is not None + and isinstance(finfo[1], FormalQueryRef) + and finfo[1].name in dtype_queries + ): + dtype_arg_name = dtype_queries[finfo[1].name] + dtype_generic_formal_info = None + else: + dtype_arg_name = "array_dtype_" + str(array_count) + dtype_generic_formal_info = (dtype_arg_name, "type", None, None) + return ( f"\tvar {arg_name}_array_sym = {SYMTAB_FORMAL_NAME}[{ARGS_FORMAL_NAME}['{arg_name}']]: {ARRAY_ENTRY_CLASS_NAME}({dtype_arg_name}, {nd_arg_name});\n" + f"\tref {arg_name} = {arg_name}_array_sym.a;", - [(dtype_arg_name, "type", None, None), (nd_arg_name, "param", "int", None)], + dtype_generic_formal_info, + nd_generic_formal_info, ) @@ -508,7 +569,8 @@ def gen_signature(user_proc_name, generic_args=None): if generic_args: name = "ark_reg_" + user_proc_name + "_generic" arg_strings = [ - f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args + f"{kind} {name}: {ft}" if ft else f"{kind} {name}" + for name, kind, ft, _ in generic_args ] proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}" else: @@ -536,19 +598,37 @@ def gen_arg_unpacking(formals): if ftype in chapel_scalar_types: unpack_lines.append(unpack_scalar_arg(fname, ftype)) elif ftype == "": - code, array_args = unpack_array_arg(fname, array_arg_counter) + code, gen_dtype_arg, gen_nd_arg = unpack_array_arg( + fname, + array_arg_counter, + finfo, + array_domain_queries, + array_dtype_queries, + ) unpack_lines.append(code) - generic_args += array_args + if gen_dtype_arg is not None: + generic_args.append(gen_dtype_arg) + if gen_nd_arg is not None: + generic_args.append(gen_nd_arg) array_arg_counter += 1 - # when an array formal type has a domain query (e.g., '[?d]'), keep track of + # When an array formal type has a domain query (e.g., '[?d]'), keep track of # the array's generic rank argument under the domain query's name (e.g., 'd'). # this allows homogeneous-tuple formal types to use the array's rank as a size argument + # Do the same for dtype queries if finfo is not None: - if finfo[0] is not None: - array_domain_queries[finfo[0]] = array_args[1][0] - if finfo[1] is not None: - array_dtype_queries[finfo[1]] = array_args[0][0] + if ( + finfo[0] is not None + and isinstance(finfo[0], FormalQuery) + and gen_nd_arg is not None + ): + array_domain_queries[finfo[0].name] = gen_nd_arg[0] + if ( + finfo[1] is not None + and isinstance(finfo[1], FormalQuery) + and gen_dtype_arg is not None + ): + array_dtype_queries[finfo[1].name] = gen_dtype_arg[0] elif "list" in ftype: unpack_lines.append(unpack_list_arg(fname, ftype)) @@ -572,12 +652,17 @@ def gen_arg_unpacking(formals): unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype)) else: + # a scalar formal with a generic type if ftype in array_dtype_queries.keys(): - unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype])) + unpack_lines.append( + unpack_scalar_arg(fname, array_dtype_queries[ftype]) + ) else: # TODO: fully handle generic user-defined types - code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter) + code, scalar_args = unpack_scalar_arg_with_generic( + fname, scalar_arg_counter + ) unpack_lines.append(code) generic_args += scalar_args scalar_arg_counter += 1 @@ -671,10 +756,14 @@ def gen_command_proc(name, return_type, formals, mod_name): arg_unpack, command_formals = gen_arg_unpacking(formals) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) - fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type) + fn_call, result_name = gen_user_function_call( + name, [f[0] for f in formals], mod_name, return_type + ) # get the names of the array-elt-type queries in the formals - array_etype_queries = [f[3][1] for f in formals if (f[2] == "" and f[3] is not None)] + array_etype_queries = [ + f[3][1].name for f in formals if (f[2] == "" and f[3] is not None) + ] # assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference # or if it is a `SymEntry` type-constructor call @@ -693,22 +782,30 @@ def gen_command_proc(name, return_type, formals, mod_name): ) ) returns_array = ( - return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type() + return_type + and isinstance(return_type, chapel.BracketLoop) + and return_type.is_maybe_array_type() ) if returns_array: - symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name) + symbol_creation, result_name = gen_symbol_creation( + ARRAY_ENTRY_CLASS_NAME, result_name + ) else: symbol_creation = "" response = gen_response(result_name, returns_symbol or returns_array) - command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"]) + command_proc = "\n".join( + [signature, arg_unpack, fn_call, symbol_creation, response, "}"] + ) return (command_proc, cmd_name, is_generic_command, command_formals) -def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc): +def stamp_out_command( + config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc +): """ Yield instantiations of a generic command with using the values from the configuration file @@ -730,7 +827,9 @@ def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_ formal_perms = generic_permutations(config, formals) for fp in formal_perms: - stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc) + stamp = stamp_generic_command( + name, cmd_prefix, mod_name, fp, line_num, is_user_proc + ) yield stamp