Skip to content

Commit 5a4204c

Browse files
committed
GH-3496: Add OneClassClassifier model
1 parent c6a2643 commit 5a4204c

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from pathlib import Path
2+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
3+
4+
import numpy as np
5+
import torch
6+
from torch.utils.data import Dataset
7+
8+
import flair
9+
from flair.data import Dictionary, Sentence, _iter_dataset
10+
from flair.embeddings import DocumentEmbeddings
11+
from flair.training_utils import store_embeddings
12+
13+
14+
class OneClassClassifier(flair.nn.Classifier[Sentence]):
15+
"""One Class Classification Model for tasks such as Anomaly Detection.
16+
17+
Task
18+
----
19+
One Class Classification (OCC) tries to identify objects of a specific class amongst all objects, in contrast to
20+
distinguishing between two or more classes.
21+
22+
Example:
23+
-------
24+
The model expects to be trained on a dataset in which every element has the same label_value, e.g. movie reviews
25+
with the label POSITIVE.
26+
During inference, one of two label_values will be added:
27+
- In-class (e.g. another movie review) -> label_value="POSITIVE"
28+
- Anything else (e.g. a wiki page) -> label_value="<unk>"
29+
30+
Architecture
31+
------------
32+
Reconstruction with autoencoder. The score is the reconstruction error from compressing and decompressing the
33+
document embedding. A LOWER score indicates a HIGHER probability of being in-class. The threshold is
34+
calculated as a high percentile of the score distribution of in-class elements from the dev set.
35+
36+
You must set the threshold after training by running `model.threshold = model.calculate_threshold(corpus.dev)`.
37+
"""
38+
39+
def __init__(
40+
self,
41+
embeddings: DocumentEmbeddings,
42+
label_dictionary: Dictionary,
43+
label_type: str,
44+
encoding_dim: int = 128,
45+
threshold: Optional[float] = None,
46+
) -> None:
47+
"""Initializes a OneClassClassifier.
48+
49+
Args:
50+
embeddings: Embeddings to use during training and prediction
51+
label_dictionary: The label to predict. Must contain exactly one class.
52+
label_type: name of the annotation_layer to be predicted in case a corpus has multiple annotations
53+
encoding_dim: The size of the compressed embedding
54+
threshold: The score that separates in-class from out-of-class
55+
"""
56+
super().__init__()
57+
self.embeddings = embeddings
58+
if len(label_dictionary) != 1:
59+
raise ValueError(f"label_dictionary must have exactly 1 element: {label_dictionary}")
60+
self.label_dictionary = label_dictionary
61+
self.label_value = label_dictionary.get_items()[0]
62+
self._label_type = label_type
63+
self.encoding_dim = encoding_dim
64+
self.threshold = threshold
65+
66+
embedding_dim = embeddings.embedding_length
67+
self.encoder = torch.nn.Sequential(
68+
torch.nn.Linear(embedding_dim, encoding_dim * 4),
69+
torch.nn.LeakyReLU(True),
70+
torch.nn.Linear(encoding_dim * 4, encoding_dim * 2),
71+
torch.nn.LeakyReLU(True),
72+
torch.nn.Linear(encoding_dim * 2, encoding_dim),
73+
torch.nn.LeakyReLU(True),
74+
)
75+
76+
self.decoder = torch.nn.Sequential(
77+
torch.nn.Linear(encoding_dim, encoding_dim * 2),
78+
torch.nn.LeakyReLU(True),
79+
torch.nn.Linear(encoding_dim * 2, encoding_dim * 4),
80+
torch.nn.LeakyReLU(True),
81+
torch.nn.Linear(encoding_dim * 4, embedding_dim),
82+
torch.nn.LeakyReLU(True),
83+
)
84+
85+
self.cosine_sim = torch.nn.CosineSimilarity(dim=1)
86+
self.to(flair.device)
87+
88+
def forward(self, x: torch.Tensor) -> torch.Tensor:
89+
x = self.encoder(x)
90+
x = self.decoder(x)
91+
return x
92+
93+
def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
94+
"""Returns Tuple[scalar tensor, num examples]."""
95+
if len(sentences) == 0:
96+
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
97+
sentence_tensor = self._sentences_to_tensor(sentences)
98+
reconstructed_sentence_tensor = self.forward(sentence_tensor)
99+
return self._loss(reconstructed_sentence_tensor, sentence_tensor).sum(), len(sentences)
100+
101+
def predict(
102+
self,
103+
sentences: Union[List[Sentence], Sentence],
104+
mini_batch_size: int = 32,
105+
return_probabilities_for_all_classes: bool = False,
106+
verbose: bool = False,
107+
label_name: Optional[str] = None,
108+
return_loss=False,
109+
embedding_storage_mode="none",
110+
) -> Optional[torch.Tensor]:
111+
"""Predicts the class labels for the given sentences. The labels are directly added to the sentences.
112+
113+
Args:
114+
sentences: list of sentences to predict
115+
mini_batch_size: the amount of sentences that will be predicted within one batch (unimplemented)
116+
return_probabilities_for_all_classes: return probabilities for all classes instead of only best predicted (unimplemented)
117+
verbose: set to True to display a progress bar (unimplemented)
118+
return_loss: set to True to return loss
119+
label_name: set this to change the name of the label type that is predicted
120+
embedding_storage_mode: default is 'none' which is the best is most cases.
121+
Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory.
122+
123+
Returns: None. If return_loss is set, returns a scalar tensor
124+
"""
125+
if label_name is None:
126+
label_name = self.label_type
127+
128+
with torch.no_grad():
129+
# make sure it's a list
130+
if not isinstance(sentences, list):
131+
sentences = [sentences]
132+
133+
Sentence.set_context_for_sentences(cast(List[Sentence], sentences))
134+
135+
# filter empty sentences
136+
sentences = [sentence for sentence in sentences if len(sentence) > 0]
137+
if len(sentences) == 0:
138+
return torch.tensor(0.0, requires_grad=True, device=flair.device) if return_loss else None
139+
140+
sentence_tensor = self._sentences_to_tensor(sentences)
141+
reconstructed = self.forward(sentence_tensor)
142+
loss_tensor = self._loss(reconstructed, sentence_tensor)
143+
144+
for sentence, loss in zip(sentences, loss_tensor.tolist()):
145+
sentence.remove_labels(label_name)
146+
label_value = self.label_value if self.threshold is not None and loss < self.threshold else "<unk>"
147+
sentence.add_label(typename=label_name, value=label_value, score=loss)
148+
149+
store_embeddings(sentences, storage_mode=embedding_storage_mode)
150+
151+
return loss_tensor.sum() if return_loss else None
152+
153+
@property
154+
def label_type(self) -> str:
155+
return self._label_type
156+
157+
def _sentences_to_tensor(self, sentences: List[Sentence]) -> torch.Tensor:
158+
self.embeddings.embed(sentences)
159+
return torch.stack([sentence.embedding for sentence in sentences])
160+
161+
def _loss(self, predicted: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
162+
"""Return cosine similarity loss.
163+
164+
Args:
165+
predicted: tensor of shape (batch_size, embedding_size)
166+
labels: tensor of shape (batch_size, embedding_size)
167+
168+
Returns:
169+
tensor of shape (batch_size)
170+
"""
171+
if labels.size(0) == 0:
172+
return torch.tensor(0.0, requires_grad=True, device=flair.device)
173+
174+
return 1 - self.cosine_sim(predicted, labels)
175+
176+
def _get_state_dict(self):
177+
"""Returns the state dictionary for this model."""
178+
model_state = {
179+
**super()._get_state_dict(),
180+
"embeddings": self.embeddings.save_embeddings(use_state_dict=False),
181+
"label_dictionary": self.label_dictionary,
182+
"label_type": self.label_type,
183+
"encoding_dim": self.encoding_dim,
184+
"threshold": self.threshold,
185+
}
186+
187+
return model_state
188+
189+
@classmethod
190+
def _init_model_with_state_dict(cls, state, **kwargs):
191+
return super()._init_model_with_state_dict(
192+
state,
193+
embeddings=state.get("embeddings"),
194+
label_dictionary=state.get("label_dictionary"),
195+
label_type=state.get("label_type"),
196+
encoding_dim=state.get("encoding_dim"),
197+
threshold=state.get("threshold"),
198+
**kwargs,
199+
)
200+
201+
@classmethod
202+
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "OneClassClassifier":
203+
from typing import cast
204+
205+
return cast("OneClassClassifier", super().load(model_path=model_path))
206+
207+
def calculate_threshold(self, dataset: Dataset[Sentence], quantile=0.995) -> float:
208+
"""Determine the score threshold to consider a Sentence in-class.
209+
210+
This implementation returns the score at which `quantile` of `dataset` will be considered in-class. Intended
211+
for use-cases desiring high-recall.
212+
"""
213+
214+
def score(sentence: Sentence) -> float:
215+
sentence_tensor = self._sentences_to_tensor([sentence])
216+
reconstructed = self.forward(sentence_tensor)
217+
loss_tensor = self._loss(reconstructed, sentence_tensor)
218+
return loss_tensor.tolist()[0]
219+
220+
scores = [
221+
score(sentence)
222+
for sentence in _iter_dataset(dataset)
223+
if sentence.get_labels(self.label_type)[0].value == self.label_value
224+
]
225+
threshold = np.quantile(scores, quantile)
226+
return threshold
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
3+
import flair
4+
from flair.embeddings import TransformerDocumentEmbeddings
5+
from flair.models.one_class_classification_model import OneClassClassifier
6+
from flair.trainers import ModelTrainer
7+
from tests.model_test_utils import BaseModelTest
8+
9+
10+
class TestOneClassClassifier(BaseModelTest):
11+
model_cls = OneClassClassifier
12+
train_label_type = "topic"
13+
training_args = {
14+
"max_epochs": 2,
15+
}
16+
17+
@pytest.fixture()
18+
def corpus(self, tasks_base_path):
19+
label_type = "topic"
20+
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb", label_type=label_type)
21+
corpus._train = [x for x in corpus.train if x.get_labels(label_type)[0].value == "POSITIVE"]
22+
return corpus
23+
24+
@pytest.fixture()
25+
def embeddings(self):
26+
return TransformerDocumentEmbeddings(model="distilbert-base-uncased", layers="-1", fine_tune=True)
27+
28+
@pytest.mark.integration()
29+
def test_train_load_use_one_class_classifier(self, results_base_path, corpus, example_sentence, embeddings):
30+
label_dict = corpus.make_label_dictionary(label_type=self.train_label_type)
31+
32+
model = self.model_cls(embeddings=embeddings, label_dictionary=label_dict, label_type=self.train_label_type)
33+
trainer = ModelTrainer(model, corpus)
34+
35+
trainer.train(results_base_path, shuffle=False, **self.training_args)
36+
37+
del trainer, model, label_dict, corpus
38+
loaded_model = self.model_cls.load(results_base_path / "final-model.pt")
39+
40+
loaded_model.predict(example_sentence)
41+
loaded_model.predict([example_sentence, self.empty_sentence])
42+
loaded_model.predict([self.empty_sentence])
43+
44+
assert example_sentence.get_labels(self.train_label_type)[0].value in {"POSITIVE", "<unk>"}

0 commit comments

Comments
 (0)