Skip to content

Commit

Permalink
feat(buzzwords): [0.3.1] silhouette score (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisExternal authored Dec 20, 2022
1 parent 70df9c4 commit 9d4b5be
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ _site/*
buzzwords.egg-info/*
Gemfile.lock

.DS_Store
.DS_Store
__pycache__
27 changes: 27 additions & 0 deletions buzzwords/buzzwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cuml.cluster import approximate_predict
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import silhouette_score
from tqdm import tqdm

from .models.clip_encoder import CLIPEncoder
Expand Down Expand Up @@ -158,6 +159,10 @@ def __init__(self,
self.topic_descriptions = None
self.topic_alterations = {}

silhouette_params = self.model_parameters['Silhouette']
self.silhouette_random_state = silhouette_params['random_state']
self.run_silhouette_score = silhouette_params['run_silhouette_score']

def fit(self, docs: List[str], recursions: int = 1) -> None:
"""
Fit model based on given data
Expand Down Expand Up @@ -241,6 +246,11 @@ def fit_transform(self, docs: List[str], recursions: int = 1) -> List[int]:
min_cluster_size=int(self.model_parameters['HDBSCAN']['min_cluster_size'])
)

# Silhouette score is a metric used to calculate the goodness of a clustering technique
if self.run_silhouette_score:
self.silhouette_score = self.get_silhouette_score(embeddings,topics)
print(f"Silhouette score: {self.silhouette_score}")

# Lemmatise words to avoid similar words in top n keywords
if self.lemmatise:
docs = [
Expand Down Expand Up @@ -594,3 +604,20 @@ def load(self, destination: str) -> None:

with open(destination, 'rb') as file:
self.__dict__ = pickle.load(file)

def get_silhouette_score(self, X: np.ndarray, labels: np.ndarray) -> float:
"""
A Silhouette Coefficient or silhouette score is a metric used to calculate the goodness of a clustering technique
1: Means clusters are well apart from each other and clearly distinguished
0: Means clusters are indifferent, or we can say that the distance between clusters is not significant
-1: Means clusters are assigned in the wrong way
Parameters
----------
X : np.ndarray
embeddings array
labels : np.ndarray
labels as predicted by the model
"""
return silhouette_score(X[labels!=-1],labels[labels!=-1],random_state=self.silhouette_random_state)

3 changes: 3 additions & 0 deletions buzzwords/model_parameters.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Default parameters for Buzzwords model
Embedding:
model_name_or_path: 'all-mpnet-base-v2'
Silhouette:
random_state: 42
run_silhouette_score: False
UMAP:
n_neighbors: 10
n_components: 5
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ conda create -y -n $env_name \
source activate $env_name;

pip3 install \
sentence-transformers==2.1.0 \
sentence-transformers==2.2.2 \
keybert==0.5.1 \
pytest~=7.0.0 \
clip-by-openai==1.1;
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

setuptools.setup(
name='buzzwords',
version='0.3.0',
version='0.3.1',
packages=setuptools.find_packages()
)

0 comments on commit 9d4b5be

Please sign in to comment.