-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathtest_sentence_labeling.py
205 lines (181 loc) · 8.35 KB
/
test_sentence_labeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from typing import cast
import pytest
from flair.data import Sentence
from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence_from_entity_offsets
@pytest.fixture(params=["resume1.txt"])
def resume(request, resources_path) -> str:
filepath = resources_path / "text_sequences" / request.param
with open(filepath, encoding="utf8") as file:
text_content = file.read()
return text_content
@pytest.fixture
def parsed_resume_dict(resume) -> dict:
return {
"raw_text": resume,
"entities": [
CharEntity(20, 40, "dummy_label1", "Dummy Text 1"),
CharEntity(250, 300, "dummy_label2", "Dummy Text 2"),
CharEntity(700, 810, "dummy_label3", "Dummy Text 3"),
CharEntity(3900, 4000, "dummy_label4", "Dummy Text 4"),
],
}
@pytest.fixture
def small_token_limit_resume() -> dict:
return {
"raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah "
"Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah "
"Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah "
"Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah "
"Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Gained "
"proficiency in juggling and scaring children.",
"entities": [
CharEntity(0, 18, "EXPERIENCE.TITLE", ""),
CharEntity(19, 29, "DATE.START_DATE", ""),
CharEntity(31, 42, "DATE.END_DATE", ""),
CharEntity(450, 510, "EXPERIENCE.DESCRIPTION", ""),
],
}
@pytest.fixture
def small_token_limit_response() -> list[Sentence]:
"""Recreates expected response Sentences."""
chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of")
chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE")
chunk0[2:4].add_label("June 2020", "DATE.START_DATE")
chunk0[5:7].add_label("August 2021", "DATE.END_DATE")
chunk1 = Sentence("Blah Blah Blah Blah Blah Blah Blah Bl")
chunk2 = Sentence("ah Blah Gained proficiency in juggling and scaring children .")
chunk2[0:10].add_label("ah Blah Gained proficiency in juggling and scaring children .", "EXPERIENCE.DESCRIPTION")
return [chunk0, chunk1, chunk2]
class TestChunking:
def test_empty_string(self):
sentences = create_labeled_sentence_from_entity_offsets("", [])
assert len(sentences) == 0
def check_tokens(self, sentence: Sentence, expected_tokens: list[str]):
assert len(sentence.tokens) == len(expected_tokens)
assert [token.text for token in sentence.tokens] == expected_tokens
for token, expected_token in zip(sentence.tokens, expected_tokens):
assert token.text == expected_token
def check_token_entities(self, sentence: Sentence, expected_labels: list[TokenEntity]):
assert len(sentence.labels) == len(expected_labels)
for label, expected_label in zip(sentence.labels, expected_labels):
assert label.value == expected_label.label
span = cast(Sentence, label.data_point)
assert span.tokens[0]._internal_index is not None
assert span.tokens[0]._internal_index - 1 == expected_label.start_token_idx
assert span.tokens[-1]._internal_index is not None
assert span.tokens[-1]._internal_index - 1 == expected_label.end_token_idx
def check_split_entities(self, entity_labels, sentence: Sentence):
"""Ensure that no entities are split over chunks (except entities longer than the token limit)."""
for entity in entity_labels:
entity_start, entity_end = entity.start_char_idx, entity.end_char_idx
assert entity_start >= 0 and entity_end <= len(
sentence
), f"Entity {entity} is not within a single chunk interval"
@pytest.mark.parametrize(
"test_text, expected_text",
[
("test text", "test text"),
("a", "a"),
("this ", "this"),
],
)
def test_short_text(self, test_text: str, expected_text: str):
"""Short texts that should fit nicely into a single chunk."""
chunks = create_labeled_sentence_from_entity_offsets(test_text, [])
assert chunks.text == expected_text
def test_create_labeled_sentence(self, parsed_resume_dict: dict):
create_labeled_sentence_from_entity_offsets(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"])
@pytest.mark.parametrize(
"test_text, entities, expected_tokens, expected_labels",
[
(
"Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.",
[
CharEntity(0, 28, "RESPONSIBILITY", "Led a team of five engineers"),
CharEntity(30, 74, "ACHIEVEMENT", "It's important to note the project's success"),
CharEntity(76, 123, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"),
CharEntity(125, 173, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"),
],
[
"Led",
"a",
"team",
"of",
"five",
"engineers",
".",
"It",
"'s",
"important",
"to",
"note",
"the",
"project",
"'s",
"success",
".",
"We",
"'ve",
"implemented",
"state-of-the-art",
"technologies",
".",
"Co-ordinated",
"efforts",
"with",
"cross-functional",
"teams",
".",
],
[
TokenEntity(0, 5, "RESPONSIBILITY"),
TokenEntity(7, 15, "ACHIEVEMENT"),
TokenEntity(17, 21, "ACHIEVEMENT"),
TokenEntity(23, 27, "RESPONSIBILITY"),
],
),
],
)
def test_contractions_and_hyphens(
self, test_text: str, entities: list[CharEntity], expected_tokens: list[str], expected_labels: list[TokenEntity]
):
sentence = create_labeled_sentence_from_entity_offsets(test_text, entities)
self.check_tokens(sentence, expected_tokens)
self.check_token_entities(sentence, expected_labels)
@pytest.mark.parametrize(
"test_text, entities",
[
(
"This is a long text. " * 100,
[CharEntity(0, 1000, "dummy_label1", "Dummy Text 1")],
)
],
)
def test_long_text(self, test_text: str, entities: list[CharEntity]):
"""Test for handling long texts that should be split into multiple chunks."""
create_labeled_sentence_from_entity_offsets(test_text, entities)
@pytest.mark.parametrize(
"test_text, entities, expected_labels",
[
(
"Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.",
[
CharEntity(0, 6, "LABEL", "Hello!"),
CharEntity(7, 30, "LABEL", "Is your company hiring?"),
CharEntity(31, 61, "LABEL", "I am available for employment."),
CharEntity(62, 85, "LABEL", "Contact me at 5:00 p.m."),
],
[
TokenEntity(0, 1, "LABEL"),
TokenEntity(2, 6, "LABEL"),
TokenEntity(7, 12, "LABEL"),
TokenEntity(13, 18, "LABEL"),
],
)
],
)
def test_text_with_punctuation(
self, test_text: str, entities: list[CharEntity], expected_labels: list[TokenEntity]
):
sentence = create_labeled_sentence_from_entity_offsets(test_text, entities)
self.check_token_entities(sentence, expected_labels)