From c8f4cce5be7871627b7e2627b377727434348c0b Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Wed, 28 Jan 2026 14:46:34 +0100 Subject: [PATCH 1/5] fix error msg tests for pandas dataframes --- doubleml/irm/tests/test_apo_exceptions.py | 2 +- doubleml/irm/tests/test_ssm_exceptions.py | 2 +- doubleml/plm/tests/test_lplr_exceptions.py | 2 +- doubleml/plm/tests/test_plpr_exceptions.py | 2 +- doubleml/utils/tests/test_policytree.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doubleml/irm/tests/test_apo_exceptions.py b/doubleml/irm/tests/test_apo_exceptions.py index f428de6b2..619308131 100644 --- a/doubleml/irm/tests/test_apo_exceptions.py +++ b/doubleml/irm/tests/test_apo_exceptions.py @@ -25,7 +25,7 @@ def test_apo_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLAPO(pd.DataFrame(), ml_g, ml_m, treatment_level=0) diff --git a/doubleml/irm/tests/test_ssm_exceptions.py b/doubleml/irm/tests/test_ssm_exceptions.py index 039ed9219..8409b60b0 100644 --- a/doubleml/irm/tests/test_ssm_exceptions.py +++ b/doubleml/irm/tests/test_ssm_exceptions.py @@ -35,7 +35,7 @@ def test_ssm_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLSSM(pd.DataFrame(), ml_g, ml_pi, ml_m) diff --git a/doubleml/plm/tests/test_lplr_exceptions.py b/doubleml/plm/tests/test_lplr_exceptions.py index c58d7aa02..7a49d0438 100644 --- a/doubleml/plm/tests/test_lplr_exceptions.py +++ b/doubleml/plm/tests/test_lplr_exceptions.py @@ -23,7 +23,7 @@ @pytest.mark.ci def test_lplr_exception_data(): - msg = r"The data must be of DoubleMLData.* type\.[\s\S]* of type " r" was passed\." + msg = r"The data must be of DoubleMLData.*type\." with pytest.raises(TypeError, match=msg): _ = DoubleMLLPLR(pd.DataFrame(), ml_M, ml_t, ml_m) diff --git a/doubleml/plm/tests/test_plpr_exceptions.py b/doubleml/plm/tests/test_plpr_exceptions.py index 9bf7697c7..ca4dbe387 100644 --- a/doubleml/plm/tests/test_plpr_exceptions.py +++ b/doubleml/plm/tests/test_plpr_exceptions.py @@ -61,7 +61,7 @@ @pytest.mark.ci def test_plpr_exception_data(): - msg = "The data must be of DoubleMLPanelData type. was passed." + msg = "The data must be of DoubleMLPanelData type. was passed." with pytest.raises(TypeError, match=msg): _ = dml.DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m) # not a panel data object diff --git a/doubleml/utils/tests/test_policytree.py b/doubleml/utils/tests/test_policytree.py index 28c2ab7c2..055a87536 100644 --- a/doubleml/utils/tests/test_policytree.py +++ b/doubleml/utils/tests/test_policytree.py @@ -98,8 +98,8 @@ def test_doubleml_exception_policytree(): with pytest.raises(TypeError, match=msg): dml_policytree_predict.predict(features=1) msg = ( - r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'object\'\). " - r"Features with keys Index\(\[\'d\'\], dtype=\'object\'\) were passed." + r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'str\'\). " + r"Features with keys Index\(\[\'d\'\], dtype=\'str\'\) were passed." ) with pytest.raises(KeyError, match=msg): dml_policytree_predict.predict(features=pd.DataFrame({"d": [3, 4]})) From 6a859dddf3a649a655135fdb6a1efa0772f18920 Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Wed, 28 Jan 2026 15:14:53 +0100 Subject: [PATCH 2/5] catch other datatype errors for time variable --- doubleml/data/panel_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doubleml/data/panel_data.py b/doubleml/data/panel_data.py index 1fac7caca..82a539526 100644 --- a/doubleml/data/panel_data.py +++ b/doubleml/data/panel_data.py @@ -394,7 +394,11 @@ def _set_time_var(self): if hasattr(self, "_data") and self.t_col in self.data.columns: t_values = self.data.loc[:, self.t_col] expected_dtypes = (np.integer, np.floating, np.datetime64) - if not any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes): + try: + valid_type = any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes) + except TypeError: + valid_type = False + if not valid_type: raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.") else: self._t = t_values From 4b318f638a747ab5306f7879b44291e659f67a0a Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Wed, 28 Jan 2026 15:15:01 +0100 Subject: [PATCH 3/5] Fix: Update jitter initialization logic and improve test assertions for add_jitter function --- doubleml/did/utils/_plot.py | 5 ++++- doubleml/did/utils/tests/test_add_jitter.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/doubleml/did/utils/_plot.py b/doubleml/did/utils/_plot.py index 9a3b3aab4..546f77eb3 100644 --- a/doubleml/did/utils/_plot.py +++ b/doubleml/did/utils/_plot.py @@ -25,7 +25,10 @@ def add_jitter(data, x_col, is_datetime=None, jitter_value=None): is_datetime = pd.api.types.is_datetime64_any_dtype(data[x_col]) # Initialize jittered_x with original values - data["jittered_x"] = data[x_col] + if is_datetime: + data["jittered_x"] = data[x_col] + else: + data["jittered_x"] = data[x_col].astype(float) for x_val in data[x_col].unique(): mask = data[x_col] == x_val diff --git a/doubleml/did/utils/tests/test_add_jitter.py b/doubleml/did/utils/tests/test_add_jitter.py index c66cb8bdd..715064d01 100644 --- a/doubleml/did/utils/tests/test_add_jitter.py +++ b/doubleml/did/utils/tests/test_add_jitter.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta +import numpy as np import pandas as pd import pytest @@ -41,8 +42,7 @@ def test_add_jitter_numeric_no_duplicates(numeric_df_no_duplicates): """Test that no jitter is added when there are no duplicates.""" result = add_jitter(numeric_df_no_duplicates, "x") # No jitter should be added when there are no duplicates - pd.testing.assert_series_equal(result["jittered_x"], result["x"], check_names=False) - + np.testing.assert_allclose(result["jittered_x"], result["x"]) @pytest.mark.ci def test_add_jitter_numeric_with_duplicates(numeric_df_with_duplicates): @@ -121,7 +121,7 @@ def test_add_jitter_explicit_datetime_flag(): df = pd.DataFrame({"x": ["2023-01-01", "2023-01-01", "2023-01-02"], "y": [10, 15, 20]}) # Without specifying is_datetime, it would treat as strings - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = add_jitter(df, "x") # With is_datetime=True, it should convert and jitter as datetimes From 2f3883e2bc473569609c08a9bc82ad8c33450e06 Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Wed, 28 Jan 2026 15:20:46 +0100 Subject: [PATCH 4/5] formatting --- doubleml/did/utils/tests/test_add_jitter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doubleml/did/utils/tests/test_add_jitter.py b/doubleml/did/utils/tests/test_add_jitter.py index 715064d01..92386af79 100644 --- a/doubleml/did/utils/tests/test_add_jitter.py +++ b/doubleml/did/utils/tests/test_add_jitter.py @@ -44,6 +44,7 @@ def test_add_jitter_numeric_no_duplicates(numeric_df_no_duplicates): # No jitter should be added when there are no duplicates np.testing.assert_allclose(result["jittered_x"], result["x"]) + @pytest.mark.ci def test_add_jitter_numeric_with_duplicates(numeric_df_with_duplicates): """Test that jitter is added correctly to numeric values with duplicates.""" From 4a5ea0d2217e882e603b80692044ce7ca11fef22 Mon Sep 17 00:00:00 2001 From: SvenKlaassen Date: Thu, 29 Jan 2026 10:59:11 +0100 Subject: [PATCH 5/5] Fix: Update error messages in exception tests to allow for varying pandas DataFrame representations --- doubleml/irm/tests/test_apo_exceptions.py | 2 +- doubleml/irm/tests/test_ssm_exceptions.py | 2 +- doubleml/plm/tests/test_plpr_exceptions.py | 2 +- doubleml/utils/tests/test_policytree.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doubleml/irm/tests/test_apo_exceptions.py b/doubleml/irm/tests/test_apo_exceptions.py index 619308131..8800cf86b 100644 --- a/doubleml/irm/tests/test_apo_exceptions.py +++ b/doubleml/irm/tests/test_apo_exceptions.py @@ -25,7 +25,7 @@ def test_apo_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLAPO(pd.DataFrame(), ml_g, ml_m, treatment_level=0) diff --git a/doubleml/irm/tests/test_ssm_exceptions.py b/doubleml/irm/tests/test_ssm_exceptions.py index 8409b60b0..1aab50b74 100644 --- a/doubleml/irm/tests/test_ssm_exceptions.py +++ b/doubleml/irm/tests/test_ssm_exceptions.py @@ -35,7 +35,7 @@ def test_ssm_exception_data(): msg = ( r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or " r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type " - r" was passed\." + r" was passed\." ) with pytest.raises(TypeError, match=msg): _ = DoubleMLSSM(pd.DataFrame(), ml_g, ml_pi, ml_m) diff --git a/doubleml/plm/tests/test_plpr_exceptions.py b/doubleml/plm/tests/test_plpr_exceptions.py index ca4dbe387..b3b291a5b 100644 --- a/doubleml/plm/tests/test_plpr_exceptions.py +++ b/doubleml/plm/tests/test_plpr_exceptions.py @@ -61,7 +61,7 @@ @pytest.mark.ci def test_plpr_exception_data(): - msg = "The data must be of DoubleMLPanelData type. was passed." + msg = r"The data must be of DoubleMLPanelData type. was passed." with pytest.raises(TypeError, match=msg): _ = dml.DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m) # not a panel data object diff --git a/doubleml/utils/tests/test_policytree.py b/doubleml/utils/tests/test_policytree.py index 055a87536..a44f34295 100644 --- a/doubleml/utils/tests/test_policytree.py +++ b/doubleml/utils/tests/test_policytree.py @@ -98,8 +98,8 @@ def test_doubleml_exception_policytree(): with pytest.raises(TypeError, match=msg): dml_policytree_predict.predict(features=1) msg = ( - r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'str\'\). " - r"Features with keys Index\(\[\'d\'\], dtype=\'str\'\) were passed." + r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype=.*?\)\. " + r"Features with keys Index\(\[\'d\'\], dtype=.*?\) were passed\." ) with pytest.raises(KeyError, match=msg): dml_policytree_predict.predict(features=pd.DataFrame({"d": [3, 4]}))