Skip to content

Commit

Permalink
Add AAclustPlot
Browse files Browse the repository at this point in the history
  • Loading branch information
breimanntools committed Oct 3, 2023
1 parent 0261505 commit 0634bcd
Show file tree
Hide file tree
Showing 62 changed files with 524 additions and 101 deletions.
4 changes: 3 additions & 1 deletion aaanalysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from aaanalysis.data_handling import load_dataset, load_scales
from aaanalysis.aaclust import AAclust
from aaanalysis.aaclust_plot import AAclustPlot
from aaanalysis.cpp import CPP, CPPPlot, SequenceFeature, SplitRange
from aaanalysis.dpulearn import dPULearn
from aaanalysis.plotting import (plot_get_clist, plot_get_cmap, plot_get_cdict,
plot_settings, plot_legend, plot_gcfs)
from aaanalysis.config import options

__all__ = ["load_dataset", "load_scales", "AAclust",
__all__ = ["load_dataset", "load_scales",
"AAclust", "AAclustPlot",
"CPP", "CPPPlot", "SequenceFeature", "SplitRange",
"dPULearn", "plot_get_clist", "plot_get_cmap", "plot_get_cdict",
"plot_settings", "plot_legend", "plot_gcfs", "options"]
Expand Down
Binary file modified aaanalysis/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified aaanalysis/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file modified aaanalysis/_utils/__pycache__/_check_data.cpython-39.pyc
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions aaanalysis/_utils/_check_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.utils import check_array

# Helper functions
def _check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_nan=False):
def check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_nan=False):
"""
Check if the provided value matches the specified dtype.
If dtype is None, checks for general array-likeness.
Expand Down Expand Up @@ -35,7 +35,7 @@ def _check_array_like(name=None, val=None, dtype=None, ensure_2d=False, allow_na
# Check feature matrix and labels
def check_X(X, min_n_samples=3, min_n_features=2, ensure_2d=True, allow_nan=False):
"""Check the feature matrix X is valid."""
X = _check_array_like(name="X", val=X, dtype="float", ensure_2d=ensure_2d, allow_nan=allow_nan)
X = check_array_like(name="X", val=X, dtype="float", ensure_2d=ensure_2d, allow_nan=allow_nan)
n_samples, n_features = X.shape
if n_samples < min_n_samples:
raise ValueError(f"n_samples ({n_samples} in 'X') should be >= {min_n_samples}."
Expand Down
34 changes: 34 additions & 0 deletions aaanalysis/_utils/_check_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""This is a script for scikit-learn model specific check functions"""
import inspect
from inspect import isclass

# Helper functions

# Main functions
def check_mode_class(model_class=None):
""""""
# Check if model_class is actually a class and not an instance
if not isclass(model_class):
raise ValueError(f"'{model_class}' is not a model class. Please provide a valid model class.")
# Check if model is callable
if not callable(getattr(model_class, "__call__", None)):
raise ValueError(f"'{model_class}' is not a callable model.")
return model_class

def check_model_kwargs(model_class=None, model_kwargs=None, param_to_check="n_clusters"):
"""
Check if the provided model has 'n_clusters' as a parameter.
Filter the model_kwargs to only include keys that are valid parameters for the model.
"""
model_kwargs = model_kwargs or {}
if model_class is None:
raise ValueError("'model_class' must be provided.")
valid_args = list(inspect.signature(model_class).parameters.keys())
# Check if 'param_to_check' is a parameter of the model
if param_to_check not in valid_args:
raise ValueError(f"'n_clusters' should be an argument in the given 'model' ({model_class}).")
# Filter model_kwargs to only include valid parameters for the model
invalid_kwargs = [x for x in model_kwargs if x not in valid_args]
if len(invalid_kwargs):
raise ValueError(f"'model_kwargs' contains non valid arguments: {invalid_kwargs}")
return model_kwargs
Binary file modified aaanalysis/aaclust/__pycache__/aaclust.cpython-39.pyc
Binary file not shown.
72 changes: 26 additions & 46 deletions aaanalysis/aaclust/aaclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,28 @@
This is a script for the AAclust clustering wrapper method.
"""
import numpy as np
from typing import Type
from typing import Optional, Dict, Union, List, Tuple
from typing import Optional, Dict, Union, List, Tuple, Type
import inspect
from inspect import isclass
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, calinski_harabasz_score
from sklearn.base import ClusterMixin
from sklearn.exceptions import ConvergenceWarning
import warnings
import pandas as pd

from aaanalysis.aaclust._aaclust_bic import bic_score

from aaanalysis.template_classes import Wrapper
import aaanalysis.utils as ut

from aaanalysis.aaclust._aaclust import (estimate_lower_bound_n_clusters, optimize_n_clusters, merge_clusters,
compute_centers, compute_medoids)
from aaanalysis.aaclust._aaclust_bic import bic_score
from aaanalysis.aaclust._aaclust_statics import compute_correlation, name_clusters
from aaanalysis.template_classes import Wrapper


# I Helper Functions
# Check parameter functions
def check_mode_class(model_class=None):
""""""
# Check if model_class is actually a class and not an instance
if not isclass(model_class):
raise ValueError(f"'{model_class}' is not a model class. Please provide a valid model class.")
# Check if model is callable
if not callable(getattr(model_class, "__call__", None)):
raise ValueError(f"'{model_class}' is not a callable model.")
return model_class

def check_model_kwargs(model_class=None, model_kwargs=None):
"""
Check if the provided model has 'n_clusters' as a parameter.
Filter the model_kwargs to only include keys that are valid parameters for the model.
"""
model_kwargs = model_kwargs or {}
if model_class is None:
raise ValueError("'model_class' must be provided.")
list_model_args = list(inspect.signature(model_class).parameters.keys())
# Check if 'n_clusters' is a parameter of the model
if "n_clusters" not in list_model_args:
error = f"'n_clusters' should be an argument in the given 'model' ({model_class})."
raise ValueError(error)
# Filter model_kwargs to only include valid parameters for the model
not_valid_kwargs = [x for x in model_kwargs if x not in list_model_args]
if len(not_valid_kwargs):
raise ValueError(f"'model_kwargs' contains non valid arguments: {not_valid_kwargs}")
return model_kwargs


def check_merge_metric(merge_metric=None):
""""""
if merge_metric is not None and merge_metric not in ut.LIST_METRICS:
Expand Down Expand Up @@ -122,7 +94,7 @@ class AAclust(Wrapper):
Parameters
----------
model_class
A clustering model class with ``n_clusters`` parameter. This class will be instantiated during the ``fit`` method.
A clustering model class with ``n_clusters`` parameter. This class will be instantiated by the ``fit`` method.
model_kwargs
Keyword arguments to pass to the selected clustering model.
verbose
Expand Down Expand Up @@ -164,10 +136,11 @@ def __init__(self,
model_kwargs: Optional[Dict] = None,
verbose: bool = False):
# Model parameters
model_class = check_mode_class(model_class=model_class)
model_class = ut.check_mode_class(model_class=model_class)
if model_kwargs is None and model_class is KMeans:
model_kwargs = dict(n_init="auto")
model_kwargs = check_model_kwargs(model_class=model_class, model_kwargs=model_kwargs)
model_kwargs = ut.check_model_kwargs(model_class=model_class, model_kwargs=model_kwargs,
param_to_check="n_clusters")
self.model_class = model_class
self._model_kwargs = model_kwargs
self._verbose = ut.check_verbose(verbose)
Expand Down Expand Up @@ -296,11 +269,11 @@ def fit(self,
self.medoid_names_ = [names[i] for i in medoid_ind]
return self

@staticmethod
@ut.catch_runtime_warnings()
def eval(X: ut.ArrayLike2D,
labels:ut.ArrayLike1D = None
) -> Tuple[float, float, float]:
def eval(self,
X: ut.ArrayLike2D,
labels: Optional[ut.ArrayLike1D] = None
) -> Tuple[int, float, float, float]:
"""Evaluates the quality of clustering using three established measures.
Clustering quality is quantified using:
Expand All @@ -323,12 +296,14 @@ def eval(X: ut.ArrayLike2D,
Returns
-------
n_clusters : int
Number of clusters, equal to number of medoids.
BIC : float
BIC value for clustering.
BIC value for clustering (-inf to inf).
CH : float
CH value for clustering.
CH value for clustering (0 to inf).
SC : float
SC value for clustering.
SC value for clustering (-1 to 1).
Notes
-----
Expand All @@ -343,8 +318,13 @@ def eval(X: ut.ArrayLike2D,
# Check input
X = ut.check_X(X=X)
ut.check_X_unique_samples(X=X)
if labels is None:
labels = self.labels_
labels = ut.check_labels(labels=labels)
ut.check_match_X_labels(X=X, labels=labels)

# Number of clusters (number of medoids)
n_clusters = len(set(labels))
# Bayesian Information Criterion
BIC = bic_score(X, labels)
# Calinski-Harabasz Index
Expand All @@ -357,7 +337,7 @@ def eval(X: ut.ArrayLike2D,
if np.isnan(SC):
SC = -1
warnings.warn("SC was set to -1 because sklearn.metric.silhouette_score returned NaN.", RuntimeWarning)
return BIC, CH, SC
return n_clusters, BIC, CH, SC

@staticmethod
def name_clusters(X: ut.ArrayLike2D,
Expand Down
4 changes: 2 additions & 2 deletions aaanalysis/aaclust_plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from aaanalysis.aaclust.aaclust import AAclust
from aaanalysis.aaclust_plot.aaclust_plot import AAclustPlot

__all__ = ["AAclust"]
__all__ = ["AAclustPlot"]
Binary file not shown.
Binary file not shown.
93 changes: 81 additions & 12 deletions aaanalysis/aaclust_plot/aaclust_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,101 @@
import time
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from typing import Optional, Dict, Union, List, Tuple, Type
from sklearn.base import TransformerMixin
import matplotlib.pyplot as plt
import seaborn as sns


# Settings
pd.set_option('expand_frame_repr', False) # Single line print for pd.Dataframe

import aaanalysis as aa
import aaanalysis.utils as ut

# I Helper Functions
def _get_rank(data):
""""""
_df = data.copy()
_df['BIC_rank'] = _df['BIC'].rank(ascending=False)
_df['CH_rank'] = _df['CH'].rank(ascending=False)
_df['SC_rank'] = _df['SC'].rank(ascending=False)
return _df[['BIC_rank', 'CH_rank', 'SC_rank']].mean(axis=1).round(2)

# TODO add check functions finish other methods, testing, compression

# II Main Functions
class AAclustPlot:
"""Plot results of AAclust analysis"""
def __int__(self, model):
self.model = model
"""Plot results of AAclust analysis.
Dimensionality reduction is performed for visualization using decomposition models such as
Principal Component Analysis (PCA).
Parameters
----------
model_class
A decomposition model class with ``n_components`` parameter.
model_kwargs
Keyword arguments to pass to the selected decomposition model.
See Also
--------
* Scikit-learn `decomposition model classes <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.decomposition>`_.
"""
def __init__(self,
model_class: Type[TransformerMixin] = PCA,
model_kwargs: Optional[Dict] = None):
# Model parameters
model_class = ut.check_mode_class(model_class=model_class)
model_kwargs = ut.check_model_kwargs(model_class=model_class, model_kwargs=model_kwargs,
param_to_check="n_components")
self.model_class = model_class
self.model_kwargs = model_kwargs

@staticmethod
def eval():
"""Plot eval output of BIC, CH, SC"""
def eval(data : ut.ArrayLike2D,
names : Optional[List[str]] = None,
dict_xlims : Optional[Union[None, dict]] = None,
figsize : Optional[Tuple[int, int]] = (7, 6)):
"""Plot eval output of n_clusters, BIC, CH, SC"""
columns = ["n_clusters", "BIC", "CH", "SC"]
colors = aa.plot_get_clist(n_colors=4)

# Check input
data = ut.check_array_like(name="data", val=data)
n_samples, n_features = data.shape
if n_features != 4:
raise ValueError(f"'data' should contain the following four columns: {columns}")
if names is None:
names = [f"Model {i}" for i in range(1, n_samples+1)]
data = pd.DataFrame(data, columns=columns, index=names)
data["rank"] = _get_rank(data)
data = data.sort_values(by="rank", ascending=True)
# Plotting
fig, axes = plt.subplots(1, 4, sharey=True, figsize=figsize)
for i, col in enumerate(columns):
ax = axes[i]
sns.barplot(ax=ax, data=data, y=data.index, x=col, color=colors[i])
# Customize subplots
ax.set_ylabel("")
ax.set_xlabel(col)
ax.axvline(0, color='black') #, linewidth=aa.plot_gcfs("axes.linewidth"))
if dict_xlims and col in dict_xlims:
ax.set_xlim(dict_xlims[col])
if i == 0:
ax.set_title("Number of clusters", weight="bold")
elif i == 2:
ax.set_title("Quality measures", weight="bold")
sns.despine(ax=ax, left=True)
ax.tick_params(axis='y', which='both', left=False)
plt.tight_layout()
plt.subplots_adjust(wspace=0.25, hspace=0)


def center(self):
def center(self, data):
"""PCA plot of clustering with centers highlighted"""

def medoids(self):
def medoids(self, data):
"""PCA plot of clustering with medoids highlighted"""

@staticmethod
def correlation():
def correlation(df_corr=None):
"""Heatmap for correlation"""
Binary file modified aaanalysis/plotting/__pycache__/plot_gcfs_.cpython-39.pyc
Binary file not shown.
11 changes: 7 additions & 4 deletions aaanalysis/plotting/plot_gcfs_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import seaborn as sns

# Main function
def plot_gcfs():
def plot_gcfs(option='font.size'):
"""
Gets current font size.
Gets current font size (or axes linewdith).
This font size can be set by :func:`plot_settings` function.
Expand All @@ -33,7 +33,10 @@ def plot_gcfs():
--------
* Our `Plotting Prelude <plotting_prelude.html>`_.
"""
allowed_options = ["font.size", "axes.linewidth"]
if option not in allowed_options:
return ValueError(f"'option' should be one of following: {allowed_options}")
# Get the current plotting context
current_context = sns.plotting_context()
font_size = current_context['font.size']
return font_size
option_value = current_context[option] # Typically font_size
return option_value
8 changes: 5 additions & 3 deletions aaanalysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from aaanalysis.config import options

# Import utility functions explicitly
from aaanalysis._utils._check_data import (check_X, check_X_unique_samples, check_labels, check_match_X_labels,
check_array_like, check_superset_subset,
check_col_in_df)
from aaanalysis._utils._check_models import check_mode_class, check_model_kwargs
from aaanalysis._utils._check_type import (check_number_range, check_number_val, check_str, check_bool,
check_dict, check_tuple, check_list_like,
check_ax)
from aaanalysis._utils._check_data import (check_X, check_X_unique_samples, check_labels, check_match_X_labels,
check_superset_subset,
check_col_in_df)

from aaanalysis._utils.utils_cpp import (check_color, check_y_categorical, check_labels_, check_ylim,
check_args_len, check_args_len, check_list_parts,
check_split_kws, check_split,
Expand Down
Binary file modified docs/build/doctrees/api.doctree
Binary file not shown.
Binary file modified docs/build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.AAclust.doctree
Binary file not shown.
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.plot_gcfs.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/build/html/_sources/api.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Feature Engineering
:toctree: generated/

aaanalysis.AAclust
aaanalysis.AAclustPlot
aaanalysis.SequenceFeature
aaanalysis.CPP
aaanalysis.CPPPlot
Expand Down
Loading

0 comments on commit 0634bcd

Please sign in to comment.