Skip to content

Commit

Permalink
Different intialization of Span, adapted unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Mar 22, 2023
1 parent d16a47b commit b5d221c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 46 deletions.
102 changes: 58 additions & 44 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,6 @@ def __init__(self, sentence):
super().__init__()
self.sentence: Sentence = sentence

def _init_labels(self):
if self.unlabeled_identifier in self.sentence._known_spans:
self.annotation_layers = self.sentence._known_spans[self.unlabeled_identifier].annotation_layers
del self.sentence._known_spans[self.unlabeled_identifier]
self.sentence._known_spans[self.unlabeled_identifier] = self

def add_label(self, typename: str, value: str, score: float = 1.0):
super().add_label(typename, value, score)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score))
Expand Down Expand Up @@ -469,12 +463,12 @@ class Token(_PartOfSentence):
"""

def __init__(
self,
text: str,
head_id: int = None,
whitespace_after: int = 1,
start_position: int = 0,
sentence=None,
self,
text: str,
head_id: int = None,
whitespace_after: int = 1,
start_position: int = 0,
sentence=None,
):
super().__init__(sentence=sentence)

Expand Down Expand Up @@ -562,10 +556,26 @@ class Span(_PartOfSentence):
This class represents one textual span consisting of Tokens.
"""

def __new__(self, tokens: List[Token]):

# check if the span already exists. If so, return it
unlabeled_identifier = self._make_unlabeled_identifier(tokens)
if unlabeled_identifier in tokens[0].sentence._known_spans:
span = tokens[0].sentence._known_spans[unlabeled_identifier]
return span

# else make a new span
else:
span = super(Span, self).__new__(self)
span.initialized = False
tokens[0].sentence._known_spans[unlabeled_identifier] = span
return span

def __init__(self, tokens: List[Token]):
super().__init__(tokens[0].sentence)
self.tokens = tokens
super()._init_labels()
if not self.initialized:
super().__init__(tokens[0].sentence)
self.tokens = tokens
self.initialized = True

@property
def start_position(self) -> int:
Expand All @@ -579,9 +589,14 @@ def end_position(self) -> int:
def text(self) -> str:
return "".join([t.text + t.whitespace_after * " " for t in self.tokens]).strip()

@staticmethod
def _make_unlabeled_identifier(tokens: List[Token]):
text = "".join([t.text + t.whitespace_after * " " for t in tokens]).strip()
return f'Span[{tokens[0].idx - 1}:{tokens[-1].idx}]: "{text}"'

@property
def unlabeled_identifier(self) -> str:
return f'Span[{self.tokens[0].idx - 1}:{self.tokens[-1].idx}]: "{self.text}"'
return self._make_unlabeled_identifier(self.tokens)

def __repr__(self):
return self.__str__()
Expand Down Expand Up @@ -646,11 +661,11 @@ class Sentence(DataPoint):
"""

def __init__(
self,
text: Union[str, List[str], List[Token]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: str = None,
start_position: int = 0,
self,
text: Union[str, List[str], List[Token]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: str = None,
start_position: int = 0,
):
"""
Class to hold all meta related to a text (tokens, predictions, language code, ...)
Expand Down Expand Up @@ -1038,10 +1053,10 @@ def is_context_set(self) -> bool:
:return: True if context is set, else False
"""
return (
self._has_context
or self._previous_sentence is not None
or self._next_sentence is not None
or self._position_in_dataset is not None
self._has_context
or self._previous_sentence is not None
or self._next_sentence is not None
or self._position_in_dataset is not None
)

def copy_context_from_sentence(self, sentence: "Sentence") -> None:
Expand Down Expand Up @@ -1168,12 +1183,12 @@ def unlabeled_identifier(self):

class Corpus(typing.Generic[T_co]):
def __init__(
self,
train: Optional[Dataset[T_co]] = None,
dev: Optional[Dataset[T_co]] = None,
test: Optional[Dataset[T_co]] = None,
name: str = "corpus",
sample_missing_splits: Union[bool, str] = True,
self,
train: Optional[Dataset[T_co]] = None,
dev: Optional[Dataset[T_co]] = None,
test: Optional[Dataset[T_co]] = None,
name: str = "corpus",
sample_missing_splits: Union[bool, str] = True,
):
# set name
self.name: str = name
Expand Down Expand Up @@ -1212,11 +1227,11 @@ def test(self) -> Optional[Dataset[T_co]]:
return self._test

def downsample(
self,
percentage: float = 0.1,
downsample_train=True,
downsample_dev=True,
downsample_test=True,
self,
percentage: float = 0.1,
downsample_train=True,
downsample_dev=True,
downsample_test=True,
):
if downsample_train and self._train is not None:
self._train = self._downsample_to_proportion(self._train, percentage)
Expand Down Expand Up @@ -1449,7 +1464,7 @@ def make_label_dictionary(self, label_type: str, min_count: int = -1, add_unk: b
unked_count += count

if len(label_dictionary.idx2item) == 0 or (
len(label_dictionary.idx2item) == 1 and "<unk>" in label_dictionary.get_items()
len(label_dictionary.idx2item) == 1 and "<unk>" in label_dictionary.get_items()
):
log.error(f"ERROR: You specified label_type='{label_type}' which is not in this dataset!")
contained_labels = ", ".join(
Expand Down Expand Up @@ -1520,7 +1535,7 @@ def _corrupt_labels(self, noise_share: float, label_type: str, labels: List[str]
corrupted_count += 1

log.info(
f"Total labels corrupted: {corrupted_count}. Resulting noise share: {round((corrupted_count/total_label_count)*100, 2)}%."
f"Total labels corrupted: {corrupted_count}. Resulting noise share: {round((corrupted_count / total_label_count) * 100, 2)}%."
)

def get_label_distribution(self):
Expand Down Expand Up @@ -1555,11 +1570,11 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary:

class MultiCorpus(Corpus):
def __init__(
self,
corpora: List[Corpus],
task_ids: Optional[List[str]] = None,
name: str = "multicorpus",
**corpusargs,
self,
corpora: List[Corpus],
task_ids: Optional[List[str]] = None,
name: str = "multicorpus",
**corpusargs,
):
self.corpora: List[Corpus] = corpora

Expand All @@ -1568,7 +1583,6 @@ def __init__(
train_parts = []
dev_parts = []
test_parts = []
print(self.corpora)
for corpus in self.corpora:
if corpus.train:
train_parts.append(corpus.train)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_equality():
assert Sentence("Guten Tag!") != Sentence("Good day!")
assert Sentence("Guten Tag!", use_tokenizer=True) != Sentence("Guten Tag!", use_tokenizer=False)

# TODO: is this desirable? Or should two sentences with same text still be considered different objects?
assert Sentence("Guten Tag!") == Sentence("Guten Tag!")
# TODO: is this desirable? Or should two sentences with same text be considered same objects?
assert Sentence("Guten Tag!") != Sentence("Guten Tag!")


def test_token_labeling():
Expand Down

0 comments on commit b5d221c

Please sign in to comment.