Skip to content

Commit

Permalink
Issue 494: Add deduplication logic to extract_xri script
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan M committed Sep 13, 2024
1 parent f5386d6 commit 99dc80c
Showing 1 changed file with 58 additions and 4 deletions.
62 changes: 58 additions & 4 deletions silnlp/common/extract_xri.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from dataclasses import dataclass
from enum import Enum
from itertools import groupby
from logging import Logger
from pathlib import Path
from typing import List, Optional
Expand All @@ -56,6 +57,8 @@
logger = logging.getLogger(__package__ + ".extract_xri")
repair_logger = logging.getLogger(logger.name + ".repair")
clean_logger = logging.getLogger(logger.name + ".clean")
deduplication_logger = logging.getLogger(logger.name + ".deduplication")
all_loggers = [logger, repair_logger, clean_logger, deduplication_logger]


class Split(Enum):
Expand Down Expand Up @@ -142,7 +145,9 @@ def get_column_index(column_name: str) -> int:

repaired = repair_if_necessary(column_schema.id_column_index, data_rows)

return filter_and_clean(repaired, column_schema)
cleaned = filter_and_clean(repaired, column_schema)

return remove_duplicates(cleaned)


def try_extract_id(row: List[str], id_column_index: int, log: Logger) -> Optional[int]:
Expand Down Expand Up @@ -317,6 +322,56 @@ def trim(sentence: str, description: str) -> str:
return sentence_pairs


def remove_duplicates(sentence_pairs: List[SentencePair]) -> List[SentencePair]:
"""
Removes duplicate translations from the sentence pairs, choosing the one it thinks is most likely to be the best translation.
Duplicates are defined as sentence pairs with the same source string (case-sensitive currently)
after basic cleaning has been applied (like trimming).
For example
- "this is a source" and "this is a source " would be considered duplicates as they differ only on boundary whitespace
- "This is a source" and "this is a source" currently wouldn't be considered duplicates
When duplicates are detected, the target text with length closest to the source string is chosen,
and other sentences are excluded.
In the case of a tie, the earlier one from the input is taken.
"""
deduplication_logger.info("Starting deduplication stage")
# groupby requires data to be contiguous on the grouping key otherwise the groups get fragmented
sentence_pairs_sorted_by_source = sorted(sentence_pairs, key=lambda sentence_pair: sentence_pair.source)

grouped_by_source = groupby(sentence_pairs_sorted_by_source, key=lambda sentence_pair: sentence_pair.source)

def choose_best(source: str, duplicates: List[SentencePair]) -> SentencePair:
"""Given a group of sentence pairs with the same source text, this picks the representative from the group to keep"""
if len(duplicates) == 1:
# No duplicates - most common case
return duplicates[0]
else:
# Duplictes found, choose the best match
closest_match = min(
duplicates, key=lambda sentence_pair: abs(len(sentence_pair.target) - len(source))
)
deduplication_logger.error(
f"{len(duplicates)} duplicate sentence pairs found with source: '{source}'. "
+ f"Id's are {[sentence_pair.id for sentence_pair in duplicates]}. "
+ f"Id {closest_match.id} chosen with target: '{closest_match.target}'"
)
return closest_match

deduplicated = [choose_best(source, list(duplicates)) for source, duplicates in grouped_by_source]

deduplication_logger.info(
"Finished deduplication stage. "
+ f"{len(sentence_pairs)} pairs passed in. "
+ f"{len(deduplicated)} pairs returned. "
+ f"{len(sentence_pairs) - len(deduplicated)} duplicate pairs removed."
)

return deduplicated


def write_output_file(filepath: Path, sentences: List[str]) -> None:
logger.debug(f"Writing {len(sentences)} sentences to file: {filepath}")
with open(filepath, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -373,9 +428,8 @@ def build_output_path(iso: str) -> Path:
def run(cli_input: CliInput) -> None:
if cli_input.log_level is not None:
log_level = getattr(logging, cli_input.log_level.upper())
logger.setLevel(log_level)
repair_logger.setLevel(log_level)
clean_logger.setLevel(log_level)
for log in all_loggers:
log.setLevel(log_level)
logger.info("Starting script")
sentence_pairs = load_sentence_pairs(cli_input.input_file_path)
create_extract_files(cli_input, sentence_pairs)
Expand Down

0 comments on commit 99dc80c

Please sign in to comment.