Skip to content

Commit

Permalink
Issue 494: Refactor normalization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan M committed Oct 18, 2024
1 parent a63fffa commit 4aab55c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 120 deletions.
62 changes: 31 additions & 31 deletions silnlp/common/normalize_extracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ def write_extract_file(path: Path, sentences: List[str]) -> None:
f.write(f"{sentence}\n")


def display_warnings(normalized_summaries_with_line_numbers: List[Tuple[SentenceNormalizationSummary, int]]) -> None:
warnings_with_line_number: List[Tuple[NormalizationWarning, int]] = [
(warning, line_number)
for summary, line_number in normalized_summaries_with_line_numbers
for warning in summary.warnings
]

if len(warnings_with_line_number) > 0:
logger.warning(f"{len(warnings_with_line_number)} warnings found")
for warning, line_number in warnings_with_line_number:
# Pretty print out all the transformation relative to the original string
# TODO This is just for debugging and will be replaced by better reporting
logger.warning(100 * "=")
sentence = warning.slice.outer
indent = 12 * " "
logger.warning(f"line: {line_number}")
num_blocks_of_10 = len(sentence) // 10 + 1
tens_row = (" " * 9).join([str(i) for i in range(0, num_blocks_of_10)])
# analysis_indent = 12 * " "
logger.warning(indent + tens_row)
logger.warning(indent + "0123456789" * num_blocks_of_10)
logger.warning(indent[0:-1] + f"'{sentence}'")
slice = warning.slice
logger.warning(indent + slice.start_index * " " + len(slice.slice) * "^")
logger.warning(indent + slice.start_index * " " + f"({slice.start_index},{slice.end_index})")
logger.warning(f">>> WARNING_CODE: {warning.warning_code}")
logger.warning(f">>> DESCRIPTION: {warning.description}")
else:
logger.info("No warnings found")


def run(cli_input: CliInput) -> None:
if cli_input.log_level is not None:
log_level = getattr(logging, cli_input.log_level.upper())
Expand Down Expand Up @@ -158,37 +189,6 @@ def run(cli_input: CliInput) -> None:
logger.info("Completed script")


def display_warnings(normalized_summaries_with_line_numbers: List[Tuple[SentenceNormalizationSummary, int]]) -> None:
warnings_with_line_number: List[Tuple[NormalizationWarning, int]] = [
(warning, line_number)
for summary, line_number in normalized_summaries_with_line_numbers
for warning in summary.warnings
]

if len(warnings_with_line_number) > 0:
logger.warning(f"{len(warnings_with_line_number)} warnings found")
for warning, line_number in warnings_with_line_number:
# Pretty print out all the transformation relative to the original string
# TODO This is just for debugging and will be replaced by better reporting
logger.warning(100 * "=")
sentence = warning.slice.outer
indent = 12 * " "
logger.warning(f"line: {line_number}")
num_blocks_of_10 = len(sentence) // 10 + 1
tens_row = (" " * 9).join([str(i) for i in range(0, num_blocks_of_10)])
# analysis_indent = 12 * " "
logger.warning(indent + tens_row)
logger.warning(indent + "0123456789" * num_blocks_of_10)
logger.warning(indent[0:-1] + f"'{sentence}'")
slice = warning.slice
logger.warning(indent + slice.start_index * " " + len(slice.slice) * "^")
logger.warning(indent + slice.start_index * " " + f"({slice.start_index},{slice.end_index})")
logger.warning(f">>> WARNING_CODE: {warning.warning_code}")
logger.warning(f">>> DESCRIPTION: {warning.description}")
else:
logger.info("No warnings found")


def main() -> None:
parser = argparse.ArgumentParser(description="Normalizes extract files")

Expand Down
13 changes: 13 additions & 0 deletions silnlp/common/normalize_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ def test_consecutive_punctuation_doesnt_prevent_normalizing_single_punctuation_a
expected_normalized="Hello, . there! How are things !?",
)

def test_warnings_generated_for_multiple_consecutive_punctuation(self):
sentence = "Hello, . there ! How , are things !?"
summary = standard_normalizer.normalize(sentence)
consecutive_punctuation_warnings = sorted(
filter(lambda warning: warning.warning_code == WarningCode.MULTIPLE_PUNCTUATION, summary.warnings),
key=lambda warning: warning.slice.start_index,
)
self.assertEqual(len(consecutive_punctuation_warnings), 2)
self.assertEqual(consecutive_punctuation_warnings[0].slice.start_index, 5)
self.assertEqual(consecutive_punctuation_warnings[0].slice.end_index, 8)
self.assertEqual(consecutive_punctuation_warnings[1].slice.start_index, 35)
self.assertEqual(consecutive_punctuation_warnings[1].slice.end_index, 37)

def test_consecutive_punctuation_doesnt_prevent_shrinking_of_consecutive_whitespace_around_it(self):
self.run_test(unnormalized="Hello ,. \t there", expected_normalized="Hello ,. there")

Expand Down
173 changes: 84 additions & 89 deletions silnlp/common/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def build_slice(start_index: int, end_index: int, outer: str) -> StringSlice:
return StringSlice(start_index=start_index, end_index=end_index, slice=outer[start_index:end_index], outer=outer)


def slice_contains(outer: StringSlice, inner: StringSlice) -> bool:
"""
Returns whether the alleged outer slice contains the inner.
It's implicitly assumed the slices correspond to the same string
"""
return (outer.start_index <= inner.start_index) and (outer.end_index >= inner.end_index)


# TODO - delete when you're confident it's not going to be used
def pretty_print_slice(slice: StringSlice) -> None:
logger.debug(slice.outer)
Expand Down Expand Up @@ -123,7 +131,6 @@ def unicode_hex(character: str) -> str:
return "U+" + code.upper()



class Normalizer:
"""
Encapsulates the state required to normalize sentences
Expand All @@ -138,16 +145,24 @@ def __init__(self, punctuation_normalization_rules: List[PunctuationNormalizatio
# TODO - ensure no whitespace

self.punctuation_char_2_normalization_rule: Dict[str, PunctuationNormalizationRule] = {
rule.character: rule
for rule in punctuation_normalization_rules
rule.character: rule for rule in punctuation_normalization_rules
}
self.supported_punctuation: Set[str] = set(self.punctuation_char_2_normalization_rule.keys())

self.consecutive_spaces = regex.compile("\\s+")
self.single_whitespace = regex.compile("\\s")

# Regex escape all the punctuation characters defined in the rules so that they be shoved into a regex
escaped_punctuation_chars = "".join(regex.escape(rule.character) for rule in punctuation_normalization_rules)
self.punctuation_regex = regex.compile(f"[{escaped_punctuation_chars}\s]+")
# Matches a single punctuation character with optional whitespace before and after
# Negative look behind and ahead is used to stop it matching punctuation that's part of a multiple punctuation group
self.single_punctuation_with_optional_whitespace_regex = regex.compile(
f"(?<![{escaped_punctuation_chars}\s])\s*[{escaped_punctuation_chars}]\s*(?![{escaped_punctuation_chars}\s])"
)
# Matches a starting punctuation char, then any number of punctuation/whitespace, then a closing punctuation character
self.multiple_punctuation_regex = regex.compile(
f"[{escaped_punctuation_chars}][{escaped_punctuation_chars}\s]*[{escaped_punctuation_chars}]"
)

self.not_letters_or_numbers_or_whitespace_regex = regex.compile("""[^\p{N}\p{L}\s]""")

Expand All @@ -157,97 +172,24 @@ def normalize(self, sentence: str) -> SentenceNormalizationSummary:
"""
logger.debug(f"Normalizing '{sentence}'")

# Boundary whitespace is dealt with specially first and later logic analyses the trimmed input
boundary_trim_transformations, sentence_trimmed, trim_offset = self.compute_boundary_transformations(sentence)

# Find groups of punctuation within the trimmed string
# Any results have to be shifted back to the coordinate system of the original sentence
punctuation_slices = find_slices(self.punctuation_regex, sentence_trimmed)

# Categorize each slice found
consecutive_spaces_slices: List[StringSlice] = []
single_punctuation_slices: List[StringSlice] = []
multiple_punctuation_warnings: List[NormalizationWarning] = []
for slice in punctuation_slices:
whitespace_removed = regex.sub(self.consecutive_spaces, "", slice.slice)
# match is completely whitespace
if len(whitespace_removed) == 0:
consecutive_spaces_slices.append(slice)
# match has one punctuation character
elif len(whitespace_removed) == 1:
single_punctuation_slices.append(slice)
# match has 2+ punctuation character
else:
# For this case, we don't transform the punctuation, but there's still potentially
# consecutive spaces that can be shrunk down to a single space
# We search within the current slice for consecutive spaces - the coordinate systems
# within those slices need to be shifted back to the main slice
consecutive_spaces_sub_slices = [
shift_slice(spaces_slice, slice.start_index, slice.outer)
for spaces_slice in find_slices(self.consecutive_spaces, slice.slice)
]
consecutive_spaces_slices.extend(consecutive_spaces_sub_slices)

# Generate a warning that spans just the punctuation characters in the slice,
# and not boundary whitespace
left_trimmed = slice.slice.lstrip()
slice_trim_offset = len(slice.slice) - len(left_trimmed)
all_punctuation_trimmed = left_trimmed.rstrip()
warning_start_index = slice.start_index + trim_offset + slice_trim_offset
multiple_punctuation_warnings.append(
NormalizationWarning(
slice=build_slice(
start_index=warning_start_index,
end_index=warning_start_index + len(all_punctuation_trimmed),
outer=sentence,
),
warning_code=WarningCode.MULTIPLE_PUNCTUATION,
description="Multiple consecutive punctuation characters (ignoring whitespace) - currently this is not normalized",
)
)
logger.debug(f" #consecutive space slices={len(consecutive_spaces_slices)}")
logger.debug(f" #single punctuation slices={len(single_punctuation_slices)}")
logger.debug(f"#multiple punctuation slices={len(multiple_punctuation_warnings)}")
all_transformations: List[SentenceTransformation] = self.find_transformations_sorted(sentence)

# Convert consecutive space slices into transformations
consecutive_spaces_transformations = [
SentenceTransformation(
slice=shift_slice(slice, trim_offset, sentence),
replacement=" ",
description="Consecutive whitespace found",
multiple_punctuation_warnings: List[NormalizationWarning] = [
NormalizationWarning(
slice=slice,
warning_code=WarningCode.MULTIPLE_PUNCTUATION,
description="Multiple consecutive punctuation characters (ignoring whitespace) - currently this is not normalized",
)
for slice in consecutive_spaces_slices
# Don't generate transformations for single spaces
if slice.slice != " "
for slice in find_slices(self.multiple_punctuation_regex, sentence)
]

# Convert single punctuation slices into transformations
single_punctuation_transformations: List[SentenceTransformation] = []
for slice in single_punctuation_slices:

# Figure out the punctuation character that was found in the slice and find the associated normalization rule for it
punctuation_char = regex.sub(self.consecutive_spaces, "", slice.slice)
rule = self.punctuation_char_2_normalization_rule[punctuation_char]
normalized = self.normalize_single_punctuation_slice(rule, slice)
# Some normalizations couldn't be applied or don't actually transform the text
if normalized is not None and normalized != slice.slice:
single_punctuation_transformations.append(
SentenceTransformation(
slice=shift_slice(slice, trim_offset, sentence),
replacement=normalized,
description=f"Punctuation ({punctuation_char}) normalized by rule {rule.category}",
)
)

false_negative_warnings: List[NormalizationWarning] = self.search_false_negatives(sentence)

# TODO - add other kinds of warnings
all_warnings = multiple_punctuation_warnings + false_negative_warnings

all_transformations = sorted(
consecutive_spaces_transformations + single_punctuation_transformations + boundary_trim_transformations,
key=lambda transformation: transformation.slice.start_index,
)
logger.debug(f"#transformations={len(all_transformations)}")
logger.debug(f"#warnings={len(all_warnings)}")

# Pretty print out all the transformation relative to the original string
# TODO This is just for debugging and will be replaced by better reporting
Expand All @@ -269,8 +211,6 @@ def normalize(self, sentence: str) -> SentenceNormalizationSummary:
+ transformation.description
)

# TODO - check the transformations aren't overlapping

# Rebuild the string by applying all the transformations
# We extract the parts unaffected by normalization, then rebuild by interleaving them with the normalized parts
# Example:
Expand Down Expand Up @@ -301,6 +241,57 @@ def normalize(self, sentence: str) -> SentenceNormalizationSummary:
warnings=all_warnings,
)

def find_transformations_sorted(self, sentence: str) -> List[SentenceTransformation]:
# Boundary whitespace is dealt with specially first and later logic analyses the trimmed input
boundary_trim_transformations, sentence_trimmed, trim_offset = self.compute_boundary_transformations(sentence)

single_punctuation_transformations: List[SentenceTransformation] = []
for slice in find_slices(self.single_punctuation_with_optional_whitespace_regex, sentence_trimmed):
# Figure out the punctuation character that was found in the slice and find the associated normalization rule for it
punctuation_char = regex.sub(self.consecutive_spaces, "", slice.slice)
rule = self.punctuation_char_2_normalization_rule[punctuation_char]
normalized = self.normalize_single_punctuation_slice(rule, slice)
# Some normalizations couldn't be applied or don't actually transform the text
if normalized is not None and normalized != slice.slice:
single_punctuation_transformations.append(
SentenceTransformation(
slice=shift_slice(slice, trim_offset, sentence),
replacement=normalized,
description=f"Punctuation ({punctuation_char}) normalized by rule {rule.category}",
)
)

consecutive_spaces_transformations = [
SentenceTransformation(
slice=shift_slice(slice, trim_offset, sentence),
replacement=" ",
description="Whitespace normalized to a single space",
)
for slice in find_slices(self.consecutive_spaces, sentence_trimmed)
# Don't create transformations for a single space as the before and after are identical
if slice.slice != " "
]
# Some of these will overlap with the single punctuation transformations which are already going to normalize whitespace
# so those need to be knocked out
# Note this needs to be done _after_ shifting the slice
consecutive_spaces_transformations = list(
filter(
lambda consecutive_space_transformation: not any(
slice_contains(
outer=single_punctuation_transformation.slice, inner=consecutive_space_transformation.slice
)
for single_punctuation_transformation in single_punctuation_transformations
),
consecutive_spaces_transformations,
)
)

# TODO - put in a general check that the transformations aren't overlapping
return sorted(
boundary_trim_transformations + consecutive_spaces_transformations + single_punctuation_transformations,
key=lambda transformation: transformation.slice.start_index,
)

def normalize_single_punctuation_slice(
self, punctuation_rule: PunctuationNormalizationRule, slice: StringSlice
) -> Optional[str]:
Expand Down Expand Up @@ -391,7 +382,11 @@ def search_false_negatives(self, sentence: str) -> List[NormalizationWarning]:
"""
potential_false_negatives = find_slices(self.not_letters_or_numbers_or_whitespace_regex, sentence)
return [
NormalizationWarning(slice, WarningCode.FALSE_NEGATIVE_CANDIDATE, f"Character '{slice.slice}' ({unicode_hex(slice.slice)}) is not a letter or digit or whitespace and is not listed as punctuation. Potential false negative.")
NormalizationWarning(
slice,
WarningCode.FALSE_NEGATIVE_CANDIDATE,
f"Character '{slice.slice}' ({unicode_hex(slice.slice)}) is not a letter or digit or whitespace and is not listed as punctuation. Potential false negative.",
)
for slice in potential_false_negatives
if slice.slice not in self.supported_punctuation
]
Expand Down

0 comments on commit 4aab55c

Please sign in to comment.