diff --git a/flair/models/__init__.py b/flair/models/__init__.py index 8357cc47ea..e75daf074b 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,3 @@ -from .clustering import ClusteringModel from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -37,6 +36,5 @@ "TARSTagger", "TextClassifier", "TextRegressor", - "ClusteringModel", "MultitaskModel", ] diff --git a/flair/models/clustering.py b/flair/models/clustering.py deleted file mode 100644 index e9902f6f67..0000000000 --- a/flair/models/clustering.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging -import pickle -from collections import OrderedDict -from pathlib import Path -from typing import Optional, Union - -import joblib -from sklearn.base import BaseEstimator, ClusterMixin -from sklearn.metrics import normalized_mutual_info_score -from tqdm import tqdm - -from flair.data import Corpus, _iter_dataset -from flair.datasets import DataLoader -from flair.embeddings import DocumentEmbeddings - -log = logging.getLogger("flair") - - -class ClusteringModel: - """A wrapper class to apply sklearn clustering models on DocumentEmbeddings.""" - - def __init__(self, model: Union[ClusterMixin, BaseEstimator], embeddings: DocumentEmbeddings) -> None: - """Instantiate the ClusteringModel. - - Args: - model: the clustering algorithm from sklearn this wrapper will use. - embeddings: the flair DocumentEmbedding this wrapper uses to calculate a vector for each sentence. - """ - self.model = model - self.embeddings = embeddings - - def fit(self, corpus: Corpus, **kwargs): - """Trains the model. - - Args: - corpus: the flair corpus this wrapper will use for fitting the model. - **kwargs: parameters propagated to the models `.fit()` method. - """ - X = self._convert_dataset(corpus) - - log.info("Start clustering " + str(self.model) + " with " + str(len(X)) + " Datapoints.") - self.model.fit(X, **kwargs) - log.info("Finished clustering.") - - def predict(self, corpus: Corpus): - """Predict labels given a list of sentences and returns the respective class indices. - - Args: - corpus: the flair corpus this wrapper will use for predicting the labels. - """ - X = self._convert_dataset(corpus) - log.info("Start the prediction " + str(self.model) + " with " + str(len(X)) + " Datapoints.") - predict = self.model.predict(X) - - for idx, sentence in enumerate(_iter_dataset(corpus.get_all_sentences())): - sentence.set_label("cluster", str(predict[idx])) - - log.info("Finished prediction and labeled all sentences.") - return predict - - def save(self, model_file: Union[str, Path]): - """Saves current model. - - Args: - model_file: path where to save the model. - """ - joblib.dump(pickle.dumps(self), str(model_file)) - - log.info("Saved the model to: " + str(model_file)) - - @staticmethod - def load(model_file: Union[str, Path]): - """Loads a model from a given path. - - Args: - model_file: path to the file where the model is saved. - """ - log.info("Loading model from: " + str(model_file)) - return pickle.loads(joblib.load(str(model_file))) - - def _convert_dataset( - self, corpus, label_type: Optional[str] = None, batch_size: int = 32, return_label_dict: bool = False - ): - """Makes a flair-corpus sklearn compatible. - - Turns the corpora into X, y datasets as required for most sklearn clustering models. - Ref.: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.cluster - """ - log.info("Embed sentences...") - sentences = [] - for batch in tqdm(DataLoader(corpus.get_all_sentences(), batch_size=batch_size)): - self.embeddings.embed(batch) - sentences.extend(batch) - - X = [sentence.embedding.cpu().detach().numpy() for sentence in sentences] - - if label_type is None: - return X - - labels = [sentence.get_labels(label_type)[0].value for sentence in sentences] - label_dict = {v: k for k, v in enumerate(OrderedDict.fromkeys(labels))} - y = [label_dict.get(label) for label in labels] - - if return_label_dict: - return X, y, label_dict - - return X, y - - def evaluate(self, corpus: Corpus, label_type: str): - """This method calculates some evaluation metrics for the clustering. - - Also, the result of the evaluation is logged. - - Args: - corpus: the flair corpus this wrapper will use for evaluation. - label_type: the label from the sentence will be used for the evaluation. - """ - X, Y = self._convert_dataset(corpus, label_type=label_type) - predict = self.model.predict(X) - log.info("NMI - Score: " + str(normalized_mutual_info_score(predict, Y))) diff --git a/resources/docs/TUTORIAL_12_CLUSTERING.md b/resources/docs/TUTORIAL_12_CLUSTERING.md deleted file mode 100644 index 376e5d5639..0000000000 --- a/resources/docs/TUTORIAL_12_CLUSTERING.md +++ /dev/null @@ -1,180 +0,0 @@ -Text Clustering in flair ----------- - -In this package text clustering is implemented. This module has the following -clustering algorithms implemented: -- k-Means -- BIRCH -- Expectation Maximization - -Each of the implemented algorithm needs to have an instanced DocumentEmbedding. This embedding will -transform each text/document to a vector. With these vectors the clustering algorithm can be performed. - ---------------------------- - -k-Means ------- -k-Means is a classical and well known clustering algorithm. k-Means is a partitioning-based Clustering algorithm. -The user defines with the parameter *k* how many clusters the given data has. -So the choice of *k* is very important. -More about k-Means can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html). - - -```python -from flair.models import ClusteringModel -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from sklearn.cluster import KMeans - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = KMeans(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - -BIRCH ---------- -BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm. -BIRCH is specialized to handle large amounts of data. BIRCH scans the data a single time and builds an internal data -structure. This data structure contains the data but in a compressed way. -More about BIRCH can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html). - -```python -from sklearn.cluster import Birch -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from flair.models import ClusteringModel - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = Birch(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - - -Expectation Maximization --------------------------- -Expectation Maximization (EM) is a different class of clustering algorithms called soft clustering algorithms. -Here each point isn't directly assigned to a cluster by a hard decision. -Each data point has a probability to which cluster the data point belongs. The Expectation Maximization (EM) -algorithm is a soft clustering algorithm. -More about EM can be read on the official [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html). - - -```python -from sklearn.mixture import GaussianMixture -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from flair.models import ClusteringModel - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = GaussianMixture(n_components=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -``` - ---------------------------- - -Loading/Saving the model ------------ - -The model can be saved and loaded. The code below shows how to save a model. -```python -from flair.models import ClusteringModel -from flair.datasets import TREC_6 -from flair.embeddings import SentenceTransformerDocumentEmbeddings -from sklearn.cluster import KMeans - -embeddings = SentenceTransformerDocumentEmbeddings() - -# store all embeddings in memory which is required to perform clustering -corpus = TREC_6(memory_mode='full').downsample(0.05) - -model = KMeans(n_clusters=6) - -clustering_model = ClusteringModel( - model=model, - embeddings=embeddings -) - -# fit the model on a corpus -clustering_model.fit(corpus) - -# save the model -clustering_model.save(model_file="clustering_model.pt") -``` - -The code for loading a model. - -````python -# load saved clustering model -model = ClusteringModel.load(model_file="clustering_model.pt") - -# load a corpus -corpus = TREC_6(memory_mode='full').downsample(0.05) - -# predict the corpus -model.predict(corpus) -```` - ---------------------- - -Evaluation ---------- -The result of the clustering can be evaluated. For this we will use the -[NMI](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html). -(Normalized Mutual Info) score. - -````python -# need to fit() the model first -# evaluate the model on a corpus with the given label -clustering_model.evaluate(corpus, label_type="question_class") -```` - -The result of the evaluation can be seen below with the SentenceTransformerDocumentEmbeddings: - - -| Clustering Algorithm | Dataset | NMI | -|--------------------------|:-------------:|--------:| -| k Means | StackOverflow | ~0.2122 | -| BIRCH | StackOverflow | ~0,2424 | -| Expectation Maximization | 20News group | ~0,2222 |