Skip to content

Commit ff3b42b

Browse files
committed
Fix merge_visits
1 parent 9cd7f1a commit ff3b42b

File tree

5 files changed

+113
-25
lines changed

5 files changed

+113
-25
lines changed

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## Unreleased
4+
5+
### Fixed
6+
- Fix merge_visits sort_values.groupby.first
7+
38
## v0.1.8 (2024-06-13)
49

510
### Fixed

eds_scikit/period/stays.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
77
from eds_scikit.utils.datetime_helpers import substract_datetime
88
from eds_scikit.utils.framework import get_framework
9+
from eds_scikit.utils.sort_values_first import sort_values_first
910
from eds_scikit.utils.typing import DataFrame
1011

1112

@@ -73,10 +74,10 @@ def cleaning(
7374
@concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"])
7475
def merge_visits(
7576
vo: DataFrame,
77+
open_stay_end_datetime: Optional[datetime],
7678
remove_deleted_visits: bool = True,
7779
long_stay_threshold: timedelta = timedelta(days=365),
7880
long_stay_filtering: Optional[str] = "all",
79-
open_stay_end_datetime: Optional[datetime] = None,
8081
max_timedelta: timedelta = timedelta(days=2),
8182
merge_different_hospitals: bool = False,
8283
merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"],
@@ -108,6 +109,11 @@ def merge_visits(
108109
- care_site_id (if ``merge_different_hospitals == True``)
109110
- visit_source_value (if ``merge_different_source_values != False``)
110111
- row_status_source_value (if ``remove_deleted_visits= True``)
112+
open_stay_end_datetime: Optional[datetime]
113+
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
114+
order to compute stay duration and to filter long stays.
115+
You might provide the extraction date of your data or datetime.now()
116+
(be aware it will produce undeterministic outputs).
111117
remove_deleted_visits: bool
112118
Wether to remove deleted visits from the merging procedure.
113119
Deleted visits are extracted via the `row_status_source_value` column
@@ -126,10 +132,6 @@ def merge_visits(
126132
Long stays are determined by the ``long_stay_threshold`` value.
127133
long_stay_threshold : timedelta
128134
Minimum visit duration value to consider a visit as candidate for "long visits filtering"
129-
open_stay_end_datetime: Optional[datetime]
130-
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
131-
order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used.
132-
You might provide the extraction date of your data here.
133135
max_timedelta : timedelta
134136
Maximum time difference between the end of a visit and the start of another to consider
135137
them as belonging to the same stay. This duration is internally converted in seconds before
@@ -291,21 +293,18 @@ def get_first(
291293
how="inner",
292294
)
293295

294-
# Getting the corresponding first visit
295-
first_visit = (
296-
merged.sort_values(
297-
by=[flag_name, "visit_start_datetime_1"], ascending=[False, False]
298-
)
299-
.groupby("visit_occurrence_id_2")
300-
.first()["visit_occurrence_id_1"]
301-
.reset_index()
302-
.rename(
303-
columns={
304-
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
305-
"visit_occurrence_id_2": "visit_occurrence_id",
306-
}
307-
)
296+
first_visit = sort_values_first(
297+
merged,
298+
by_cols=["visit_occurrence_id_2"],
299+
cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"],
300+
)
301+
first_visit = first_visit.rename(
302+
columns={
303+
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
304+
"visit_occurrence_id_2": "visit_occurrence_id",
305+
}
308306
)
307+
first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]]
309308

310309
return merged, first_visit
311310

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import List
2+
3+
from eds_scikit.utils.typing import DataFrame
4+
5+
6+
def sort_values_first(
7+
df: DataFrame, by_cols: List[str], cols: List[str], ascending: bool = False
8+
):
9+
"""
10+
Replace dataframe.sort_value(cols).groupby(by_cols).first()
11+
12+
Parameters
13+
----------
14+
df : DataFrame
15+
by_cols : List[str]
16+
columns to groupby
17+
cols : List[str]
18+
columns to sort
19+
ascending : bool
20+
"""
21+
22+
return (
23+
df.groupby(by_cols)
24+
.apply(
25+
lambda group: group.sort_values(
26+
by=cols, ascending=[ascending for i in cols]
27+
).head(1)
28+
)
29+
.reset_index(drop=True)
30+
)

tests/test_sort_values_first.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from eds_scikit.utils import framework
6+
from eds_scikit.utils.sort_values_first import sort_values_first
7+
from eds_scikit.utils.test_utils import assert_equal_no_order
8+
9+
# Create a DataFrame
10+
np.random.seed(0)
11+
size = 10000
12+
data = {
13+
"A": np.random.choice(["X", "Y", "Z"], size),
14+
"B": np.random.randint(1, 5, size),
15+
"C": np.random.randint(1, 5, size),
16+
"D": np.random.randint(1, 5, size),
17+
"E": np.random.randint(1, 5, size),
18+
}
19+
20+
inputs = pd.DataFrame(data)
21+
inputs.loc[0, "B"] = 0
22+
inputs.loc[0, "C"] = 4
23+
24+
25+
@pytest.mark.parametrize(
26+
"module",
27+
["pandas", "koalas"],
28+
)
29+
def test_sort_values_first(module):
30+
31+
inputs_fr = framework.to(module, inputs)
32+
33+
computed = framework.pandas(
34+
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True)
35+
)
36+
expected = (
37+
inputs.sort_values(["B", "C"], ascending=True)
38+
.groupby("A", as_index=False)
39+
.first()
40+
)
41+
assert_equal_no_order(computed, expected)
42+
43+
computed = framework.pandas(
44+
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False)
45+
)
46+
expected = (
47+
inputs.sort_values(["B", "C"], ascending=False)
48+
.groupby("A", as_index=False)
49+
.first()
50+
)
51+
assert_equal_no_order(computed, expected)

tests/test_visit_merging.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from datetime import datetime
2+
13
import pandas as pd
24
import pytest
35

@@ -43,7 +45,10 @@
4345
]
4446

4547

46-
@pytest.mark.parametrize("module", ["pandas", "koalas"])
48+
@pytest.mark.parametrize(
49+
"module",
50+
["pandas", "koalas"],
51+
)
4752
@pytest.mark.parametrize(
4853
"params, results",
4954
[(params, results) for params, results in zip(all_params, all_results)],
@@ -53,9 +58,7 @@ def test_visit_merging(module, params, results):
5358
results = framework.to(module, results)
5459

5560
vo = framework.to(module, ds.visit_occurrence)
56-
merged = merge_visits(vo, **params)
61+
merged = merge_visits(vo, datetime(2023, 1, 1), **params)
5762
merged = framework.pandas(merged)
58-
59-
assert_equal_no_order(
60-
merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results
61-
)
63+
merged = merged[results.columns]
64+
assert_equal_no_order(merged, results, check_dtype=False)

0 commit comments

Comments
 (0)