Skip to content

Commit

Permalink
Fix type checking for partial function definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroSteiner committed Jul 15, 2023
1 parent 71190d2 commit 7b34573
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
31 changes: 19 additions & 12 deletions lib/rule_engine/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,21 +1066,28 @@ def evaluate(self, thing):
def _validate_function(self, function_type, arguments):
if not isinstance(function_type, DataType.FUNCTION.__class__):
raise errors.EvaluationError('data type mismatch (not a callable value)')
if len(arguments) < function_type.minimum_arguments:
raise errors.FunctionCallError(
"missing {} required positional arguments".format(function_type.minimum_arguments - len(arguments)),
function_name=function_type.value_name
)
for pos, (arg1, arg2_type) in enumerate(zip(arguments, function_type.argument_types), 1):
if isinstance(arg1, LiteralExpressionBase):
arg1_type = arg1.result_type
else:
arg1_type = DataType.from_value(arg1)
if not DataType.is_compatible(arg1_type, arg2_type):
if function_type.minimum_arguments is not DataType.UNDEFINED:
if len(arguments) < function_type.minimum_arguments:
raise errors.FunctionCallError(
"data type mismatch (argument #{})".format(pos),
"expected at least {} positional arguments".format(function_type.minimum_arguments),
function_name=function_type.value_name
)
if function_type.argument_types is not DataType.UNDEFINED:
if len(arguments) > len(function_type.argument_types):
raise errors.FunctionCallError(
"expected at most {} positional arguments".format(len(function_type.argument_types)),
function_name=function_type.value_name
)
for pos, (arg1, arg2_type) in enumerate(zip(arguments, function_type.argument_types), 1):
if isinstance(arg1, LiteralExpressionBase):
arg1_type = arg1.result_type
else:
arg1_type = DataType.from_value(arg1)
if not DataType.is_compatible(arg1_type, arg2_type):
raise errors.FunctionCallError(
"data type mismatch (argument #{})".format(pos),
function_name=function_type.value_name
)

class Statement(ASTNodeBase):
"""A class representing the top level statement of the grammar text."""
Expand Down
25 changes: 22 additions & 3 deletions tests/ast/expression/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_ast_expression_function_call_error_on_uncallable_value(self):
with self.assertRaises(errors.EvaluationError):
self.assertTrue(function_call.evaluate({'function': True}))

def test_ast_expression_function_call_error_on_missing_arguments(self):
def test_ast_expression_function_call_error_on_to_few_arguments(self):
context = engine.Context(
type_resolver=engine.type_resolver_from_dict({
'function': ast.DataType.FUNCTION(
Expand All @@ -101,11 +101,30 @@ def test_ast_expression_function_call_error_on_missing_arguments(self):
})
)
symbol = ast.SymbolExpression(context, 'function')
function_call = ast.FunctionCallExpression(context, symbol, [ast.FloatExpression(context, 1)])

# function is missing arguments
with self.assertRaises(errors.FunctionCallError):
function_call = ast.FunctionCallExpression(context, symbol, [])
ast.FunctionCallExpression(context, symbol, [])

def test_ast_expression_function_call_error_on_to_many_arguments(self):
context = engine.Context(
type_resolver=engine.type_resolver_from_dict({
'function': ast.DataType.FUNCTION(
'function',
return_type=ast.DataType.FLOAT,
argument_types=(ast.DataType.FLOAT,),
minimum_arguments=1
)
})
)
symbol = ast.SymbolExpression(context, 'function')

# function is missing arguments
with self.assertRaises(errors.FunctionCallError):
ast.FunctionCallExpression(context, symbol, [
ast.FloatExpression(context, 1),
ast.FloatExpression(context, 1)
])

def test_ast_expression_function_call_error_on_exception(self):
symbol = ast.SymbolExpression(context, 'function')
Expand Down

0 comments on commit 7b34573

Please sign in to comment.