Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #3771: register_commands.py to handle generic scalar type #3772

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions registration-config.json
Copy link
Contributor

@jeremiah-corrado jeremiah-corrado Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the .configs/ files (used for CI) will also need the new "scalar" parameter class

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
"bool",
"bigint"
]
},
"scalar": {
"dtype": [
"int",
"uint",
"uint(8)",
"real",
"bool",
"bigint"
]
}
}
}
10 changes: 10 additions & 0 deletions src/registry/Commands.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ param regConfig = """
"bool",
"bigint"
]
},
"scalar": {
"dtype": [
"int",
"uint",
"uint(8)",
"real",
"bool",
"bigint"
]
}
}
}
Expand Down
90 changes: 48 additions & 42 deletions src/registry/register_commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import chapel
import sys
import json
import itertools
import json
import sys

import chapel

DEFAULT_MODS = ["MsgProcessing", "GenSymIO"]

Expand Down Expand Up @@ -210,6 +211,7 @@ def info_tuple(formal):
gen_formals.append(formal_info)
else:
con_formals.append(formal_info)

return con_formals, gen_formals


Expand All @@ -220,9 +222,7 @@ def clean_stamp_name(name):
return name.translate(str.maketrans("[](),=", "______"))


def stamp_generic_command(
generic_proc_name, prefix, module_name, formals, line_num, is_user_proc
):
def stamp_generic_command(generic_proc_name, prefix, module_name, formals, line_num, is_user_proc):
"""
Create code to stamp out and register a generic command using a generic
procedure, and a set values for its generic formals.
Expand Down Expand Up @@ -295,9 +295,7 @@ def parse_param_class_value(value):
if isinstance(value, list):
for v in value:
if not isinstance(v, (int, float, str)):
raise ValueError(
f"Invalid parameter value type ({type(v)}) in list '{value}'"
)
raise ValueError(f"Invalid parameter value type ({type(v)}) in list '{value}'")
return value
elif isinstance(value, int):
return [
Expand All @@ -313,9 +311,7 @@ def parse_param_class_value(value):
if isinstance(vals, list):
return vals
else:
raise ValueError(
f"Could not create a list of parameter values from '{value}'"
)
raise ValueError(f"Could not create a list of parameter values from '{value}'")
else:
raise ValueError(f"Invalid parameter value type ({type(value)}) for '{value}'")

Expand Down Expand Up @@ -353,9 +349,7 @@ def generic_permutations(config, gen_formals):
+ "please check the 'parameter_classes' field in the configuration file"
)

to_permute[formal_name] = parse_param_class_value(
config["parameter_classes"][pclass][pname]
)
to_permute[formal_name] = parse_param_class_value(config["parameter_classes"][pclass][pname])

return permutations(to_permute)

Expand Down Expand Up @@ -446,6 +440,28 @@ def unpack_scalar_arg(arg_name, arg_type):
return f"\tvar {arg_name} = {ARGS_FORMAL_NAME}['{arg_name}'].toScalar({arg_type});"


def unpack_scalar_arg_with_generic(arg_name, array_count):
"""
Generate the code to unpack a scalar argument

'scalar_count' is used to generate unique names when
a procedure has multiple array-symbol formals

Comment on lines +443 to +448
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def unpack_scalar_arg_with_generic(arg_name, array_count):
"""
Generate the code to unpack a scalar argument
'scalar_count' is used to generate unique names when
a procedure has multiple array-symbol formals
def unpack_scalar_arg_with_generic(arg_name, scalar_count):
"""
Generate the code to unpack a scalar argument
'scalar_count' is used to generate unique names when
a procedure has multiple scalar-symbol formals

Example:
```
var x = msgArgs['x'].toScalar(scalar_dtype_0);
```

Returns the chapel code, and the specifications for the
'dtype' and type-constructor arguments
"""
dtype_arg_name = "scalar_dtype_" + str(array_count)
return (
unpack_scalar_arg(arg_name, dtype_arg_name),
[(dtype_arg_name, "type", None, None)],
)


def unpack_tuple_arg(arg_name, tuple_size, scalar_type):
"""
Generate the code to unpack a tuple argument
Expand Down Expand Up @@ -492,8 +508,7 @@ def gen_signature(user_proc_name, generic_args=None):
if generic_args:
name = "ark_reg_" + user_proc_name + "_generic"
arg_strings = [
f"{kind} {name}: {ft}" if ft else f"{kind} {name}"
for name, kind, ft, _ in generic_args
f"{kind} {name}: {ft}" if ft else f"{kind} {name}" for name, kind, ft, _ in generic_args
]
proc = f"proc {name}(cmd: string, {ARGS_FORMAL_NAME}: {ARGS_FORMAL_TYPE}, {SYMTAB_FORMAL_NAME}: {SYMTAB_FORMAL_TYPE}, {', '.join(arg_strings)}): {RESPONSE_TYPE_NAME} throws {'{'}"
else:
Expand All @@ -511,11 +526,13 @@ def gen_arg_unpacking(formals):
unpack_lines = []
generic_args = []
array_arg_counter = 0
scalar_arg_counter = 0

array_domain_queries = {}
array_dtype_queries = {}

for fname, fintent, ftype, finfo in formals:

if ftype in chapel_scalar_types:
unpack_lines.append(unpack_scalar_arg(fname, ftype))
elif ftype == "<array>":
Expand Down Expand Up @@ -556,12 +573,14 @@ def gen_arg_unpacking(formals):
unpack_lines.append(unpack_tuple_arg(fname, tsize, ttype))
else:
if ftype in array_dtype_queries.keys():
unpack_lines.append(
unpack_scalar_arg(fname, array_dtype_queries[ftype])
)

unpack_lines.append(unpack_scalar_arg(fname, array_dtype_queries[ftype]))
else:
# TODO: fully handle generic user-defined types
unpack_lines.append(unpack_user_symbol(fname, ftype))
code, scalar_args = unpack_scalar_arg_with_generic(fname, scalar_arg_counter)
unpack_lines.append(code)
generic_args += scalar_args
scalar_arg_counter += 1

return ("\n".join(unpack_lines), generic_args)

Expand Down Expand Up @@ -652,14 +671,10 @@ def gen_command_proc(name, return_type, formals, mod_name):
arg_unpack, command_formals = gen_arg_unpacking(formals)
is_generic_command = len(command_formals) > 0
signature, cmd_name = gen_signature(name, command_formals)
fn_call, result_name = gen_user_function_call(
name, [f[0] for f in formals], mod_name, return_type
)
fn_call, result_name = gen_user_function_call(name, [f[0] for f in formals], mod_name, return_type)

# get the names of the array-elt-type queries in the formals
array_etype_queries = [
f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)
]
array_etype_queries = [f[3][1] for f in formals if (f[2] == "<array>" and f[3] is not None)]

# assume the returned type is a symbol if it's an identifier that is not a scalar or type-query reference
# or if it is a `SymEntry` type-constructor call
Expand All @@ -678,30 +693,22 @@ def gen_command_proc(name, return_type, formals, mod_name):
)
)
returns_array = (
return_type
and isinstance(return_type, chapel.BracketLoop)
and return_type.is_maybe_array_type()
return_type and isinstance(return_type, chapel.BracketLoop) and return_type.is_maybe_array_type()
)

if returns_array:
symbol_creation, result_name = gen_symbol_creation(
ARRAY_ENTRY_CLASS_NAME, result_name
)
symbol_creation, result_name = gen_symbol_creation(ARRAY_ENTRY_CLASS_NAME, result_name)
else:
symbol_creation = ""

response = gen_response(result_name, returns_symbol or returns_array)

command_proc = "\n".join(
[signature, arg_unpack, fn_call, symbol_creation, response, "}"]
)
command_proc = "\n".join([signature, arg_unpack, fn_call, symbol_creation, response, "}"])

return (command_proc, cmd_name, is_generic_command, command_formals)


def stamp_out_command(
config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc
):
def stamp_out_command(config, formals, name, cmd_prefix, mod_name, line_num, is_user_proc):
"""
Yield instantiations of a generic command with using the
values from the configuration file
Expand All @@ -723,9 +730,7 @@ def stamp_out_command(
formal_perms = generic_permutations(config, formals)

for fp in formal_perms:
stamp = stamp_generic_command(
name, cmd_prefix, mod_name, fp, line_num, is_user_proc
)
stamp = stamp_generic_command(name, cmd_prefix, mod_name, fp, line_num, is_user_proc)
yield stamp


Expand Down Expand Up @@ -782,6 +787,7 @@ def register_commands(config, source_files):
(cmd_proc, cmd_name, is_generic_cmd, cmd_gen_formals) = gen_command_proc(
name, fn.return_type(), con_formals, mod_name
)

file_stamps.append(cmd_proc)
count += 1

Expand Down
Loading