Skip to content

Commit 62098ca

Browse files
committed
Fix merge_visits
1 parent 9cd7f1a commit 62098ca

File tree

3 files changed

+62
-24
lines changed

3 files changed

+62
-24
lines changed

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## Unreleased
4+
5+
### Fixed
6+
- Fix merge_visits sort_values.groupby.first
7+
38
## v0.1.8 (2024-06-13)
49

510
### Fixed

eds_scikit/period/stays.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def cleaning(
7373
@concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"])
7474
def merge_visits(
7575
vo: DataFrame,
76+
open_stay_end_datetime: Optional[datetime],
7677
remove_deleted_visits: bool = True,
7778
long_stay_threshold: timedelta = timedelta(days=365),
7879
long_stay_filtering: Optional[str] = "all",
79-
open_stay_end_datetime: Optional[datetime] = None,
8080
max_timedelta: timedelta = timedelta(days=2),
8181
merge_different_hospitals: bool = False,
8282
merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"],
@@ -108,6 +108,11 @@ def merge_visits(
108108
- care_site_id (if ``merge_different_hospitals == True``)
109109
- visit_source_value (if ``merge_different_source_values != False``)
110110
- row_status_source_value (if ``remove_deleted_visits= True``)
111+
open_stay_end_datetime: Optional[datetime]
112+
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
113+
order to compute stay duration and to filter long stays.
114+
You might provide the extraction date of your data or datetime.now()
115+
(be aware it will produce undeterministic outputs).
111116
remove_deleted_visits: bool
112117
Wether to remove deleted visits from the merging procedure.
113118
Deleted visits are extracted via the `row_status_source_value` column
@@ -126,10 +131,6 @@ def merge_visits(
126131
Long stays are determined by the ``long_stay_threshold`` value.
127132
long_stay_threshold : timedelta
128133
Minimum visit duration value to consider a visit as candidate for "long visits filtering"
129-
open_stay_end_datetime: Optional[datetime]
130-
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
131-
order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used.
132-
You might provide the extraction date of your data here.
133134
max_timedelta : timedelta
134135
Maximum time difference between the end of a visit and the start of another to consider
135136
them as belonging to the same stay. This duration is internally converted in seconds before
@@ -292,20 +293,50 @@ def get_first(
292293
)
293294

294295
# Getting the corresponding first visit
295-
first_visit = (
296-
merged.sort_values(
297-
by=[flag_name, "visit_start_datetime_1"], ascending=[False, False]
298-
)
299-
.groupby("visit_occurrence_id_2")
300-
.first()["visit_occurrence_id_1"]
301-
.reset_index()
302-
.rename(
303-
columns={
304-
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
305-
"visit_occurrence_id_2": "visit_occurrence_id",
306-
}
307-
)
296+
# Replacement for :
297+
# first_visit = merged.sort_values(by=[flag_name, "visit_start_datetime_1"],
298+
# ascending=[False, False])
299+
# .groupby(visit_occurrence_id_2).first()["visit_occurrence_id_1"]
300+
# which is not deterministic in Koalas
301+
302+
flagged = (
303+
merged[merged[flag_name]]
304+
.groupby("visit_occurrence_id_2", as_index=False)[
305+
["visit_start_datetime_1"]
306+
]
307+
.max()
308+
)
309+
flagged = merged[merged[flag_name]].merge(
310+
flagged, on=["visit_occurrence_id_2", "visit_start_datetime_1"], how="right"
311+
)
312+
flagged["flagged"] = True
313+
unflagged = (
314+
merged[~merged[flag_name]]
315+
.groupby("visit_occurrence_id_2", as_index=False)[
316+
["visit_start_datetime_1"]
317+
]
318+
.max()
319+
)
320+
unflagged = merged[~merged[flag_name]].merge(
321+
unflagged,
322+
on=["visit_occurrence_id_2", "visit_start_datetime_1"],
323+
how="right",
324+
)
325+
unflagged = unflagged.merge(
326+
flagged[["visit_occurrence_id_2", "flagged"]],
327+
on="visit_occurrence_id_2",
328+
how="left",
329+
)
330+
unflagged = unflagged[unflagged.flagged.isna()]
331+
first_visit = fw.concat((flagged, unflagged), axis=0)
332+
333+
first_visit = first_visit.rename(
334+
columns={
335+
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
336+
"visit_occurrence_id_2": "visit_occurrence_id",
337+
}
308338
)
339+
first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]]
309340

310341
return merged, first_visit
311342

tests/test_visit_merging.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from eds_scikit.period.stays import merge_visits
66
from eds_scikit.utils import framework
77
from eds_scikit.utils.test_utils import assert_equal_no_order
8+
from datetime import datetime
89

910
ds = load_visit_merging()
1011

@@ -43,7 +44,10 @@
4344
]
4445

4546

46-
@pytest.mark.parametrize("module", ["pandas", "koalas"])
47+
@pytest.mark.parametrize(
48+
"module",
49+
["pandas", "koalas"],
50+
)
4751
@pytest.mark.parametrize(
4852
"params, results",
4953
[(params, results) for params, results in zip(all_params, all_results)],
@@ -53,9 +57,7 @@ def test_visit_merging(module, params, results):
5357
results = framework.to(module, results)
5458

5559
vo = framework.to(module, ds.visit_occurrence)
56-
merged = merge_visits(vo, **params)
60+
merged = merge_visits(vo, datetime(2023, 1, 1), **params)
5761
merged = framework.pandas(merged)
58-
59-
assert_equal_no_order(
60-
merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results
61-
)
62+
merged = merged[results.columns]
63+
assert_equal_no_order(merged, results, check_dtype=False)

0 commit comments

Comments
 (0)