Skip to content

Commit

Permalink
Rename AugmentedSentence to PrefixedSentence
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Feb 9, 2024
1 parent d501f58 commit 037862e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
24 changes: 12 additions & 12 deletions flair/models/prefixed_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from flair.models import SequenceTagger


class AugmentedSentence(Sentence):
"""An AugmentedSentence expresses that a sentence is augmented and compatible with the AugmentedSentenceSequenceTagger.
class PrefixedSentence(Sentence):
"""An AugmentedSentence expresses that a sentence is augmented and compatible with the PrefixedSequenceTagger.
For inference, i.e. `predict` and `evaluate`, the AugmentedSentenceSequenceTagger internally encodes the sentences.
Therefore, these functions work with the regular flair sentence objects.
Expand All @@ -26,7 +26,7 @@ class SentenceAugmentationStrategy(ABC):
@abstractmethod
def augment_sentence(
self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None
) -> AugmentedSentence:
) -> PrefixedSentence:
"""Augments the given sentence text with additional instructions for working / predicting the task on the given annotations.
Args:
Expand Down Expand Up @@ -64,7 +64,7 @@ def _init_strategy_with_state_dict(cls, state, **kwargs):

def augment_dataset(
self, dataset: Dataset[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None
) -> FlairDatapointDataset[AugmentedSentence]:
) -> FlairDatapointDataset[PrefixedSentence]:
"""Transforms a dataset into a dataset containing augmented sentences specific to the `AugmentedSentenceSequenceTagger`.
The returned dataset is stored in memory. For more information on the internal sentence transformation
Expand All @@ -85,7 +85,7 @@ def augment_dataset(

def augment_corpus(
self, corpus: Corpus[Sentence], annotation_layers: Optional[Union[str, List[str]]] = None
) -> Corpus[AugmentedSentence]:
) -> Corpus[PrefixedSentence]:
"""Transforms a corpus into a corpus containing augmented sentences specific to the `AugmentedSentenceSequenceTagger`.
The splits of the returned corpus are stored in memory. For more information on the internal
Expand Down Expand Up @@ -128,9 +128,9 @@ def __init__(self, entity_types: List[str]):

def augment_sentence(
self, sentence: Sentence, annotation_layers: Optional[Union[str, List[str]]] = None
) -> AugmentedSentence:
) -> PrefixedSentence:
# Prepend the task description prompt to the sentence text
augmented_sentence = AugmentedSentence(
augmented_sentence = PrefixedSentence(
text=self.task_prompt + [t.text for t in sentence.tokens],
use_tokenizer=False,
language_code=sentence.language_code,
Expand Down Expand Up @@ -223,14 +223,14 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "PrefixedSequence

return cast("AugmentedSentenceSequenceTagger", super().load(model_path=model_path))

def forward_loss(self, sentences: Union[List[Sentence], List[AugmentedSentence]]) -> Tuple[torch.Tensor, int]:
def forward_loss(self, sentences: Union[List[Sentence], List[PrefixedSentence]]) -> Tuple[torch.Tensor, int]:
# If all sentences are not augmented -> augment them
if all(isinstance(sentence, Sentence) for sentence in sentences):
# mypy does not infer the type of "sentences" restricted by the if statement
sentences = cast(List[Sentence], sentences)

sentences = self.augment_sentences(sentences=sentences, annotation_layers=self.tag_type)
elif not all(isinstance(sentence, AugmentedSentence) for sentence in sentences):
elif not all(isinstance(sentence, PrefixedSentence) for sentence in sentences):
raise ValueError("All passed sentences must be either uniformly augmented or not.")

# mypy does not infer the type of "sentences" restricted by code above
Expand All @@ -240,7 +240,7 @@ def forward_loss(self, sentences: Union[List[Sentence], List[AugmentedSentence]]

def predict(
self,
sentences: Union[List[Sentence], Sentence, List[AugmentedSentence], AugmentedSentence],
sentences: Union[List[Sentence], Sentence, List[PrefixedSentence], PrefixedSentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
Expand All @@ -257,7 +257,7 @@ def predict(
sentences = [sentences]

# If all sentences are already augmented (i.e. compatible with this class), just forward the sentences
if all(isinstance(sentence, AugmentedSentence) for sentence in sentences):
if all(isinstance(sentence, PrefixedSentence) for sentence in sentences):
# mypy does not infer the type of "sentences" restricted by the if statement
sentences = cast(List[Sentence], sentences)

Expand Down Expand Up @@ -312,7 +312,7 @@ def predict(

def augment_sentences(
self, sentences: Union[Sentence, List[Sentence]], annotation_layers: Optional[Union[str, List[str]]] = None
) -> List[AugmentedSentence]:
) -> List[PrefixedSentence]:
if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset):
sentences = [sentences]

Expand Down
10 changes: 5 additions & 5 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flair.data import Sentence
from flair.models.prefixed_tagger import AugmentedSentence, EntityTypeTaskPromptAugmentationStrategy
from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy, PrefixedSentence


def test_entity_type_task_prompt_augmentation_single_type():
Expand All @@ -8,7 +8,7 @@ def test_entity_type_task_prompt_augmentation_single_type():
sent = Sentence("This is a test sentence.")
aug_sent = strategy.augment_sentence(sent)

assert isinstance(aug_sent, AugmentedSentence)
assert isinstance(aug_sent, PrefixedSentence)
assert aug_sent.text.startswith("[ Tag genes ] ")
assert len(aug_sent) == 10

Expand All @@ -19,7 +19,7 @@ def test_entity_type_task_prompt_augmentation_two_types():
sent = Sentence("This is a test sentence.")
aug_sent = strategy.augment_sentence(sent)

assert isinstance(aug_sent, AugmentedSentence)
assert isinstance(aug_sent, PrefixedSentence)
assert aug_sent.text.startswith("[ Tag genes and diseases ] ")
assert len(aug_sent) == 12

Expand All @@ -30,7 +30,7 @@ def test_entity_type_task_prompt_augmentation_multiple_types():
sent = Sentence("This is a test sentence.")
aug_sent = strategy.augment_sentence(sent)

assert isinstance(aug_sent, AugmentedSentence)
assert isinstance(aug_sent, PrefixedSentence)
assert aug_sent.text.startswith("[ Tag genes, diseases and chemicals ] ")
assert len(aug_sent) == 13

Expand All @@ -44,7 +44,7 @@ def test_entity_type_task_prompt_augmentation_label_transfer():

aug_sent = strategy.augment_sentence(sent, "ner")

assert isinstance(aug_sent, AugmentedSentence)
assert isinstance(aug_sent, PrefixedSentence)
assert aug_sent.text.startswith("[ Tag genes ] ")
assert len(aug_sent.get_labels("foo")) == 0

Expand Down

0 comments on commit 037862e

Please sign in to comment.