Skip to content

Commit

Permalink
Fix mypy and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Feb 22, 2022
1 parent e644914 commit 1d0a21b
Show file tree
Hide file tree
Showing 22 changed files with 225 additions and 239 deletions.
102 changes: 75 additions & 27 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def __lt__(self, other):
def labeled_identifier(self):
return f"{self.data_point.unlabeled_identifier}/{self.value}"

@property
def unlabeled_identifier(self):
return f"{self.data_point.unlabeled_identifier}"


class DataPoint:
"""
Expand Down Expand Up @@ -426,12 +430,12 @@ class Token(_PartOfSentence):
"""

def __init__(
self,
text: str,
head_id: int = None,
whitespace_after: bool = True,
start_position: int = 0,
sentence=None,
self,
text: str,
head_id: int = None,
whitespace_after: bool = True,
start_position: int = 0,
sentence=None,
):
super().__init__(sentence=sentence)

Expand Down Expand Up @@ -508,7 +512,6 @@ class Span(_PartOfSentence):
"""
This class represents one textual span consisting of Tokens.
"""

def __init__(self, tokens: List[Token]):
super().__init__(tokens[0].sentence)
self.tokens = tokens
Expand Down Expand Up @@ -552,8 +555,12 @@ def __iter__(self):
def __len__(self) -> int:
return len(self.tokens)

@property
def embedding(self):
pass

class Relation(_PartOfSentence):

def __init__(self, first: Span, second: Span):
super().__init__(sentence=first.sentence)
self.first: Span = first
Expand Down Expand Up @@ -590,6 +597,10 @@ def start_position(self) -> int:
def end_position(self) -> int:
return max(self.first.end_position, self.second.end_position)

@property
def embedding(self):
pass


class Tokenizer(ABC):
r"""An abstract class representing a :class:`Tokenizer`.
Expand All @@ -616,11 +627,11 @@ class Sentence(DataPoint):
"""

def __init__(
self,
text: Union[str, List[str]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: str = None,
start_position: int = 0,
self,
text: Union[str, List[str]],
use_tokenizer: Union[bool, Tokenizer] = True,
language_code: str = None,
start_position: int = 0,
):
"""
Class to hold all meta related to a text (tokens, predictions, language code, ...)
Expand All @@ -638,7 +649,7 @@ def __init__(
self.tokens: List[Token] = []

# private field for all known spans
self._known_spans = {}
self._known_spans: Dict[str, _PartOfSentence] = {}

self.language_code: Optional[str] = language_code

Expand Down Expand Up @@ -687,6 +698,10 @@ def __init__(
previous_word_offset = current_offset - 1
previous_token = token

# the last token has no whitespace after
if len(self) > 0:
self.tokens[-1].whitespace_after = False

# log a warning if the dataset is empty
if text == "":
log.warning("Warning: An empty Sentence was created! Are there empty strings in your dataset?")
Expand Down Expand Up @@ -741,10 +756,11 @@ def add_token(self, token: Union[Token, str]):

# set token idx and sentence
token.sentence = self
token.idx = len(self.tokens)
token.idx = len(self.tokens) + 1
if token.start_position == 0 and len(self) > 0:
token.start_pos = len(self.to_original_text()) + 1 if self[-1].whitespace_after \
else len(self.to_original_text())
token.start_pos = (
len(self.to_original_text()) + 1 if self[-1].whitespace_after else len(self.to_original_text())
)
token.end_pos = token.start_pos + len(token.text)

# append token to sentence
Expand Down Expand Up @@ -1071,6 +1087,10 @@ def get_labels(self, label_type: str = None):

def remove_labels(self, typename: str):

# labels also need to be deleted at all tokens
for token in self:
token.remove_labels(typename)

# labels also need to be deleted at all known spans
for span in self._known_spans.values():
span.remove_labels(typename)
Expand Down Expand Up @@ -1116,6 +1136,18 @@ def __len__(self):
def unlabeled_identifier(self):
return f"{self.first.unlabeled_identifier} || {self.second.unlabeled_identifier}"

@property
def start_position(self) -> int:
return self.first.start_position

@property
def end_position(self) -> int:
return self.first.end_position

@property
def text(self):
return self.first.text + " || " + self.second.text


TextPair = DataPair[Sentence, Sentence]

Expand All @@ -1138,6 +1170,22 @@ def __str__(self):

return f"Image: {image_repr} {image_url}"

@property
def start_position(self) -> int:
pass

@property
def end_position(self) -> int:
pass

@property
def text(self):
pass

@property
def unlabeled_identifier(self):
pass


class FlairDataset(Dataset):
@abstractmethod
Expand All @@ -1147,12 +1195,12 @@ def is_in_memory(self) -> bool:

class Corpus:
def __init__(
self,
train: Dataset = None,
dev: Dataset = None,
test: Dataset = None,
name: str = "corpus",
sample_missing_splits: Union[bool, str] = True,
self,
train: Dataset = None,
dev: Dataset = None,
test: Dataset = None,
name: str = "corpus",
sample_missing_splits: Union[bool, str] = True,
):
# set name
self.name: str = name
Expand Down Expand Up @@ -1191,11 +1239,11 @@ def test(self) -> Optional[Dataset]:
return self._test

def downsample(
self,
percentage: float = 0.1,
downsample_train=True,
downsample_dev=True,
downsample_test=True,
self,
percentage: float = 0.1,
downsample_train=True,
downsample_dev=True,
downsample_test=True,
):

if downsample_train and self._train is not None:
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class StringDataset(FlairDataset):
def __init__(
self,
texts: Union[str, List[str]],
use_tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
use_tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
):
"""
Instantiate StringDataset
Expand Down Expand Up @@ -225,7 +225,7 @@ def _parse_document_to_sentence(
self,
text: str,
labels: List[str],
tokenizer: Union[Callable[[str], List[Token]], Tokenizer],
tokenizer: Union[bool, Tokenizer],
):
if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]
Expand Down
34 changes: 16 additions & 18 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from pathlib import Path
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union

import flair
from flair.data import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
truncate_to_max_tokens: int = -1,
truncate_to_max_chars: int = -1,
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
label_name_map: Dict[str, str] = None,
skip_labels: List[str] = None,
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
truncate_to_max_tokens=-1,
truncate_to_max_chars=-1,
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
label_name_map: Dict[str, str] = None,
skip_labels: List[str] = None,
Expand Down Expand Up @@ -253,9 +253,7 @@ def __init__(
position = f.tell()
line = f.readline()

def _parse_line_to_sentence(
self, line: str, label_prefix: str, tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer]
):
def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union[bool, Tokenizer]):
words = line.split()

labels = []
Expand Down Expand Up @@ -1106,7 +1104,7 @@ class SENTEVAL_CR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1160,7 +1158,7 @@ class SENTEVAL_MR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1214,7 +1212,7 @@ class SENTEVAL_SUBJ(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1268,7 +1266,7 @@ class SENTEVAL_MPQA(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1322,7 +1320,7 @@ class SENTEVAL_SST_BINARY(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1382,7 +1380,7 @@ class SENTEVAL_SST_GRANULAR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1535,7 +1533,7 @@ class GO_EMOTIONS(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
**corpusargs,
):
Expand Down Expand Up @@ -1642,7 +1640,7 @@ class TREC_50(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="full",
**corpusargs,
):
Expand Down Expand Up @@ -1704,7 +1702,7 @@ class TREC_6(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="full",
**corpusargs,
):
Expand Down Expand Up @@ -1767,7 +1765,7 @@ class YAHOO_ANSWERS(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="partial",
**corpusargs,
):
Expand Down Expand Up @@ -1846,7 +1844,7 @@ class GERMEVAL_2018_OFFENSIVE_LANGUAGE(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "full",
fine_grained_classes: bool = False,
**corpusargs,
Expand Down Expand Up @@ -1919,7 +1917,7 @@ def __init__(
self,
base_path: Union[str, Path] = None,
memory_mode: str = "full",
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
**corpusargs,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions flair/datasets/text_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Optional, Union

import flair
from flair.data import Corpus, DataPair, FlairDataset, Sentence, _iter_dataset
from flair.data import Corpus, DataPair, FlairDataset, Sentence, TextPair, _iter_dataset
from flair.datasets.base import find_train_dev_test_files
from flair.file_utils import cached_path, unpack_file, unzip_file

Expand Down Expand Up @@ -180,7 +180,7 @@ def _make_bi_sentence(self, source_line: str, target_line: str):
source_sentence.tokens = source_sentence.tokens[: self.max_tokens_per_doc]
target_sentence.tokens = target_sentence.tokens[: self.max_tokens_per_doc]

return DataPair(source_sentence, target_sentence)
return TextPair(source_sentence, target_sentence)

def __len__(self):
return self.total_sentence_count
Expand Down Expand Up @@ -416,7 +416,7 @@ def _make_data_pair(self, first_element: str, second_element: str, label: str =
first_sentence.tokens = first_sentence.tokens[: self.max_tokens_per_doc]
second_sentence.tokens = second_sentence.tokens[: self.max_tokens_per_doc]

data_pair = DataPair(first_sentence, second_sentence)
data_pair = TextPair(first_sentence, second_sentence)

if label:
data_pair.add_label(typename=self.label_type, value=label)
Expand Down
Loading

0 comments on commit 1d0a21b

Please sign in to comment.