Skip to content

Commit

Permalink
Make backend AAclustPlot methods (center, medoids, correlation)
Browse files Browse the repository at this point in the history
  • Loading branch information
breimanntools committed Oct 9, 2023
1 parent 537b50c commit 4f7a589
Show file tree
Hide file tree
Showing 45 changed files with 282 additions and 45 deletions.
Binary file modified aaanalysis/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file modified aaanalysis/_utils/__pycache__/utils_ploting.cpython-39.pyc
Binary file not shown.
118 changes: 117 additions & 1 deletion aaanalysis/_utils/utils_ploting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,128 @@
This is a script for internal plotting utility functions used in the backend.
"""
import seaborn as sns
import matplotlib.patches as mpatches


# Helper functions
def _get_color_map(labels, color):
unique_labels = sorted(set(labels))
if isinstance(color, list):
if len(color) != len(unique_labels):
raise ValueError("If color is a list, it must have the same length as the number of unique labels.")
color_map = {label: color[i] for i, label in enumerate(unique_labels)}
else:
color_map = {label: color for label in unique_labels}
return color_map, unique_labels


def _get_positions_lengths_colors(labels, color_map):
positions, lengths, colors = [], [], []
current_label, start_pos = labels[0], 0
for i, label in enumerate(labels + [None]):
if label != current_label or i == len(labels):
positions.append(start_pos)
length = i - start_pos
lengths.append(length)
colors.append(color_map[current_label])
start_pos = i
current_label = label
return positions, lengths, colors

def _get_xy_wh(ax=None, position=None, pos=None, barspacing=None, length=None, bar_width=None):
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
if position == 'bottom':
y = y_min + barspacing
return (pos, y), (length, bar_width)
elif position == 'top':
y = y_max - barspacing - bar_width
return (pos, y), (length, bar_width)
elif position == 'left':
x = x_min - barspacing - bar_width
return (x, pos), (bar_width, length)
elif position == 'right':
x = x_max + barspacing
return (x, pos), (bar_width, length)
else:
raise ValueError("Position should be 'left', 'right', 'top', or 'bottom'.")

def _get_xy_hava(position=None, xy=None, wh=None):
bar_width = wh[1] if position in ['bottom', 'top'] else wh[0]
if position == 'bottom':
text_x = xy[0] + wh[0] / 2
text_y = xy[1] + bar_width * 1.5
ha, va = 'center', 'top'
elif position == 'top':
text_x = xy[0] + wh[0] / 2
text_y = xy[1] + wh[1] - bar_width
ha, va = 'center', 'bottom'
elif position == 'left':
text_x = xy[0] - bar_width
text_y = xy[1] + wh[1] / 2
ha, va = 'right', 'center'
else: # Assuming position == 'right'
text_x = xy[0] + wh[0] + bar_width*0.5
text_y = xy[1] + wh[1] / 2
ha, va = 'left', 'center'
return text_x, text_y, ha, va

def _add_bar_labels(ax=None, bar_labels_align=None, position=None, bar_width=None,
labels=None, positions=None, lengths=None, bar_labels=None, barspacing=None):
label_map = {label: bar_labels[i] for i, label in enumerate(sorted(set(labels)))}
rotation = 0 if bar_labels_align == 'horizontal' else 90
for pos, length in zip(positions, lengths):
xy, wh = _get_xy_wh(ax=ax, position=position, pos=pos, barspacing=barspacing, length=length, bar_width=bar_width)
text_x, text_y, ha, va = _get_xy_hava(position=position, xy=xy, wh=wh)
ax.text(text_x, text_y, label_map[labels[int(pos + length / 2)]], ha=ha, va=va, rotation=rotation,
transform=ax.transData, clip_on=False)



# Main function
def plot_add_bars(ax, labels, position='left', bar_spacing=0.05, colors='tab:gray', bar_labels=None,
bar_labels_align='horizontal', bar_width=0.1):
"""
Add colored bars along a specified axis of the plot based on label grouping.
Parameters:
ax (matplotlib.axes._axes.Axes): The axes to which bars will be added.
labels (list or array-like): Labels determining bar grouping and coloring.
position (str): The position to add the bars ('left', 'right', 'top', 'bottom').
bar_spacing (float): Spacing between plot and added bars.
colors (str or list): Either a matplotlib color string, or a list of colors for each unique label.
bar_labels (list, optional): Labels for the bars.
bar_labels_align (str): Text alignment for bar labels, either 'horizontal' or other valid matplotlib alignment.
bar_width (float): Width of the bars.
Note:
This function adds colored bars in correspondence with the provided `labels` to visualize groupings in plots.
"""

if not isinstance(labels, list):
labels = list(labels)
single_color = isinstance(colors, str) or (isinstance(colors, (list, tuple)) and len(colors) == 1)
color_map, _ = _get_color_map(labels, colors)
positions, lengths, colors = _get_positions_lengths_colors(labels, color_map)
args_get = dict(position=position, bar_width=bar_width, barspacing=bar_spacing)
if bar_labels is not None:
_add_bar_labels(ax=ax, bar_labels_align=bar_labels_align,
labels=labels, positions=positions, lengths=lengths, bar_labels=bar_labels, **args_get)
# Adding bars
args = dict(transform=ax.transData, clip_on=False)
for pos, length, bar_color in zip(positions, lengths, colors):
xy, wh = _get_xy_wh(ax=ax, pos=pos, length=length, **args_get)
# Add edgecolor if only one color is specified
edgecolor = "white" if single_color else bar_color
ax.add_patch(mpatches.Rectangle(xy=xy, width=wh[0], height=wh[1],
facecolor=bar_color, edgecolor=edgecolor, linewidth=0.5,
**args))


def plot_gco(option='font.size', show_options=False):
"""Get current option from plotting context"""
current_context = sns.plotting_context()
if show_options:
print(current_context)
option_value = current_context[option] # Typically font_size
return option_value
return option_value
Binary file not shown.
73 changes: 66 additions & 7 deletions aaanalysis/feature_engineering/_aaclust_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def check_dict_xlims(dict_xlims=None):


# TODO add check functions finish other methods, testing, compression

# II Main Functions
class AAclustPlot:
"""Plot results of AAclust analysis.
Expand Down Expand Up @@ -151,7 +150,7 @@ def eval(data_eval: ut.ArrayLike2D,
colors=colors)
return fig, axes


# TODO check functions, docstring, testing
def center(self,
X: ut.ArrayLike2D,
labels: ut.ArrayLike1D = None,
Expand Down Expand Up @@ -193,6 +192,7 @@ def center(self,
legend=legend, palette=palette)
return ax, df_components

# TODO check functions, docstring, testing
def medoids(self,
X: ut.ArrayLike2D,
labels: ut.ArrayLike1D = None,
Expand All @@ -205,6 +205,7 @@ def medoids(self,
dot_size: Optional[int] = 100,
legend: Optional[bool] = True,
palette: Optional[mpl.colors.ListedColormap] = None,
return_data : Optional[bool] = False
) -> Tuple[plt.Axes, pd.DataFrame]:
"""PCA plot of clustering with medoids highlighted"""
# Check input
Expand All @@ -225,15 +226,73 @@ def medoids(self,
ax=ax, figsize=figsize,
dot_size=dot_size, dot_alpha=dot_alpha,
legend=legend, palette=palette)
if return_data:
return df_components
return ax

return ax, df_components

# TODO check functions, docstring, testing
@staticmethod
def correlation(df_corr=None, labels_sorted=None, **kwargs_heatmap):
"""Heatmap for correlation
def correlation(df_corr: Optional[pd.DataFrame] = None,
labels: Optional[List[str]] = None,
bar_position: str = "left",
bar_width: float = 0.1,
bar_spacing: float = 0.1,
bar_colors: Union[str, List[str]] = "gray",
bar_ticklabel_pad: Optional[float] = None,
vmin: float = -1,
vmax: float = 1,
cmap: str = "viridis",
**kwargs_heatmap
) -> plt.Axes:
"""
Heatmap for correlation matrix with colored sidebar to label clusters.
Parameters
----------
df_corr : `array-like, shape (n_samples, n_clusters)`
DataFrame with correlation matrix. `Rows` typically correspond to scales and `columns` to clusters.
labels
Labels determining the grouping and coloring of the side color bar.
It should be of the same length as `df_corr` columns/rows.
Defaults to None.
bar_position
Position of the colored sidebar (``left``, ``right``, ``top``, or ``down``). If ``None``, no sidebar is added.
bar_width
Width of the sidebar.
bar_spacing
Space between the heatmap and the side color bar.
bar_colors
Either a single color or a list of colors for each unique label in `labels`.
bar_ticklabel_pad
Padding for y-axis tick labels. If ``None``, uses default padding.
vmin
Minimum value of the color scale in the ``sns.heatmap()``.
vmax
Maximum value of the color scale in the ``sns.heatmap()``.
cmap
Colormap to be used for the ``sns.heatmap()``.
**kwargs_heatmap
Additional keyword arguments passed to ``sns.heatmap()``.
Returns
-------
ax : matplotlib.axes._axes.Axes
Axes object with the correlation heatmap.
Notes
-----
- Ensure `labels` and `df_corr` are in the same order to avoid mislabeling.
See Also
--------
- :func:`seaborn.heatmap` for information on kwargs_heatmap.
sns.heatmap : Seaborn function for creating heatmaps.
"""
plot_correlation(df_corr=df_corr, labels_sorted=labels_sorted, **kwargs_heatmap)
ax = plot_correlation(df_corr=df_corr, labels_sorted=labels,
bar_position=bar_position,
bar_width=bar_width, bar_spacing=bar_spacing, bar_colors=bar_colors,
bar_ticklabel_pad=bar_ticklabel_pad,
vmin=vmin, vmax=vmax, cmap=cmap, **kwargs_heatmap)
plt.tight_layout()
return ax
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@ def remove_2nd_info(name_):
# Compute correlation
def _sort_X_labels_names(X, labels=None, names=None):
"""Sort labels"""
print(labels)
sorted_order = np.argsort(labels)
labels = np.array([labels[i] for i in sorted_order])
print(labels)
X = X[sorted_order]
if names:
names = [names[i] for i in sorted_order]
print(names)
return X, labels, names

def _get_df_corr(X=None, X_ref=None):
Expand Down
24 changes: 16 additions & 8 deletions aaanalysis/feature_engineering/_backend/aaclust/aaclust_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,28 @@ def plot_center_or_medoid(X=None, labels=None,
return ax, df_components


def plot_correlation(df_corr=None, labels_sorted=None, **kwargs):
""""""
_kwargs = dict(cmap="viridis", vmin=-1, vmax=1,
cbar_kws = {"label": "Pearson correlation"})

_kwargs.update(**kwargs)
ax = sns.heatmap(df_corr, **_kwargs)
print(labels_sorted)
def plot_correlation(df_corr=None, labels_sorted=None,
bar_position="left", bar_width=0.1, bar_spacing=0.1, bar_colors="gray", bar_ticklabel_pad=None,
vmin=-1, vmax=1, cmap="viridis", **kwargs_heatmap):
"""Plots heatmap for clustering results with rows (y-axis) corresponding to scales and columns (x-axis) to clusters."""
# Plot heatmap
_kwargs_heatmap = {"cmap": cmap, "vmin": vmin, "vmax": vmax,
"cbar_kws": {"label": "Pearson correlation"},
**kwargs_heatmap}
ax = sns.heatmap(data=df_corr, **_kwargs_heatmap)
# Adjust ticks
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
# Customizing color bart tick lines
cbar = ax.collections[0].colorbar
lw = ut.plot_gco(option="axes.linewidth")
fs = ut.plot_gco(option="font.size")
cbar.ax.tick_params(axis='y', width=lw, length=6, color='black', labelsize=fs-1)
# Add bars for highlighting clustering
if bar_position is not None:
ut.plot_add_bars(ax=ax, labels=labels_sorted, bar_spacing=bar_spacing, bar_width=bar_width,
position=bar_position, colors=bar_colors)
if bar_ticklabel_pad is not None:
ax.tick_params(axis="y", which="both", pad=bar_ticklabel_pad)
plt.tight_layout()
return ax
2 changes: 1 addition & 1 deletion aaanalysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
doc_params)

from ._utils.utils_output import (print_out, print_start_progress, print_progress, print_finished_progress)
from ._utils.utils_ploting import (plot_gco)
from ._utils.utils_ploting import plot_gco, plot_add_bars


# Folder structure
Expand Down
Binary file modified docs/build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.AAclust.doctree
Binary file not shown.
Binary file modified docs/build/doctrees/generated/aaanalysis.AAclustPlot.doctree
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 4f7a589

Please sign in to comment.