From 320cb0df599fb60e5ec5cb05709b109855596554 Mon Sep 17 00:00:00 2001 From: Spencer McIntyre Date: Mon, 3 Jul 2023 16:52:33 -0400 Subject: [PATCH] Add the parse_float builtin function --- lib/rule_engine/_utils.py | 14 ++++++++++++++ lib/rule_engine/ast.py | 9 ++++++--- lib/rule_engine/builtins.py | 4 +++- lib/rule_engine/errors.py | 2 +- lib/rule_engine/parser.py | 23 ++++++----------------- tests/builtins.py | 13 +++++++++++++ 6 files changed, 43 insertions(+), 22 deletions(-) diff --git a/lib/rule_engine/_utils.py b/lib/rule_engine/_utils.py index 1eb0171..0a552f5 100644 --- a/lib/rule_engine/_utils.py +++ b/lib/rule_engine/_utils.py @@ -1,4 +1,6 @@ +import ast as pyast import datetime +import decimal import re from . import errors @@ -26,6 +28,18 @@ def parse_datetime(string, default_timezone): dt = dt.replace(tzinfo=default_timezone) return dt +def parse_float(string): + if re.match('^0[0-9]', string): + raise errors.RuleSyntaxError('invalid floating point literal: ' + string + ' (leading zeros in decimal literals are not permitted)') + try: + if re.match('^0[box]', string): + val = decimal.Decimal(pyast.literal_eval(string)) + else: + val = decimal.Decimal(string) + except Exception: + raise errors.RuleSyntaxError('invalid floating point literal: ' + string) from None + return val + def parse_timedelta(periodstring): if periodstring == "P": raise errors.TimedeltaSyntaxError('empty timedelta string', periodstring) diff --git a/lib/rule_engine/ast.py b/lib/rule_engine/ast.py index 0830f32..48141ad 100644 --- a/lib/rule_engine/ast.py +++ b/lib/rule_engine/ast.py @@ -38,7 +38,7 @@ import re from . import errors -from ._utils import parse_datetime, parse_timedelta +from ._utils import parse_datetime, parse_float, parse_timedelta from .suggestions import suggest_symbol from .types import * @@ -251,8 +251,7 @@ class TimedeltaExpression(LiteralExpressionBase): result_type = DataType.TIMEDELTA @classmethod def from_string(cls, context, string): - dt = parse_timedelta(string) - return cls(context, dt) + return cls(context, parse_timedelta(string)) class FloatExpression(LiteralExpressionBase): """Literal float expressions representing numerical values.""" @@ -261,6 +260,10 @@ def __init__(self, context, value, **kwargs): value = coerce_value(value) super(FloatExpression, self).__init__(context, value, **kwargs) + @classmethod + def from_string(cls, context, string): + return cls(context, parse_float(string)) + class MappingExpression(LiteralExpressionBase): """Literal mapping expression representing a set of associations between keys and values.""" result_type = DataType.MAPPING diff --git a/lib/rule_engine/builtins.py b/lib/rule_engine/builtins.py index dcb4d04..48b20c6 100644 --- a/lib/rule_engine/builtins.py +++ b/lib/rule_engine/builtins.py @@ -38,7 +38,7 @@ import math import random -from ._utils import parse_datetime, parse_timedelta +from ._utils import parse_datetime, parse_float, parse_timedelta from . import ast from . import errors from . import types @@ -161,6 +161,7 @@ def from_defaults(cls, values=None, **kwargs): 'min': min, 'filter': _builtin_filter, 'parse_datetime': BuiltinValueGenerator(lambda builtins: functools.partial(_builtin_parse_datetime, builtins)), + 'parse_float': parse_float, 'parse_timedelta': parse_timedelta, 'random': _builtin_random, 'split': _builtins_split @@ -182,6 +183,7 @@ def from_defaults(cls, values=None, **kwargs): 'min': ast.DataType.FUNCTION('min', return_type=ast.DataType.FLOAT, argument_types=(ast.DataType.ARRAY(ast.DataType.FLOAT),)), 'filter': ast.DataType.FUNCTION('filter', argument_types=(ast.DataType.FUNCTION, ast.DataType.ARRAY)), 'parse_datetime': ast.DataType.FUNCTION('parse_datetime', return_type=ast.DataType.DATETIME, argument_types=(ast.DataType.STRING,)), + 'parse_float': ast.DataType.FUNCTION('parse_float', return_type=ast.DataType.FLOAT, argument_types=(ast.DataType.STRING,)), 'parse_timedelta': ast.DataType.FUNCTION('parse_timedelta', return_type=ast.DataType.TIMEDELTA, argument_types=(ast.DataType.STRING,)), 'random': ast.DataType.FUNCTION('random', return_type=ast.DataType.FLOAT, argument_types=(ast.DataType.FLOAT,), minimum_arguments=0), 'split': ast.DataType.FUNCTION( diff --git a/lib/rule_engine/errors.py b/lib/rule_engine/errors.py index 13561b5..d4a6340 100644 --- a/lib/rule_engine/errors.py +++ b/lib/rule_engine/errors.py @@ -105,7 +105,7 @@ def __init__(self, message, error, value): """The regular expression value which contains the syntax error which caused this exception to be raised.""" class RuleSyntaxError(SyntaxError): - """An error raised for issues identified in while parsing the grammar of the rule text.""" + """An error raised for issues identified while parsing the grammar of the rule text.""" def __init__(self, message, token=None): """ :param str message: A text description of what error occurred. diff --git a/lib/rule_engine/parser.py b/lib/rule_engine/parser.py index dc449dc..8044c53 100644 --- a/lib/rule_engine/parser.py +++ b/lib/rule_engine/parser.py @@ -32,7 +32,6 @@ import ast as pyast import collections -import re import threading import types as pytypes @@ -461,23 +460,13 @@ def p_expression_timedelta(self, p): p[0] = _DeferredAstNode(ast.TimedeltaExpression, args=(self.context, p[1]), method='from_string') def p_expression_float(self, p): - 'expression : FLOAT' + """ + expression : FLOAT + | FLOAT_NAN + | FLOAT_INF + """ str_val = p[1] - if re.match('^0[0-9]', str_val): - raise errors.RuleSyntaxError('invalid floating point literal: ' + str_val + ' (leading zeros in decimal literals are not permitted)') - try: - val = literal_eval(str_val) - except SyntaxError: - raise errors.RuleSyntaxError('invalid floating point literal: ' + str_val) - p[0] = _DeferredAstNode(ast.FloatExpression, args=(self.context, float(val))) - - def p_expression_float_nan(self, p): - 'expression : FLOAT_NAN' - p[0] = _DeferredAstNode(ast.FloatExpression, args=(self.context, float('nan'))) - - def p_expression_float_inf(self, p): - 'expression : FLOAT_INF' - p[0] = _DeferredAstNode(ast.FloatExpression, args=(self.context, float('inf'))) + p[0] = _DeferredAstNode(ast.FloatExpression, args=(self.context, str_val), method='from_string') def p_expression_null(self, p): 'object : NULL' diff --git a/tests/builtins.py b/tests/builtins.py index 2044e1d..d8389af 100644 --- a/tests/builtins.py +++ b/tests/builtins.py @@ -32,6 +32,7 @@ import contextlib import datetime +import decimal import random import string import unittest @@ -158,6 +159,18 @@ def test_engine_builtins_function_parse_datetime(self): with self.assertRaises(errors.DatetimeSyntaxError): self.assertBuiltinFunction('parse_datetime', now, '') + def test_engine_builtins_function_parse_float(self): + self.assertBuiltinFunction('parse_float', 1, '1') + self.assertBuiltinFunction('parse_float', 0b10, '0b10') + self.assertBuiltinFunction('parse_float', 0o10, '0o10') + self.assertBuiltinFunction('parse_float', 0x10, '0x10') + self.assertBuiltinFunction('parse_float', decimal.Decimal('1.1'), '1.1') + self.assertBuiltinFunction('parse_float', 1e1, '1e1') + self.assertBuiltinFunction('parse_float', float('inf'), 'inf') + self.assertBuiltinFunction('parse_float', -1, '-1') + with self.assertRaises(errors.RuleSyntaxError): + self.assertBuiltinFunction('parse_float', 1, 'f00d') + def test_engine_builtins_function_parse_timedelta(self): self.assertBuiltinFunction('parse_timedelta', datetime.timedelta(days=1), 'P1D') with self.assertRaises(errors.TimedeltaSyntaxError):