Skip to content

Commit

Permalink
Tests for AAclustPlot().eval()
Browse files Browse the repository at this point in the history
  • Loading branch information
breimanntools committed Oct 9, 2023
1 parent 881ed77 commit d3bfe66
Show file tree
Hide file tree
Showing 367 changed files with 413 additions and 65 deletions.
Binary file modified aaanalysis/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file modified aaanalysis/_utils/__pycache__/check_data.cpython-39.pyc
Binary file not shown.
Binary file modified aaanalysis/_utils/__pycache__/check_models.cpython-39.pyc
Binary file not shown.
3 changes: 1 addition & 2 deletions aaanalysis/_utils/check_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
This is a script for ...
This is a script for data checking utility functions.
"""
import pandas as pd
import numpy as np
# Write wrapper around scikit checkers
from sklearn.utils import check_array

# Helper functions
Expand Down
2 changes: 1 addition & 1 deletion aaanalysis/_utils/utils_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This is a script for adjust output (mainly for python console)
This is a script for adjusting terminal output.
"""
import numpy as np

Expand Down
3 changes: 3 additions & 0 deletions aaanalysis/_utils/utils_ploting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
This is a script for internal plotting utility functions used in the backend.
"""
Binary file modified aaanalysis/feature_engineering/__pycache__/_aaclust.cpython-39.pyc
Binary file not shown.
Binary file not shown.
3 changes: 2 additions & 1 deletion aaanalysis/feature_engineering/_aaclust.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
This is a script for the interface of the AAclust class, used for clustering wrapper method.
This is a script for the interface of the AAclust class, a clustering wrapper object to obtain redundancy-reduced
scale subsets.
"""
import numpy as np
from typing import Optional, Dict, Union, List, Tuple, Type
Expand Down
122 changes: 102 additions & 20 deletions aaanalysis/feature_engineering/_aaclust_plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
This is a script for the interface of the AAclustPlot class, used for plotting the results of AAclust.
This is a script for the frontend of the AAclustPlot class, used for plotting of the AAclust results.
"""
import pandas as pd
from sklearn.decomposition import PCA
from typing import Optional, Dict, Union, List, Tuple, Type
from sklearn.base import TransformerMixin
import matplotlib.pyplot as plt

import aaanalysis as aa
import aaanalysis.utils as ut
Expand All @@ -12,11 +14,51 @@


# I Helper Functions
def _get_components(data=None, model_class=None):
def check_match_data_names(data=None, names=None):
""""""
n_samples = len(data)
if names is not None:
if len(names) != n_samples:
raise ValueError(f"n_samples does not match for 'data' ({n_samples}) and 'names' ({len(names)}).")
else:
names = [f"Set {i}" for i in range(1, n_samples + 1)]
if not isinstance(data, pd.DataFrame):
data = ut.check_array_like(name=data, val=data)
n_samples, n_features = data.shape
# Check matching number of features
if n_features != 4:
raise ValueError(f"'data' should contain the following four columns: {ut.COLS_EVAL_AACLUST}")
df_eval = pd.DataFrame(data, columns=ut.COLS_EVAL_AACLUST, index=names)
else:
df_eval = data
# Check data for missing columns
missing_cols = [x for x in ut.COLS_EVAL_AACLUST if x not in list(df_eval)]
if len(missing_cols) > 0:
raise ValueError(f"'data' must contain the following columns: {missing_cols}")
df_eval.index = names
return df_eval


def check_dict_xlims(dict_xlims=None):
""""""
if dict_xlims is None:
return
ut.check_dict(name="dict_xlims", val=dict_xlims)
wrong_keys = [x for x in list(dict_xlims) if x not in ut.COLS_EVAL_AACLUST]
if len(wrong_keys) > 0:
raise ValueError(f"'dict_xlims' should not contain the following keys: {wrong_keys}")
for key in dict_xlims:
if len(dict_xlims[key]) != 2:
raise ValueError("'dict_xlims' values should be tuple with two numbers.")
xmin, xmax = dict_xlims[key]
ut.check_number_val(name="dict_xlims:min", val=xmin, just_int=False, accept_none=False)
ut.check_number_val(name="dict_xlims:max", val=xmax, just_int=False, accept_none=False)
if xmin >= xmax:
raise ValueError(f"'dict_xlims:min' ({xmin}) should be < 'dict_xlims:max' ({xmax}) for '{key}'.")


# TODO add check functions finish other methods, testing, compression

# TODO add check functions finish other methods, testing, compression
# II Main Functions
class AAclustPlot:
"""Plot results of AAclust analysis.
Expand All @@ -34,36 +76,76 @@ class AAclustPlot:
See Also
--------
* Scikit-learn `decomposition model classes <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.decomposition>`_.
"""
def __init__(self,
model_class: Type[TransformerMixin] = PCA,
model_kwargs: Optional[Dict] = None):
# Model parameters
model_class = ut.check_mode_class(model_class=model_class)
model_kwargs = ut.check_model_kwargs(model_class=model_class, model_kwargs=model_kwargs,
model_kwargs = ut.check_model_kwargs(model_class=model_class,
model_kwargs=model_kwargs,
param_to_check="n_components")
self.model_class = model_class
self.model_kwargs = model_kwargs

@staticmethod
def eval(data : ut.ArrayLike2D,
names : Optional[List[str]] = None,
dict_xlims : Optional[Union[None, dict]] = None,
figsize : Optional[Tuple[int, int]] = (7, 6)):
"""Plot eval output of n_clusters, BIC, CH, SC"""
columns = ["n_clusters", "BIC", "CH", "SC"]
colors = aa.plot_get_clist(n_colors=4)

def eval(data: ut.ArrayLike2D,
names: Optional[List[str]] = None,
dict_xlims: Optional[Union[None, dict]] = None,
figsize: Optional[Tuple[int, int]] = (7, 6)
) -> Tuple[plt.Figure, plt.Axes]:
"""
Evaluates and plots ``n_clusters`` and clustering metrics ``BIC``, ``CH``, and ``SC`` for the provided data.
The clustering evaluation metrics (BIC, CH, and SC) are ranked by the average of their independent rankings.
Parameters
----------
data : `array-like, shape (n_samples, n_features)`
Evaluation matrix or DataFrame. `Rows` correspond to scale sets and `columns` to the following
four evaluation measures:
- ``n_clusters``: Number of clusters.
- ``BIC``: Bayesian Information Criterion.
- ``CH``: Calinski-Harabasz Index.
- ``SC``: Silhouette Coefficient.
names
Names of scale sets from ``data``. If None, names are internally generated as 'Set 1', 'Set 2' etc.
dict_xlims
A dictionary containing x-axis limits (``xmin``, ``xmax``) for selected evaluation measure metric subplots.
Keys should be names of the ``evaluation measures`` (e.g., 'BIC'). If None, x-axis are auto-scaled.
figsize
Width and height of the figure in inches.
Returns
-------
fig
Figure object containing the plots.
axes
Axes object(s) containing four subplots.
Notes
-----
- The data is ranked in ascending order of the average ranking of the scale sets.
See Also
--------
* :meth:`AAclust.eval` for details on evaluation measures.
"""
# Check input
data = ut.check_array_like(name="data", val=data)
n_samples, n_features = data.shape
if n_features != 4:
raise ValueError(f"'data' should contain the following four columns: {columns}")
if names is None:
names = [f"Model {i}" for i in range(1, n_samples+1)]
ut.check_array_like(name="data", val=data)
ut.check_list_like(name="names", val=names, accept_none=True)
df_eval = check_match_data_names(data=data, names=names)
check_dict_xlims(dict_xlims=dict_xlims)
ut.check_tuple(name="figsize", val=figsize, n=2, accept_none=True)
# Plotting
fig, axes = plot_eval()
colors = aa.plot_get_clist(n_colors=4)
fig, axes = plot_eval(df_eval=df_eval,
dict_xlims=dict_xlims,
figsize=figsize,
colors=colors)
return fig, axes


def center(self, data):
Expand Down
Binary file not shown.
66 changes: 49 additions & 17 deletions aaanalysis/feature_engineering/_backend/aaclust/aaclust_plot.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,83 @@
"""
This is a script for the AAclust plot_eval method.
This is a script for the backend of the AAclustPlot object for all plotting functions.
"""
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import aaanalysis.utils as ut
import matplotlib.ticker as mticker


# I Helper Functions
def _get_rank(data):
# Computation helper functions
def _get_mean_rank(data):
""""""
_df = data.copy()
_df['BIC_rank'] = _df['BIC'].rank(ascending=False)
_df['CH_rank'] = _df['CH'].rank(ascending=False)
_df['SC_rank'] = _df['SC'].rank(ascending=False)
return _df[['BIC_rank', 'CH_rank', 'SC_rank']].mean(axis=1).round(2)
_df['BIC_rank'] = _df[ut.COL_BIC].rank(ascending=False)
_df['CH_rank'] = _df[ut.COL_CH].rank(ascending=False)
_df['SC_rank'] = _df[ut.COL_SC].rank(ascending=False)
rank = _df[['BIC_rank', 'CH_rank', 'SC_rank']].mean(axis=1).round(2)
return rank

# II Main Functions
def plot_eval(data=None, names=None, dict_xlims=None, figsize=None, columns=None, colors=None):
def _get_components(data=None, model_class=None):
""""""
data = pd.DataFrame(data, columns=columns, index=names)
data["rank"] = _get_rank(data)
data = data.sort_values(by="rank", ascending=True)

# Plotting helper functions
def _adjust_spines(ax=None):
"""Adjust spines to be in middle if data range from <0 to >0"""
min_val, max_val = ax.get_xlim()
if max_val > 0 and min_val >= 0:
sns.despine(ax=ax)
else:
sns.despine(ax=ax, left=True)
current_lw = ax.spines['bottom'].get_linewidth()
ax.axvline(0, color='black', linewidth=current_lw)
val = max([abs(min_val), abs(max_val)])
ax.set_xlim(-val, val)
return ax


def _x_ticks_0(ax):
"""Apply custom formatting for x-axis ticks."""
def custom_x_ticks(x, pos):
"""Format x-axis ticks."""
return f'{x:.2f}' if x else f'{x:.0f}'
ax.xaxis.set_major_formatter(mticker.FuncFormatter(custom_x_ticks))




# II Main Functions
def plot_eval(df_eval=None, dict_xlims=None, figsize=None, colors=None):
"""Plot evaluation of AAclust clustering results"""
df_eval[ut.COL_RANK] = _get_mean_rank(df_eval)
df_eval = df_eval.sort_values(by=ut.COL_RANK, ascending=True)
# Plotting
fig, axes = plt.subplots(1, 4, sharey=True, figsize=figsize)
for i, col in enumerate(columns):
for i, col in enumerate(ut.COLS_EVAL_AACLUST):
ax = axes[i]
sns.barplot(ax=ax, data=data, y=data.index, x=col, color=colors[i])
sns.barplot(ax=ax, data=df_eval, y=df_eval.index, x=col, color=colors[i])
# Customize subplots
ax.set_ylabel("")
ax.set_xlabel(col)
ax.axvline(0, color='black') # , linewidth=aa.plot_gcfs("axes.linewidth"))
# Adjust spines
ax = _adjust_spines(ax=ax)
# Manual xlims, if needed
if dict_xlims and col in dict_xlims:
ax.set_xlim(dict_xlims[col])
if i == 0:
ax.set_title("Number of clusters", weight="bold")
ax.set_title("Clustering", weight="bold")
elif i == 2:
ax.set_title("Quality measures", weight="bold")
sns.despine(ax=ax, left=True)
ax.tick_params(axis='y', which='both', left=False)
_x_ticks_0(ax=ax)
plt.tight_layout()
plt.subplots_adjust(wspace=0.25, hspace=0)
return fig, axes


def _plot_pca(df_pred=None, filter_classes=None, x=None, y=None, others=True, highlight_rel=True,
figsize=(6, 6), highlight_mean=True, list_classes=None):
""""""
Expand Down
Empty file.
10 changes: 8 additions & 2 deletions aaanalysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ def _folder_path(super_folder, folder_name):
NAMES_SCALE_SETS = [STR_SCALES, STR_SCALES_RAW, STR_SCALE_CAT,
STR_SCALES_PC, STR_TOP60, STR_TOP60_EVAL]

# Options
# AAclust
METRIC_CORRELATION = "correlation"
LIST_METRICS = [METRIC_CORRELATION, "manhattan", "euclidean", "cosine"]
STR_UNCLASSIFIED = "Unclassified"
COL_N_CLUST = "n_clusters"
COL_BIC = "BIC"
COL_CH = "CH"
COL_SC = "SC"
COL_RANK = "rank"
COLS_EVAL_AACLUST = [COL_N_CLUST, COL_BIC, COL_CH, COL_SC]

# Column names for primary df
# df_seq
Expand Down Expand Up @@ -247,7 +253,7 @@ def check_df_parts(df_parts=None, verbose=True):
print(warning)
#raise ValueError("'df_part' should not be None")
else:
if not (isinstance(df_parts, pd.DataFrame)):
if not isinstance(df_parts, pd.DataFrame):
raise ValueError(f"'df_parts' ({type(df_parts)}) must be type pd.DataFrame")
if len(list(df_parts)) == 0 or len(df_parts) == 0:
raise ValueError("'df_parts' should not be empty pd.DataFrame")
Expand Down
Binary file modified docs/build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.AAclust.doctree
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.AAclustPlot.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit d3bfe66

Please sign in to comment.