From 082e84569841625ecf467c05bfc6e9c12464ad3c Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Mon, 3 Feb 2025 20:28:07 -0800 Subject: [PATCH] fix: doc strings and function name --- flair/training_utils.py | 11 ++++++----- tests/test_sentence_labeling.py | 14 +++++++------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/flair/training_utils.py b/flair/training_utils.py index ce15bdb6e5..915df609a8 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -463,24 +463,25 @@ def create_labeled_sentence_from_tokens( return sentence -def create_labeled_sentence( +def create_labeled_sentence_from_entity_offsets( text: str, entities: list[CharEntity], token_limit: float = inf, ) -> Sentence: - """Chunks and labels a text from a list of entity annotations. + """Creates a labeled sentence from a text and a list of entity annotations. The function explicitly tokenizes the text and labels separately, ensuring entity labels are - not partially split across tokens. + not partially split across tokens. The sentence is truncated if a token limit is set. Args: text (str): The full text to be tokenized and labeled. entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the format (start_char_index, end_char_index, entity_class, entity_text). - token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking + token_limit: numerical value that determines the maximum token length of the sentence. + use inf to not perform chunking Returns: - A list of labeled Sentence objects representing the chunks of the original text + A labeled Sentence objects representing the text and entity annotations. """ tokens: list[Token] = [] current_index = 0 diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py index 56742da4df..0bfb6dce94 100644 --- a/tests/test_sentence_labeling.py +++ b/tests/test_sentence_labeling.py @@ -3,7 +3,7 @@ import pytest from flair.data import Sentence -from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence +from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence_from_entity_offsets @pytest.fixture(params=["resume1.txt"]) @@ -63,7 +63,7 @@ def small_token_limit_response() -> list[Sentence]: class TestChunking: def test_empty_string(self): - sentences = create_labeled_sentence("", []) + sentences = create_labeled_sentence_from_entity_offsets("", []) assert len(sentences) == 0 def check_tokens(self, sentence: Sentence, expected_tokens: list[str]): @@ -101,11 +101,11 @@ def check_split_entities(self, entity_labels, sentence: Sentence): ) def test_short_text(self, test_text: str, expected_text: str): """Short texts that should fit nicely into a single chunk.""" - chunks = create_labeled_sentence(test_text, []) + chunks = create_labeled_sentence_from_entity_offsets(test_text, []) assert chunks.text == expected_text def test_create_labeled_sentence(self, parsed_resume_dict: dict): - create_labeled_sentence(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) + create_labeled_sentence_from_entity_offsets(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) @pytest.mark.parametrize( "test_text, entities, expected_tokens, expected_labels", @@ -161,7 +161,7 @@ def test_create_labeled_sentence(self, parsed_resume_dict: dict): def test_contractions_and_hyphens( self, test_text: str, entities: list[CharEntity], expected_tokens: list[str], expected_labels: list[TokenEntity] ): - sentence = create_labeled_sentence(test_text, entities) + sentence = create_labeled_sentence_from_entity_offsets(test_text, entities) self.check_tokens(sentence, expected_tokens) self.check_token_entities(sentence, expected_labels) @@ -176,7 +176,7 @@ def test_contractions_and_hyphens( ) def test_long_text(self, test_text: str, entities: list[CharEntity]): """Test for handling long texts that should be split into multiple chunks.""" - create_labeled_sentence(test_text, entities) + create_labeled_sentence_from_entity_offsets(test_text, entities) @pytest.mark.parametrize( "test_text, entities, expected_labels", @@ -201,5 +201,5 @@ def test_long_text(self, test_text: str, entities: list[CharEntity]): def test_text_with_punctuation( self, test_text: str, entities: list[CharEntity], expected_labels: list[TokenEntity] ): - sentence = create_labeled_sentence(test_text, entities) + sentence = create_labeled_sentence_from_entity_offsets(test_text, entities) self.check_token_entities(sentence, expected_labels)