From 7b34573d24823d4a872ac0506c5ffaa22b9aefdc Mon Sep 17 00:00:00 2001 From: Spencer McIntyre Date: Sat, 15 Jul 2023 12:10:54 -0400 Subject: [PATCH] Fix type checking for partial function definitions --- lib/rule_engine/ast.py | 31 ++++++++++++++++----------- tests/ast/expression/function_call.py | 25 ++++++++++++++++++--- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/lib/rule_engine/ast.py b/lib/rule_engine/ast.py index 48141ad..356f94f 100644 --- a/lib/rule_engine/ast.py +++ b/lib/rule_engine/ast.py @@ -1066,21 +1066,28 @@ def evaluate(self, thing): def _validate_function(self, function_type, arguments): if not isinstance(function_type, DataType.FUNCTION.__class__): raise errors.EvaluationError('data type mismatch (not a callable value)') - if len(arguments) < function_type.minimum_arguments: - raise errors.FunctionCallError( - "missing {} required positional arguments".format(function_type.minimum_arguments - len(arguments)), - function_name=function_type.value_name - ) - for pos, (arg1, arg2_type) in enumerate(zip(arguments, function_type.argument_types), 1): - if isinstance(arg1, LiteralExpressionBase): - arg1_type = arg1.result_type - else: - arg1_type = DataType.from_value(arg1) - if not DataType.is_compatible(arg1_type, arg2_type): + if function_type.minimum_arguments is not DataType.UNDEFINED: + if len(arguments) < function_type.minimum_arguments: raise errors.FunctionCallError( - "data type mismatch (argument #{})".format(pos), + "expected at least {} positional arguments".format(function_type.minimum_arguments), function_name=function_type.value_name ) + if function_type.argument_types is not DataType.UNDEFINED: + if len(arguments) > len(function_type.argument_types): + raise errors.FunctionCallError( + "expected at most {} positional arguments".format(len(function_type.argument_types)), + function_name=function_type.value_name + ) + for pos, (arg1, arg2_type) in enumerate(zip(arguments, function_type.argument_types), 1): + if isinstance(arg1, LiteralExpressionBase): + arg1_type = arg1.result_type + else: + arg1_type = DataType.from_value(arg1) + if not DataType.is_compatible(arg1_type, arg2_type): + raise errors.FunctionCallError( + "data type mismatch (argument #{})".format(pos), + function_name=function_type.value_name + ) class Statement(ASTNodeBase): """A class representing the top level statement of the grammar text.""" diff --git a/tests/ast/expression/function_call.py b/tests/ast/expression/function_call.py index df26341..d5cff85 100644 --- a/tests/ast/expression/function_call.py +++ b/tests/ast/expression/function_call.py @@ -89,7 +89,7 @@ def test_ast_expression_function_call_error_on_uncallable_value(self): with self.assertRaises(errors.EvaluationError): self.assertTrue(function_call.evaluate({'function': True})) - def test_ast_expression_function_call_error_on_missing_arguments(self): + def test_ast_expression_function_call_error_on_to_few_arguments(self): context = engine.Context( type_resolver=engine.type_resolver_from_dict({ 'function': ast.DataType.FUNCTION( @@ -101,11 +101,30 @@ def test_ast_expression_function_call_error_on_missing_arguments(self): }) ) symbol = ast.SymbolExpression(context, 'function') - function_call = ast.FunctionCallExpression(context, symbol, [ast.FloatExpression(context, 1)]) # function is missing arguments with self.assertRaises(errors.FunctionCallError): - function_call = ast.FunctionCallExpression(context, symbol, []) + ast.FunctionCallExpression(context, symbol, []) + + def test_ast_expression_function_call_error_on_to_many_arguments(self): + context = engine.Context( + type_resolver=engine.type_resolver_from_dict({ + 'function': ast.DataType.FUNCTION( + 'function', + return_type=ast.DataType.FLOAT, + argument_types=(ast.DataType.FLOAT,), + minimum_arguments=1 + ) + }) + ) + symbol = ast.SymbolExpression(context, 'function') + + # function is missing arguments + with self.assertRaises(errors.FunctionCallError): + ast.FunctionCallExpression(context, symbol, [ + ast.FloatExpression(context, 1), + ast.FloatExpression(context, 1) + ]) def test_ast_expression_function_call_error_on_exception(self): symbol = ast.SymbolExpression(context, 'function')