Skip to content

Commit

Permalink
Issue 494: Add basic filtering and cleaning to extract_xri script
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan M committed Sep 13, 2024
1 parent 3f3a287 commit f5386d6
Showing 1 changed file with 144 additions and 33 deletions.
177 changes: 144 additions & 33 deletions silnlp/common/extract_xri.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@

from dataclasses import dataclass
from enum import Enum
from logging import Logger
from pathlib import Path
from typing import List, Optional


logger = logging.getLogger(__package__ + ".extract_xri")
repair_logger = logging.getLogger(logger.name + ".repair")
clean_logger = logging.getLogger(logger.name + ".clean")


class Split(Enum):
Expand All @@ -68,6 +70,21 @@ class SentencePair:
source: str
target: str
split: Split
transformation_applied: bool


@dataclass(frozen=True)
class ColumnSchema:
"""
Represents the mapping of column indexes to data in the tsv.
This is calculated on the fly off the indices of the column names in the header row.
Most of the time it will 0, 1, 2, 3 respectively for the fields below.
"""

id_column_index: int
source_column_index: int
target_column_index: int
split_column_index: int


@dataclass(frozen=True)
Expand Down Expand Up @@ -102,12 +119,16 @@ def get_column_index(column_name: str) -> int:
else:
raise Exception(f"Unable to find expected column '{column_name}' in input file")

id_column_index = get_column_index("id")
source_column_index = get_column_index("source")
target_column_index = get_column_index("target")
split_column_index = get_column_index("split")
column_schema = ColumnSchema(
id_column_index=get_column_index("id"),
source_column_index=get_column_index("source"),
target_column_index=get_column_index("target"),
split_column_index=get_column_index("split"),
)

logger.debug(
f"Column indexes: id={id_column_index} source={source_column_index} target={target_column_index} split={split_column_index}"
f"Column indexes: id={column_schema.id_column_index} source={column_schema.source_column_index} "
+ f"target={column_schema.target_column_index} split={column_schema.split_column_index}"
)

logger.debug("Checking all rows contain 4 cells")
Expand All @@ -119,21 +140,23 @@ def get_column_index(column_name: str) -> int:
# - trailing tab characters added causing extra rows
logger.warning("Not all rows contain 4 cells")

repaired = repair_if_necessary(id_column_index, data_rows)
repaired = repair_if_necessary(column_schema.id_column_index, data_rows)

def parse_and_log_id(row_index: int, row: List[str]) -> int:
logger.debug(f"Loading row {row_index} into sentence pair structure: {row}")
return int(row[id_column_index])
return filter_and_clean(repaired, column_schema)

return [
SentencePair(
id=parse_and_log_id(row_index, row),
source=row[source_column_index],
target=row[target_column_index],
split=Split[row[split_column_index]],
)
for row_index, row in enumerate(repaired)
]

def try_extract_id(row: List[str], id_column_index: int, log: Logger) -> Optional[int]:
"""Tries to find a numerical id in the row passed at the expected position"""
if id_column_index >= len(row):
repair_logger.debug("Short row")
return None
else:
id_str = row[id_column_index]
if id_str.isdigit():
return int(id_str)
else:
log.debug(f"Can't parse id cell: '{id_str}' - assuming row is broken")
return None


def repair_if_necessary(id_column_index: int, rows: List[List[str]]) -> List[List[str]]:
Expand Down Expand Up @@ -163,19 +186,6 @@ def repair_if_necessary(id_column_index: int, rows: List[List[str]]) -> List[Lis
# then the number would split to the next line and be mistaken for an id.
# This is very unlikely and solving it increases complexity so we don't repair that case.

def try_extract_id(row: List[str]) -> Optional[int]:
"""Tries to find a numerical id in the row passed at the expected position"""
if id_column_index >= len(row):
repair_logger.debug("Short row")
return None
else:
id_str = row[id_column_index]
if id_str.isdigit():
return int(id_str)
else:
repair_logger.debug(f"Can't parse id cell: '{id_str}' - assuming row is broken")
return None

repaired: List[List[str]] = []
# Represents the accumulated row data that is gradually populated when a sentence pair is split across many lines
current_row: List[str] = []
Expand All @@ -184,7 +194,7 @@ def try_extract_id(row: List[str]) -> Optional[int]:
if not next_row:
repair_logger.debug(f"Empty line detected at row index={row_index}")
continue
id_opt = try_extract_id(next_row)
id_opt = try_extract_id(next_row, id_column_index, repair_logger)
repair_logger.debug(f"Examining row index={row_index} id={id_opt}: {next_row}")
if id_opt is not None:
# We are starting a new row
Expand All @@ -209,6 +219,104 @@ def try_extract_id(row: List[str]) -> Optional[int]:
return repaired


def filter_and_clean(
rows: List[List[str]],
column_schema: ColumnSchema,
) -> List[SentencePair]:
"""
Applies basic checking and cleaning to the data.
Rows of data that can be processed are cleaned and transformed to a structured SentencePair.
Rows of data that can't be meaningfully processed are excluded.
"""
sentence_pairs: List[SentencePair] = []

clean_logger.info("Starting filtering and cleaning stage")
required_columns = (
max(
[
column_schema.id_column_index,
column_schema.source_column_index,
column_schema.target_column_index,
column_schema.split_column_index,
]
)
+ 1
)
clean_logger.debug(f"Required number of columns for each stage is {required_columns}")

for row_index, row in enumerate(rows):
transformation_applied = False
clean_logger.debug(f"Processing row index={row_index}: {row}")

if len(row) < required_columns:
clean_logger.warning(
f"Row found with only {len(row)} cells, at least {required_columns} expected. Ignoring: {row}"
)
continue

id = try_extract_id(row, column_schema.id_column_index, clean_logger)

if id is None:
# This case should be virtually impossible based on how the repair logic works
clean_logger.warning(
f"Unable to identify id in row - potentially badly formatted or a bug in the repair logic. Ignoring: {row}"
)
continue

def trim(sentence: str, description: str) -> str:
trimmed = sentence.strip()
if trimmed != sentence:
clean_logger.debug(
f"Boundary whitespace trimmed off '{description}' field. "
+ f"Number of trimmed chars: {len(sentence) - len(trimmed)}"
)
nonlocal transformation_applied
transformation_applied = True
return trimmed

source = trim(row[column_schema.source_column_index], "source")
target = trim(row[column_schema.target_column_index], "target")

if target == "!":
clean_logger.debug("Target sentence is '!' indicating it is not translated. Ignoring.")
continue

split_text = trim(row[column_schema.split_column_index], "split").lower()
if split_text not in ["train", "dev", "test"]:
clean_logger.warning(
f"Split value '{split_text}' is not a recognized value. Keeping sentence but assigning to training data"
)
transformation_applied = True
split = Split.train
else:
split = Split[split_text]

clean_logger.debug(
f"Successfully parsed and cleaned row index={row_index} id={id}. "
+ f"Transformations applied? {transformation_applied}"
)

sentence_pairs.append(
SentencePair(
id=id,
source=source,
target=target,
split=split,
transformation_applied=transformation_applied,
)
)

num_modified = len([sentence_pair for sentence_pair in sentence_pairs if sentence_pair.transformation_applied])
clean_logger.info(
"Finished filtering and cleaning stage. "
+ f"{len(rows)} rows ingested. "
+ f"{len(sentence_pairs)} survived. "
+ f"{len(rows) - len(sentence_pairs)} removed. "
+ f"{num_modified} survivors were transformed in some way."
)
return sentence_pairs


def write_output_file(filepath: Path, sentences: List[str]) -> None:
logger.debug(f"Writing {len(sentences)} sentences to file: {filepath}")
with open(filepath, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -267,6 +375,7 @@ def run(cli_input: CliInput) -> None:
log_level = getattr(logging, cli_input.log_level.upper())
logger.setLevel(log_level)
repair_logger.setLevel(log_level)
clean_logger.setLevel(log_level)
logger.info("Starting script")
sentence_pairs = load_sentence_pairs(cli_input.input_file_path)
create_extract_files(cli_input, sentence_pairs)
Expand All @@ -283,7 +392,9 @@ def main() -> None:
parser.add_argument(
"-output", help="Optional path to the output directory where extract files are generated", type=str
)
parser.add_argument("-log_level", help="Optional parameter to override the default logging level for this script", type=str)
parser.add_argument(
"-log_level", help="Optional parameter to override the default logging level for this script", type=str
)
args = parser.parse_args()

cli_input = CliInput(
Expand Down

0 comments on commit f5386d6

Please sign in to comment.