From 1ed1dce749dfe43bc0e70a6daf328a0d95aa9022 Mon Sep 17 00:00:00 2001 From: Dan Snow Date: Sun, 15 Dec 2024 22:29:34 -0600 Subject: [PATCH 1/3] Add function to combine overlapping missing sets --- data/src/utils/times.py | 23 +++++++++++++++++++++-- data/src/utils/utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/data/src/utils/times.py b/data/src/utils/times.py index 8c2e88a..b22bd37 100644 --- a/data/src/utils/times.py +++ b/data/src/utils/times.py @@ -14,6 +14,7 @@ create_empty_df, format_time, group_by_column_sets, + merge_overlapping_df_list, suppress_stdout, ) @@ -606,10 +607,28 @@ def many_to_many(self, second_pass: bool = True) -> pd.DataFrame: missing_sets = group_by_column_sets( missing.reset_index(), "origin_id", "destination_id" ) + + # Merge missing sets that overlap significantly (think two origins + # that share 1000 destinations but not the 1001st) + merged_sets = merge_overlapping_df_list(missing_sets, 0.8) + + # Gut check that both sets contain the same number of rows + try: + assert sum(len(df) for df in missing_sets) == sum( + len(df) for df in merged_sets + ) + except AssertionError: + raise ValueError( + "The total number of rows in missing_sets does not" + "match the total number of rows in merged_sets" + ) self.config.logger.info( - "Found %s unique missing sets", len(missing_sets) + "Found %s unique missing sets. Merged to %s sets", + len(missing_sets), + len(merged_sets), ) - for idx, missing_set in enumerate(missing_sets): + + for idx, missing_set in enumerate(merged_sets): self.config.logger.info("Routing missing set number %s", idx) o_ids = missing_set["origin_id"].unique() d_ids = missing_set["destination_id"].unique() diff --git a/data/src/utils/utils.py b/data/src/utils/utils.py index 32a506a..9139531 100644 --- a/data/src/utils/utils.py +++ b/data/src/utils/utils.py @@ -3,6 +3,7 @@ import os import sys from contextlib import contextmanager +from copy import deepcopy from pathlib import Path import pandas as pd @@ -65,10 +66,50 @@ def group_by_column_sets( result = [] for unique_set in unique_sets: group = df[df[x].isin(grouped[grouped[y] == unique_set][x])] + group = group.drop_duplicates() + group = group.reset_index(drop=True) result.append(group) return result +def merge_overlapping_df_list( + df_list: list[pd.DataFrame], + overlap_threshold: float = 0.5, +) -> list[pd.DataFrame]: + def overlap_percentage(df1, df2, col): + overlap = pd.merge(df1[[col]], df2[[col]], how="inner", on=col) + return len(set(overlap[col])) / min(len(df1[col]), len(df2[col])) + + # Copy the input so we don't modify it + df_list_c = deepcopy(df_list) + + # Merge into largest dataframes first + df_list_c.sort(key=len, reverse=True) + + merged_dfs = [] + while df_list_c: + base_df = df_list_c.pop(0) + merged = base_df + to_merge = [] + for df in df_list_c: + for col in df.columns: + if overlap_percentage(base_df, df, col) >= overlap_threshold: + to_merge.append((df, col)) + break + for df, col in to_merge: + # Remove the dataframe from the main list if it's been merged + for i in range(len(df_list_c) - 1, -1, -1): + if df_list_c[i][df_list_c[i].columns].equals(df[df.columns]): + df_list_c.pop(i) + merged = ( + pd.concat([merged, df]) + .drop_duplicates() + .reset_index(drop=True) + ) + merged_dfs.append(merged) + return merged_dfs + + def split_file_to_str(file: str | Path, **kwargs) -> str: """ Splits the contents of a Parquet file into chunks and return the chunk From 01c90f49ccb97a12014679dab1db217b3af4040d Mon Sep 17 00:00:00 2001 From: Dan Snow Date: Sun, 15 Dec 2024 22:29:49 -0600 Subject: [PATCH 2/3] Add unit tests for set functions --- data/src/tests/__init__.py | 0 data/src/tests/test_utils.py | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 data/src/tests/__init__.py create mode 100644 data/src/tests/test_utils.py diff --git a/data/src/tests/__init__.py b/data/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/src/tests/test_utils.py b/data/src/tests/test_utils.py new file mode 100644 index 0000000..d91ce94 --- /dev/null +++ b/data/src/tests/test_utils.py @@ -0,0 +1,80 @@ +import pandas as pd +import pytest +from utils.utils import group_by_column_sets, merge_overlapping_df_list + + +class TestGroupSetHandlers: + @pytest.fixture + def df_small_overlap(self): + df = pd.DataFrame( + { + "a": ["a", "b", "a", "b", "a", "b", "c", "c", "d", "d", "d"], + "b": [1, 2, 1, 2, 3, 3, 2, 3, 1, 2, 3], + } + ) + return ("small_overlap", df) + + @pytest.fixture + def df_big_overlap(self): + df = pd.DataFrame( + { + "a": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "b": [1, 2, 3, 1, 2, 3, 3, 4, 5], + } + ) + return ("big_overlap", df) + + @pytest.mark.parametrize("input", ["df_small_overlap", "df_big_overlap"]) + def test_group_by_column_sets_output(self, input, request): + data_name, data = request.getfixturevalue(input) + result = group_by_column_sets(data, "a", "b") + + expected = { + "small_overlap": [ + pd.DataFrame({"a": ["a"] * 2, "b": [1, 3]}), + pd.DataFrame({"a": ["b"] * 2 + ["c"] * 2, "b": [2, 3, 2, 3]}), + pd.DataFrame({"a": ["d"] * 3, "b": [1, 2, 3]}), + ], + "big_overlap": [ + pd.DataFrame({"a": ["a"] * 3 + ["b"] * 3, "b": [1, 2, 3] * 2}), + pd.DataFrame({"a": ["c"] * 3, "b": [3, 4, 5]}), + ], + } + + assert len(result) == len(expected[data_name]) + for res, exp in zip(result, expected[data_name]): + pd.testing.assert_frame_equal( + res.reset_index(drop=True), exp.reset_index(drop=True) + ) + + @pytest.mark.parametrize( + "input,threshold", [("df_small_overlap", 0.9), ("df_big_overlap", 0.2)] + ) + def test_merge_overlapping_df_list_output(self, input, threshold, request): + data_name, data = request.getfixturevalue(input) + initial = group_by_column_sets(data, "a", "b") + result = merge_overlapping_df_list(initial, threshold) + + expected = { + "small_overlap": [ + pd.DataFrame({"a": ["b", "b", "c", "c"], "b": [2, 3, 2, 3]}), + pd.DataFrame( + {"a": ["d"] * 3 + ["a"] * 2, "b": [1, 2, 3, 1, 3]} + ), + ], + "big_overlap": [ + pd.DataFrame( + { + "a": ["a"] * 3 + ["b"] * 3 + ["c"] * 3, + "b": [1, 2, 3] * 2 + [3, 4, 5], + } + ) + ], + } + + # Check total number of rows is preserved + assert sum(len(df) for df in initial) == sum(len(df) for df in result) + for res, exp in zip(result, expected[data_name]): + pd.testing.assert_frame_equal( + res.reset_index(drop=True), exp.reset_index(drop=True) + ) From eed94a43c1bf68ffbb03da887dd3abba585f0798 Mon Sep 17 00:00:00 2001 From: Dan Snow Date: Sun, 15 Dec 2024 22:44:28 -0600 Subject: [PATCH 3/3] Merge only on OD cols --- data/src/utils/times.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/data/src/utils/times.py b/data/src/utils/times.py index b22bd37..71b1208 100644 --- a/data/src/utils/times.py +++ b/data/src/utils/times.py @@ -608,8 +608,11 @@ def many_to_many(self, second_pass: bool = True) -> pd.DataFrame: missing.reset_index(), "origin_id", "destination_id" ) - # Merge missing sets that overlap significantly (think two origins - # that share 1000 destinations but not the 1001st) + # Merge missing sets of OD pairs that overlap significantly (think + # two origins that share 1000 destinations but not the 1001st) + missing_sets = [ + df[["origin_id", "destination_id"]] for df in missing_sets + ] merged_sets = merge_overlapping_df_list(missing_sets, 0.8) # Gut check that both sets contain the same number of rows @@ -622,12 +625,12 @@ def many_to_many(self, second_pass: bool = True) -> pd.DataFrame: "The total number of rows in missing_sets does not" "match the total number of rows in merged_sets" ) + self.config.logger.info( "Found %s unique missing sets. Merged to %s sets", len(missing_sets), len(merged_sets), ) - for idx, missing_set in enumerate(merged_sets): self.config.logger.info("Routing missing set number %s", idx) o_ids = missing_set["origin_id"].unique()