Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doubleml/data/panel_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,11 @@ def _set_time_var(self):
if hasattr(self, "_data") and self.t_col in self.data.columns:
t_values = self.data.loc[:, self.t_col]
expected_dtypes = (np.integer, np.floating, np.datetime64)
if not any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes):
try:
valid_type = any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes)
except TypeError:
valid_type = False
if not valid_type:
raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.")
else:
self._t = t_values
5 changes: 4 additions & 1 deletion doubleml/did/utils/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def add_jitter(data, x_col, is_datetime=None, jitter_value=None):
is_datetime = pd.api.types.is_datetime64_any_dtype(data[x_col])

# Initialize jittered_x with original values
data["jittered_x"] = data[x_col]
if is_datetime:
data["jittered_x"] = data[x_col]
else:
data["jittered_x"] = data[x_col].astype(float)

for x_val in data[x_col].unique():
mask = data[x_col] == x_val
Expand Down
5 changes: 3 additions & 2 deletions doubleml/did/utils/tests/test_add_jitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -41,7 +42,7 @@ def test_add_jitter_numeric_no_duplicates(numeric_df_no_duplicates):
"""Test that no jitter is added when there are no duplicates."""
result = add_jitter(numeric_df_no_duplicates, "x")
# No jitter should be added when there are no duplicates
pd.testing.assert_series_equal(result["jittered_x"], result["x"], check_names=False)
np.testing.assert_allclose(result["jittered_x"], result["x"])


@pytest.mark.ci
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_add_jitter_explicit_datetime_flag():
df = pd.DataFrame({"x": ["2023-01-01", "2023-01-01", "2023-01-02"], "y": [10, 15, 20]})

# Without specifying is_datetime, it would treat as strings
with pytest.raises(TypeError):
with pytest.raises(ValueError):
_ = add_jitter(df, "x")

# With is_datetime=True, it should convert and jitter as datetimes
Expand Down
2 changes: 1 addition & 1 deletion doubleml/irm/tests/test_apo_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_apo_exception_data():
msg = (
r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or "
r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type "
r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
r"<class 'pandas\..*DataFrame'> was passed\."
)
with pytest.raises(TypeError, match=msg):
_ = DoubleMLAPO(pd.DataFrame(), ml_g, ml_m, treatment_level=0)
Expand Down
2 changes: 1 addition & 1 deletion doubleml/irm/tests/test_ssm_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_ssm_exception_data():
msg = (
r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or "
r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type "
r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
r"<class 'pandas\..*DataFrame'> was passed\."
)
with pytest.raises(TypeError, match=msg):
_ = DoubleMLSSM(pd.DataFrame(), ml_g, ml_pi, ml_m)
Expand Down
2 changes: 1 addition & 1 deletion doubleml/plm/tests/test_lplr_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@pytest.mark.ci
def test_lplr_exception_data():
msg = r"The data must be of DoubleMLData.* type\.[\s\S]* of type " r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
msg = r"The data must be of DoubleMLData.*type\."
with pytest.raises(TypeError, match=msg):
_ = DoubleMLLPLR(pd.DataFrame(), ml_M, ml_t, ml_m)

Expand Down
2 changes: 1 addition & 1 deletion doubleml/plm/tests/test_plpr_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

@pytest.mark.ci
def test_plpr_exception_data():
msg = "The data must be of DoubleMLPanelData type. <class 'pandas.core.frame.DataFrame'> was passed."
msg = r"The data must be of DoubleMLPanelData type. <class 'pandas\..*DataFrame'> was passed."
with pytest.raises(TypeError, match=msg):
_ = dml.DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m)
# not a panel data object
Expand Down
4 changes: 2 additions & 2 deletions doubleml/utils/tests/test_policytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_doubleml_exception_policytree():
with pytest.raises(TypeError, match=msg):
dml_policytree_predict.predict(features=1)
msg = (
r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'object\'\). "
r"Features with keys Index\(\[\'d\'\], dtype=\'object\'\) were passed."
r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype=.*?\)\. "
r"Features with keys Index\(\[\'d\'\], dtype=.*?\) were passed\."
)
with pytest.raises(KeyError, match=msg):
dml_policytree_predict.predict(features=pd.DataFrame({"d": [3, 4]}))
Expand Down
Loading