From 8d994e65f72c35a6d48225d9aa4940090d60788f Mon Sep 17 00:00:00 2001 From: ajpotts Date: Tue, 17 Sep 2024 08:24:17 -0400 Subject: [PATCH] Closes #3771: register_commands.py to handle generic scalar type (#3772) Co-authored-by: Amanda Potts --- registration-config.json | 10 ++++ src/registry/Commands.chpl | 10 ++++ src/registry/register_commands.py | 90 ++++++++++++++++--------------- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/registration-config.json b/registration-config.json index 44afced3bf..44a28371fa 100644 --- a/registration-config.json +++ b/registration-config.json @@ -10,6 +10,16 @@ "bool", "bigint" ] + }, + "scalar": { + "dtype": [ + "int", + "uint", + "uint(8)", + "real", + "bool", + "bigint" + ] } } } diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index b4d63e1498..97a32a8f26 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -19,6 +19,16 @@ param regConfig = """ "bool", "bigint" ] + }, + "scalar": { + "dtype": [ + "int", + "uint", + "uint(8)", + "real", + "bool", + "bigint" + ] } } } diff --git a/src/registry/register_commands.py b/src/registry/register_commands.py index 6834d9f3c5..103291bd2c 100644 --- a/src/registry/register_commands.py +++ b/src/registry/register_commands.py @@ -1,7 +1,8 @@ -import chapel -import sys -import json import itertools +import json +import sys + +import chapel DEFAULT_MODS = ["MsgProcessing", "GenSymIO"] @@ -210,6 +211,7 @@ def info_tuple(formal): gen_formals.append(formal_info) else: con_formals.append(formal_info) + return con_formals, gen_formals @@ -220,9 +222,7 @@ def clean_stamp_name(name): return name.translate(str.maketrans("[](),=", "______")) -def stamp_generic_command( - generic_proc_name, prefix, module_name, formals, line_num, is_user_proc -): +def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc): """ Create code to stamp out and register a generic command using a generic procedure, and a set values for its generic formals. @@ -295,9 +295,7 @@ def parse_param_class_value(value): if isinstance(value, list): for v in value: if not isinstance(v, (int, float, str)): - raise ValueError( - f"Invalid parameter value type ({type(v)}) in list '{value}'" - ) + raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'") return value elif isinstance(value, int): return [ @@ -313,9 +311,7 @@ def parse_param_class_value(value): if isinstance(vals, list): return vals else: - raise ValueError( - f"Could not create a list of parameter values from '{value}'" - ) + raise ValueError(f"Could not create a list of parameter values from '{value}'") else: raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'") @@ -353,9 +349,7 @@ def generic_permutations(config, gen_formals): + "please check the 'parameter_classes' field in the configuration file" ) - to_permute[formal_name] = parse_param_class_value( - config["parameter_classes"][pclass][pname] - ) + to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname]) return permutations(to_permute) @@ -446,6 +440,28 @@ def unpack_scalar_arg(arg_name, arg_type): return f"\tvar {arg_name} = {ARGS_FORMAL_NAME}['{arg_name}'].toScalar({arg_type});" +def unpack_scalar_arg_with_generic(arg_name, array_count): + """ + Generate the code to unpack a scalar argument + + 'scalar_count' is used to generate unique names when + a procedure has multiple array-symbol formals + + Example: + ``` + var x = msgArgs['x'].toScalar(scalar_dtype_0); + ``` + + Returns the chapel code, and the specifications for the + 'dtype' and type-constructor arguments + """ + dtype_arg_name = "scalar_dtype_" + str(array_count) + return ( + unpack_scalar_arg(arg_name, dtype_arg_name), + [(dtype_arg_name, "type", None, None)], + ) + + def unpack_tuple_arg(arg_name, tuple_size, scalar_type): """ Generate the code to unpack a tuple argument @@ -492,8 +508,7 @@ def gen_signature(user_proc_name, generic_args=None): if generic_args: name = "ark_reg_" + user_proc_name + "_generic" arg_strings = [ - f"{kind} {name}: {ft}" if ft else f"{kind} {name}" - for name, kind, ft, _ in generic_args + f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args ] proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}" else: @@ -511,11 +526,13 @@ def gen_arg_unpacking(formals): unpack_lines = [] generic_args = [] array_arg_counter = 0 + scalar_arg_counter = 0 array_domain_queries = {} array_dtype_queries = {} for fname, fintent, ftype, finfo in formals: + if ftype in chapel_scalar_types: unpack_lines.append(unpack_scalar_arg(fname, ftype)) elif ftype == "": @@ -556,12 +573,14 @@ def gen_arg_unpacking(formals): unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype)) else: if ftype in array_dtype_queries.keys(): - unpack_lines.append( - unpack_scalar_arg(fname, array_dtype_queries[ftype]) - ) + + unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype])) else: # TODO: fully handle generic user-defined types - unpack_lines.append(unpack_user_symbol(fname, ftype)) + code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter) + unpack_lines.append(code) + generic_args += scalar_args + scalar_arg_counter += 1 return ("\n".join(unpack_lines), generic_args) @@ -652,14 +671,10 @@ def gen_command_proc(name, return_type, formals, mod_name): arg_unpack, command_formals = gen_arg_unpacking(formals) is_generic_command = len(command_formals) > 0 signature, cmd_name = gen_signature(name, command_formals) - fn_call, result_name = gen_user_function_call( - name, [f[0] for f in formals], mod_name, return_type - ) + fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type) # get the names of the array-elt-type queries in the formals - array_etype_queries = [ - f[3][1] for f in formals if (f[2] == "" and f[3] is not None) - ] + array_etype_queries = [f[3][1] for f in formals if (f[2] == "" and f[3] is not None)] # assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference # or if it is a `SymEntry` type-constructor call @@ -678,30 +693,22 @@ def gen_command_proc(name, return_type, formals, mod_name): ) ) returns_array = ( - return_type - and isinstance(return_type, chapel.BracketLoop) - and return_type.is_maybe_array_type() + return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type() ) if returns_array: - symbol_creation, result_name = gen_symbol_creation( - ARRAY_ENTRY_CLASS_NAME, result_name - ) + symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name) else: symbol_creation = "" response = gen_response(result_name, returns_symbol or returns_array) - command_proc = "\n".join( - [signature, arg_unpack, fn_call, symbol_creation, response, "}"] - ) + command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"]) return (command_proc, cmd_name, is_generic_command, command_formals) -def stamp_out_command( - config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc -): +def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc): """ Yield instantiations of a generic command with using the values from the configuration file @@ -723,9 +730,7 @@ def stamp_out_command( formal_perms = generic_permutations(config, formals) for fp in formal_perms: - stamp = stamp_generic_command( - name, cmd_prefix, mod_name, fp, line_num, is_user_proc - ) + stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc) yield stamp @@ -782,6 +787,7 @@ def register_commands(config, source_files): (cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc( name, fn.return_type(), con_formals, mod_name ) + file_stamps.append(cmd_proc) count += 1