From fa3afd7594b3a5b20a05866064dd3455281be9b8 Mon Sep 17 00:00:00 2001 From: Riddhi Singh <139997447+hi-riddhi@users.noreply.github.com> Date: Thu, 24 Jul 2025 00:29:38 +0530 Subject: [PATCH 1/3] Update test_primitives.py --- .../entry_point_tests/test_primitives.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/featuretools/tests/entry_point_tests/test_primitives.py b/featuretools/tests/entry_point_tests/test_primitives.py index e75cece144..34fec40ce5 100644 --- a/featuretools/tests/entry_point_tests/test_primitives.py +++ b/featuretools/tests/entry_point_tests/test_primitives.py @@ -1,3 +1,7 @@ +import pytest +import featuretools as ft +from featuretools.primitives import MultiplyNumericScalar, Year + from featuretools.tests.entry_point_tests.utils import ( _import_featuretools, _install_featuretools_primitives, @@ -21,3 +25,22 @@ def test_entry_point(): existing_primitive += 'ignored primitive "Sum" from "featuretools_primitives.existing_primitive" because a primitive ' existing_primitive += 'with that name already exists in "featuretools.primitives.standard.aggregation.sum_primitive"' assert existing_primitive in featuretools_log + + +primitive_test_data = [ + (MultiplyNumericScalar(2), [1, 2.5, -3], float), + (Year(), ["2020-01-01", "2019-12-31"], int) +] + +@pytest.mark.parametrize("primitive, values, expected_dtype", primitive_test_data) +def test_primitive_input_types(primitive, values, expected_dtype): + es = ft.demo.load_retail(nrows=5) + fm, features = ft.dfs( + entityset=es, + target_dataframe_name="orders", + trans_primitives=[primitive], + max_depth=1, + ) + col = features[-1].name + dtype = fm[col].dtype + assert expected_dtype in str(dtype), f"{primitive} returned dtype {dtype}" From f442a14bfd1addcfe30a2b40061e67f15f18707d Mon Sep 17 00:00:00 2001 From: Riddhi Singh <139997447+hi-riddhi@users.noreply.github.com> Date: Thu, 24 Jul 2025 00:30:22 +0530 Subject: [PATCH 2/3] Update test_primitives.py --- .../tests/entry_point_tests/test_primitives.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/featuretools/tests/entry_point_tests/test_primitives.py b/featuretools/tests/entry_point_tests/test_primitives.py index 34fec40ce5..a6d7c434ab 100644 --- a/featuretools/tests/entry_point_tests/test_primitives.py +++ b/featuretools/tests/entry_point_tests/test_primitives.py @@ -1,6 +1,4 @@ -import pytest -import featuretools as ft -from featuretools.primitives import MultiplyNumericScalar, Year + from featuretools.tests.entry_point_tests.utils import ( _import_featuretools, @@ -32,15 +30,3 @@ def test_entry_point(): (Year(), ["2020-01-01", "2019-12-31"], int) ] -@pytest.mark.parametrize("primitive, values, expected_dtype", primitive_test_data) -def test_primitive_input_types(primitive, values, expected_dtype): - es = ft.demo.load_retail(nrows=5) - fm, features = ft.dfs( - entityset=es, - target_dataframe_name="orders", - trans_primitives=[primitive], - max_depth=1, - ) - col = features[-1].name - dtype = fm[col].dtype - assert expected_dtype in str(dtype), f"{primitive} returned dtype {dtype}" From 3854a5e5a9e23105237d02c5781a53291713025b Mon Sep 17 00:00:00 2001 From: riddhi Date: Thu, 24 Jul 2025 00:58:05 +0530 Subject: [PATCH 3/3] TST: add test to confirm primitive input_types (#2086) --- .../test_primitive_input_types.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 featuretools/tests/primitive_tests/test_primitive_input_types.py diff --git a/featuretools/tests/primitive_tests/test_primitive_input_types.py b/featuretools/tests/primitive_tests/test_primitive_input_types.py new file mode 100644 index 0000000000..76f4dff6dc --- /dev/null +++ b/featuretools/tests/primitive_tests/test_primitive_input_types.py @@ -0,0 +1,19 @@ +import pytest +import featuretools as ft +from featuretools.primitives import MultiplyNumericScalar, Year + +@pytest.mark.parametrize("primitive, expected_dtype", [ + (MultiplyNumericScalar(2), "float"), + (Year(), "int") +]) +def test_primitive_input_types(primitive, expected_dtype): + es = ft.demo.load_retail(nrows=10) + fm, features = ft.dfs( + entityset=es, + target_dataframe_name="orders", + trans_primitives=[primitive], + max_depth=1, + ) + col = features[-1].get_name() + dtype = fm[col].dtype + assert expected_dtype in str(dtype), f"{primitive} returned dtype {dtype}"