diff --git a/featuretools/tests/entry_point_tests/test_primitives.py b/featuretools/tests/entry_point_tests/test_primitives.py index e75cece144..a6d7c434ab 100644 --- a/featuretools/tests/entry_point_tests/test_primitives.py +++ b/featuretools/tests/entry_point_tests/test_primitives.py @@ -1,3 +1,5 @@ + + from featuretools.tests.entry_point_tests.utils import ( _import_featuretools, _install_featuretools_primitives, @@ -21,3 +23,10 @@ def test_entry_point(): existing_primitive += 'ignored primitive "Sum" from "featuretools_primitives.existing_primitive" because a primitive ' existing_primitive += 'with that name already exists in "featuretools.primitives.standard.aggregation.sum_primitive"' assert existing_primitive in featuretools_log + + +primitive_test_data = [ + (MultiplyNumericScalar(2), [1, 2.5, -3], float), + (Year(), ["2020-01-01", "2019-12-31"], int) +] + diff --git a/featuretools/tests/primitive_tests/test_primitive_input_types.py b/featuretools/tests/primitive_tests/test_primitive_input_types.py new file mode 100644 index 0000000000..76f4dff6dc --- /dev/null +++ b/featuretools/tests/primitive_tests/test_primitive_input_types.py @@ -0,0 +1,19 @@ +import pytest +import featuretools as ft +from featuretools.primitives import MultiplyNumericScalar, Year + +@pytest.mark.parametrize("primitive, expected_dtype", [ + (MultiplyNumericScalar(2), "float"), + (Year(), "int") +]) +def test_primitive_input_types(primitive, expected_dtype): + es = ft.demo.load_retail(nrows=10) + fm, features = ft.dfs( + entityset=es, + target_dataframe_name="orders", + trans_primitives=[primitive], + max_depth=1, + ) + col = features[-1].get_name() + dtype = fm[col].dtype + assert expected_dtype in str(dtype), f"{primitive} returned dtype {dtype}"