Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine overlapping missing sets #5

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added data/src/tests/__init__.py
Empty file.
80 changes: 80 additions & 0 deletions data/src/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pandas as pd
import pytest
from utils.utils import group_by_column_sets, merge_overlapping_df_list


class TestGroupSetHandlers:
@pytest.fixture
def df_small_overlap(self):
df = pd.DataFrame(
{
"a": ["a", "b", "a", "b", "a", "b", "c", "c", "d", "d", "d"],
"b": [1, 2, 1, 2, 3, 3, 2, 3, 1, 2, 3],
}
)
return ("small_overlap", df)

@pytest.fixture
def df_big_overlap(self):
df = pd.DataFrame(
{
"a": ["a", "a", "a", "b", "b", "b", "c", "c", "c"],
"b": [1, 2, 3, 1, 2, 3, 3, 4, 5],
}
)
return ("big_overlap", df)

@pytest.mark.parametrize("input", ["df_small_overlap", "df_big_overlap"])
def test_group_by_column_sets_output(self, input, request):
data_name, data = request.getfixturevalue(input)
result = group_by_column_sets(data, "a", "b")

expected = {
"small_overlap": [
pd.DataFrame({"a": ["a"] * 2, "b": [1, 3]}),
pd.DataFrame({"a": ["b"] * 2 + ["c"] * 2, "b": [2, 3, 2, 3]}),
pd.DataFrame({"a": ["d"] * 3, "b": [1, 2, 3]}),
],
"big_overlap": [
pd.DataFrame({"a": ["a"] * 3 + ["b"] * 3, "b": [1, 2, 3] * 2}),
pd.DataFrame({"a": ["c"] * 3, "b": [3, 4, 5]}),
],
}

assert len(result) == len(expected[data_name])
for res, exp in zip(result, expected[data_name]):
pd.testing.assert_frame_equal(
res.reset_index(drop=True), exp.reset_index(drop=True)
)

@pytest.mark.parametrize(
"input,threshold", [("df_small_overlap", 0.9), ("df_big_overlap", 0.2)]
)
def test_merge_overlapping_df_list_output(self, input, threshold, request):
data_name, data = request.getfixturevalue(input)
initial = group_by_column_sets(data, "a", "b")
result = merge_overlapping_df_list(initial, threshold)

expected = {
"small_overlap": [
pd.DataFrame({"a": ["b", "b", "c", "c"], "b": [2, 3, 2, 3]}),
pd.DataFrame(
{"a": ["d"] * 3 + ["a"] * 2, "b": [1, 2, 3, 1, 3]}
),
],
"big_overlap": [
pd.DataFrame(
{
"a": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,
"b": [1, 2, 3] * 2 + [3, 4, 5],
}
)
],
}

# Check total number of rows is preserved
assert sum(len(df) for df in initial) == sum(len(df) for df in result)
for res, exp in zip(result, expected[data_name]):
pd.testing.assert_frame_equal(
res.reset_index(drop=True), exp.reset_index(drop=True)
)
26 changes: 24 additions & 2 deletions data/src/utils/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
create_empty_df,
format_time,
group_by_column_sets,
merge_overlapping_df_list,
suppress_stdout,
)

Expand Down Expand Up @@ -606,10 +607,31 @@ def many_to_many(self, second_pass: bool = True) -> pd.DataFrame:
missing_sets = group_by_column_sets(
missing.reset_index(), "origin_id", "destination_id"
)

# Merge missing sets of OD pairs that overlap significantly (think
# two origins that share 1000 destinations but not the 1001st)
missing_sets = [
df[["origin_id", "destination_id"]] for df in missing_sets
]
merged_sets = merge_overlapping_df_list(missing_sets, 0.8)

# Gut check that both sets contain the same number of rows
try:
assert sum(len(df) for df in missing_sets) == sum(
len(df) for df in merged_sets
)
except AssertionError:
raise ValueError(
"The total number of rows in missing_sets does not"
"match the total number of rows in merged_sets"
)

self.config.logger.info(
"Found %s unique missing sets", len(missing_sets)
"Found %s unique missing sets. Merged to %s sets",
len(missing_sets),
len(merged_sets),
)
for idx, missing_set in enumerate(missing_sets):
for idx, missing_set in enumerate(merged_sets):
self.config.logger.info("Routing missing set number %s", idx)
o_ids = missing_set["origin_id"].unique()
d_ids = missing_set["destination_id"].unique()
Expand Down
41 changes: 41 additions & 0 deletions data/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -65,10 +66,50 @@ def group_by_column_sets(
result = []
for unique_set in unique_sets:
group = df[df[x].isin(grouped[grouped[y] == unique_set][x])]
group = group.drop_duplicates()
group = group.reset_index(drop=True)
result.append(group)
return result


def merge_overlapping_df_list(
df_list: list[pd.DataFrame],
overlap_threshold: float = 0.5,
) -> list[pd.DataFrame]:
def overlap_percentage(df1, df2, col):
overlap = pd.merge(df1[[col]], df2[[col]], how="inner", on=col)
return len(set(overlap[col])) / min(len(df1[col]), len(df2[col]))

# Copy the input so we don't modify it
df_list_c = deepcopy(df_list)

# Merge into largest dataframes first
df_list_c.sort(key=len, reverse=True)

merged_dfs = []
while df_list_c:
base_df = df_list_c.pop(0)
merged = base_df
to_merge = []
for df in df_list_c:
for col in df.columns:
if overlap_percentage(base_df, df, col) >= overlap_threshold:
to_merge.append((df, col))
break
for df, col in to_merge:
# Remove the dataframe from the main list if it's been merged
for i in range(len(df_list_c) - 1, -1, -1):
if df_list_c[i][df_list_c[i].columns].equals(df[df.columns]):
df_list_c.pop(i)
merged = (
pd.concat([merged, df])
.drop_duplicates()
.reset_index(drop=True)
)
merged_dfs.append(merged)
return merged_dfs


def split_file_to_str(file: str | Path, **kwargs) -> str:
"""
Splits the contents of a Parquet file into chunks and return the chunk
Expand Down
Loading