diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py index 1fac7cac..82a53952 100644 --- a/doubleml/data/panel_data.py +++ b/doubleml/data/panel_data.py @@ -394,7 +394,11 @@ def _set_time_var(self): if hasattr(self, "_data") and self.t_col in self.data.columns: t_values = self.data.loc[:, self.t_col] expected_dtypes = (np.integer, np.floating, np.datetime64) - if not any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes): + try: + valid_type = any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes) + except TypeError: + valid_type = False + if not valid_type: raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.") else: self._t = t_values diff --git a/doubleml/did/utils/_plot.py b/doubleml/did/utils/_plot.py index 9a3b3aab..546f77eb 100644 --- a/doubleml/did/utils/_plot.py +++ b/doubleml/did/utils/_plot.py @@ -25,7 +25,10 @@ def add_jitter(data, x_col, is_datetime=None, jitter_value=None): is_datetime = pd.api.types.is_datetime64_any_dtype(data[x_col]) # Initialize jittered_x with original values - data["jittered_x"] = data[x_col] + if is_datetime: + data["jittered_x"] = data[x_col] + else: + data["jittered_x"] = data[x_col].astype(float) for x_val in data[x_col].unique(): mask = data[x_col] == x_val diff --git a/doubleml/did/utils/tests/test_add_jitter.py b/doubleml/did/utils/tests/test_add_jitter.py index c66cb8bd..92386af7 100644 --- a/doubleml/did/utils/tests/test_add_jitter.py +++ b/doubleml/did/utils/tests/test_add_jitter.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta +import numpy as np import pandas as pd import pytest @@ -41,7 +42,7 @@ def test_add_jitter_numeric_no_duplicates(numeric_df_no_duplicates): """Test that no jitter is added when there are no duplicates.""" result = add_jitter(numeric_df_no_duplicates, "x") # No jitter should be added when there are no duplicates - pd.testing.assert_series_equal(result["jittered_x"], result["x"], check_names=False) + np.testing.assert_allclose(result["jittered_x"], result["x"]) @pytest.mark.ci @@ -121,7 +122,7 @@ def test_add_jitter_explicit_datetime_flag(): df = pd.DataFrame({"x": ["2023-01-01", "2023-01-01", "2023-01-02"], "y": [10, 15, 20]}) # Without specifying is_datetime, it would treat as strings - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = add_jitter(df, "x") # With is_datetime=True, it should convert and jitter as datetimes diff --git a/doubleml/irm/tests/test_apo_exceptions.py b/doubleml/irm/tests/test_apo_exceptions.py index f428de6b..8800cf86 100644 --- a/doubleml/irm/tests/test_apo_exceptions.py +++ b/doubleml/irm/tests/test_apo_exceptions.py @@ -25,7 +25,7 @@ def test_apo_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLAPO(pd.DataFrame(), ml_g, ml_m, treatment_level=0) diff --git a/doubleml/irm/tests/test_ssm_exceptions.py b/doubleml/irm/tests/test_ssm_exceptions.py index 039ed921..1aab50b7 100644 --- a/doubleml/irm/tests/test_ssm_exceptions.py +++ b/doubleml/irm/tests/test_ssm_exceptions.py @@ -35,7 +35,7 @@ def test_ssm_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLSSM(pd.DataFrame(), ml_g, ml_pi, ml_m) diff --git a/doubleml/plm/tests/test_lplr_exceptions.py b/doubleml/plm/tests/test_lplr_exceptions.py index c58d7aa0..7a49d043 100644 --- a/doubleml/plm/tests/test_lplr_exceptions.py +++ b/doubleml/plm/tests/test_lplr_exceptions.py @@ -23,7 +23,7 @@ @pytest.mark.ci def test_lplr_exception_data(): - msg = r"The data must be of DoubleMLData.* type\.[\s\S]* of type " r" was passed\." + msg = r"The data must be of DoubleMLData.*type\." with pytest.raises(TypeError, match=msg): _ = DoubleMLLPLR(pd.DataFrame(), ml_M, ml_t, ml_m) diff --git a/doubleml/plm/tests/test_plpr_exceptions.py b/doubleml/plm/tests/test_plpr_exceptions.py index 9bf7697c..b3b291a5 100644 --- a/doubleml/plm/tests/test_plpr_exceptions.py +++ b/doubleml/plm/tests/test_plpr_exceptions.py @@ -61,7 +61,7 @@ @pytest.mark.ci def test_plpr_exception_data(): - msg = "The data must be of DoubleMLPanelData type. was passed." + msg = r"The data must be of DoubleMLPanelData type. was passed." with pytest.raises(TypeError, match=msg): _ = dml.DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m) # not a panel data object diff --git a/doubleml/utils/tests/test_policytree.py b/doubleml/utils/tests/test_policytree.py index 28c2ab7c..a44f3429 100644 --- a/doubleml/utils/tests/test_policytree.py +++ b/doubleml/utils/tests/test_policytree.py @@ -98,8 +98,8 @@ def test_doubleml_exception_policytree(): with pytest.raises(TypeError, match=msg): dml_policytree_predict.predict(features=1) msg = ( - r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'object\'\). " - r"Features with keys Index\(\[\'d\'\], dtype=\'object\'\) were passed." + r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype=.*?\)\. " + r"Features with keys Index\(\[\'d\'\], dtype=.*?\) were passed\." ) with pytest.raises(KeyError, match=msg): dml_policytree_predict.predict(features=pd.DataFrame({"d": [3, 4]}))