-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathtest_datasets_biomedical.py
139 lines (119 loc) · 4.56 KB
/
test_datasets_biomedical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional
from flair.datasets.biomedical import (
CoNLLWriter,
Entity,
InternalBioNerDataset,
filter_nested_entities,
)
from flair.splitter import NoSentenceSplitter, SentenceSplitter
from flair.tokenization import SpaceTokenizer
logger = logging.getLogger("flair")
logger.propagate = True
def test_write_to_conll():
text = "This is entity1 entity2 and a long entity3"
dataset = InternalBioNerDataset(
documents={"1": text},
entities_per_document={
"1": [
Entity((text.find("entity1"), text.find("entity1") + len("entity1")), "E"),
Entity((text.find("entity2"), text.find("entity2") + len("entity2")), "E"),
Entity(
(
text.find("a long entity3"),
text.find("a long entity3") + len("a long entity3"),
),
"E",
),
]
},
entity_types=["E"],
)
expected_labeling = [
"This O +",
"is O +",
"entity1 B-E +",
"entity2 B-E +",
"and O +",
"a B-E +",
"long I-E +",
"entity3 I-E -",
]
assert_conll_writer_output(dataset, expected_labeling)
def test_conll_writer_one_token_multiple_entities1():
text = "This is entity1 entity2"
dataset = InternalBioNerDataset(
documents={"1": text},
entities_per_document={
"1": [
Entity((text.find("entity1"), text.find("entity1") + 2), "E"),
Entity((text.find("tity1"), text.find("tity1") + 5), "E"),
Entity((text.find("entity2"), text.find("entity2") + len("entity2")), "E"),
]
},
entity_types=["E"],
)
assert_conll_writer_output(dataset, ["This O +", "is O +", "entity1 B-E +", "entity2 B-E -"])
def test_conll_writer_one_token_multiple_entities2():
text = "This is entity1 entity2"
dataset = InternalBioNerDataset(
documents={"1": text},
entities_per_document={
"1": [
Entity((text.find("entity1"), text.find("entity1") + 2), "E"),
Entity((text.find("tity1"), text.find("tity1") + 5), "E"),
]
},
entity_types=["E"],
)
assert_conll_writer_output(dataset, ["This O +", "is O +", "entity1 B-E +", "entity2 O -"])
def assert_conll_writer_output(
dataset: InternalBioNerDataset,
expected_output: list[str],
sentence_splitter: Optional[SentenceSplitter] = None,
):
fd, outfile_path = tempfile.mkstemp()
try:
sentence_splitter = sentence_splitter if sentence_splitter else NoSentenceSplitter(tokenizer=SpaceTokenizer())
writer = CoNLLWriter(sentence_splitter=sentence_splitter)
writer.write_to_conll(dataset, Path(outfile_path))
with open(outfile_path) as f:
contents = [line.strip() for line in f.readlines() if line.strip()]
finally:
os.close(fd)
os.remove(outfile_path)
assert contents == expected_output
def test_filter_nested_entities(caplog):
entities_per_document = {
"d0": [Entity((0, 1), "t0"), Entity((2, 3), "t1")],
"d1": [Entity((0, 6), "t0"), Entity((2, 3), "t1"), Entity((4, 5), "t2")],
"d2": [Entity((0, 3), "t0"), Entity((3, 5), "t1")],
"d3": [Entity((0, 3), "t0"), Entity((2, 5), "t1"), Entity((4, 7), "t2")],
"d4": [Entity((0, 4), "t0"), Entity((3, 5), "t1")],
"d5": [Entity((0, 4), "t0"), Entity((3, 9), "t1")],
"d6": [Entity((0, 4), "t0"), Entity((2, 6), "t1")],
}
target = {
"d0": [Entity((0, 1), "t0"), Entity((2, 3), "t1")],
"d1": [Entity((2, 3), "t1"), Entity((4, 5), "t2")],
"d2": [Entity((0, 3), "t0"), Entity((3, 5), "t1")],
"d3": [Entity((0, 3), "t0"), Entity((4, 7), "t2")],
"d4": [Entity((0, 4), "t0")],
"d5": [Entity((3, 9), "t1")],
"d6": [Entity((0, 4), "t0")],
}
dataset = InternalBioNerDataset(documents={}, entities_per_document=entities_per_document)
caplog.set_level(logging.WARNING)
filter_nested_entities(dataset)
assert "WARNING: Corpus modified by filtering nested entities." in caplog.text
for key, entities in dataset.entities_per_document.items():
assert key in target
assert len(target[key]) == len(entities)
for e1, e2 in zip(
sorted(target[key], key=lambda x: str(x)),
sorted(entities, key=lambda x: str(x)),
):
assert str(e1) == str(e2)