Skip to content

Commit

Permalink
Add tests for visualizations (#170)
Browse files Browse the repository at this point in the history
Clean up visualization code, make usable without GUI, add tests.
  • Loading branch information
dweindl authored Feb 18, 2025
1 parent a488a4c commit caa10d2
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 166 deletions.
137 changes: 18 additions & 119 deletions src/ccompass/RP.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import os
from io import BytesIO
from pathlib import Path

import FreeSimpleGUI as sg
Expand All @@ -11,17 +10,18 @@
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image, ImageFile
from scipy.cluster.hierarchy import leaves_list, linkage
from scipy.stats import zscore

from ccompass.core import ResultsModel
from ccompass.visualize import fig_to_bytes, fract_heatmap

logger = logging.getLogger(__package__)


def RP_gradient_heatmap(fract_data):
"""Create a GUI to display a heatmap with hierarchical clustering for each condition."""
"""Create a GUI to display a heatmap with hierarchical clustering for each
condition."""

conditions = list(fract_data["vis"])

Expand Down Expand Up @@ -55,52 +55,14 @@ def plot_heatmap(
condition_name: str,
save_as_pdf=False,
folder_path=None,
) -> ImageFile:
# Perform hierarchical clustering on the rows
linkage_matrix = linkage(dataframe, method="ward")
clustered_rows = leaves_list(
linkage_matrix
) # Order of rows after clustering

# Reorder the DataFrame rows based on hierarchical clustering
df_clustered = dataframe.iloc[clustered_rows, :]

# Custom colormap: from #f2f2f2 (for value 0) to #6d6e71 (for value 1)
cmap = mcolors.LinearSegmentedColormap.from_list(
"custom_gray", ["#f2f2f2", "#6d6e71"], N=256
)

# Plot the heatmap using seaborn with the custom color gradient
plt.figure(figsize=(8, 6))
sns.heatmap(
df_clustered,
cmap=cmap,
cbar=True,
xticklabels=False,
yticklabels=False,
vmin=0,
vmax=1,
)

# Add the condition name as the title of the plot
plt.title(f"Condition: {condition_name}", fontsize=16)

plt.tight_layout()
):
fract_heatmap(dataframe, title=f"Condition: {condition_name}")

# If we need to save the plot as a PDF file
if save_as_pdf and folder_path:
pdf_filename = os.path.join(
folder_path, f"{condition_name}_heatmap.pdf"
)
pdf_filename = Path(folder_path, f"{condition_name}_heatmap.pdf")
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

def export_results(fract_data, folder_path):
"""Export dataframes to Excel and heatmaps to PDFs."""
Path(folder_path).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -137,12 +99,8 @@ def export_results(fract_data, folder_path):
df = fract_data["vis"][selected_condition]

# Generate the heatmap with hierarchical clustering and condition name as the title
heatmap_image = plot_heatmap(df, selected_condition)

# Convert the heatmap image to PNG and update the window
bio = BytesIO()
heatmap_image.save(bio, format="PNG")
window["-HEATMAP-"].update(data=bio.getvalue())
plot_heatmap(df, selected_condition)
window["-HEATMAP-"].update(data=fig_to_bytes())

# If the Export button is clicked
elif event == "Export":
Expand Down Expand Up @@ -241,13 +199,6 @@ def plot_heatmap(
)
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

def export_results(results: dict[str, ResultsModel], folder_path):
"""Export dataframes to Excel and heatmaps to PDFs"""
Path(folder_path).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -285,12 +236,8 @@ def export_results(results: dict[str, ResultsModel], folder_path):
df = results[selected_condition].metrics

# Generate the heatmap with hierarchical clustering and condition name as the title
heatmap_image = plot_heatmap(df, selected_condition)

# Convert the heatmap image to PNG and update the window
bio = BytesIO()
heatmap_image.save(bio, format="PNG")
window["-HEATMAP-"].update(data=bio.getvalue())
plot_heatmap(df, selected_condition)
window["-HEATMAP-"].update(data=fig_to_bytes())

# If the Export button is clicked
elif event == "Export":
Expand Down Expand Up @@ -363,13 +310,6 @@ def plot_pie_chart(
)
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer for display in the GUI
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

# Function to export pie charts and summary to Excel and PDFs
def export_pie_charts(results: dict[str, ResultsModel], folder_path):
Path(folder_path).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -418,12 +358,8 @@ def export_pie_charts(results: dict[str, ResultsModel], folder_path):
df = results[selected_condition].metrics

# Generate the pie chart for the class distribution
pie_chart_image = plot_pie_chart(df, selected_condition)

# Convert the pie chart image to PNG and update the window
bio = BytesIO()
pie_chart_image.save(bio, format="PNG")
window["-PIECHART-"].update(data=bio.getvalue())
plot_pie_chart(df, selected_condition)
window["-PIECHART-"].update(data=fig_to_bytes())

# If the Export button is clicked
if event == "Export":
Expand Down Expand Up @@ -532,13 +468,6 @@ def plot_heatmap(
)
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

# Function to export filtered and renamed dataframes to Excel and heatmaps to PDFs
def export_heatmaps(comparison, folder_path):
Path(folder_path).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -587,14 +516,8 @@ def export_heatmaps(comparison, folder_path):
df_filtered_for_plot = filter_and_prepare_data(df)

# Generate the heatmap with hierarchical clustering and comparison name as the title
heatmap_image = plot_heatmap(
df_filtered_for_plot, selected_comparison
)

# Convert the heatmap image to PNG and update the window
bio = BytesIO()
heatmap_image.save(bio, format="PNG")
window["-HEATMAP-"].update(data=bio.getvalue())
plot_heatmap(df_filtered_for_plot, selected_comparison)
window["-HEATMAP-"].update(data=fig_to_bytes())

# If the Export button is clicked
if event == "Export":
Expand Down Expand Up @@ -689,13 +612,6 @@ def plot_scatter(
)
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer for display in the GUI
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

# Function to export scatter plot data and save to Excel and PDFs
def export_scatter_data(comparison, folder_path):
Path(folder_path).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -743,12 +659,8 @@ def export_scatter_data(comparison, folder_path):
df = comparison[selected_comparison].metrics

# Generate the scatter plot with the selected comparison name
scatter_image = plot_scatter(df, selected_comparison)

# Convert the scatter plot image to PNG and update the window
bio = BytesIO()
scatter_image.save(bio, format="PNG")
window["-SCATTERPLOT-"].update(data=bio.getvalue())
plot_scatter(df, selected_comparison)
window["-SCATTERPLOT-"].update(data=fig_to_bytes())

# If the Export button is clicked
if event == "Export":
Expand Down Expand Up @@ -854,13 +766,6 @@ def plot_heatmap(
)
plt.savefig(pdf_filename, format="pdf")

# Save the plot to a bytes buffer
bio = BytesIO()
plt.savefig(bio, format="PNG")
plt.close()
bio.seek(0)
return Image.open(bio)

except ValueError:
logger.exception("Error during clustering")
return None
Expand Down Expand Up @@ -940,15 +845,9 @@ def export_heatmaps(
df_zscore = compute_rowwise_zscore(df_filtered)

# Generate the heatmap for the selected classname, using condition names as column labels
heatmap_image = plot_heatmap(
df_zscore, selected_classname, conditions
)
plot_heatmap(df_zscore, selected_classname, conditions)

# Convert the heatmap image to PNG and update the window
if heatmap_image:
bio = BytesIO()
heatmap_image.save(bio, format="PNG")
window["-HEATMAP-"].update(data=bio.getvalue())
window["-HEATMAP-"].update(data=fig_to_bytes())

# If the Export button is clicked
if event == "Export":
Expand Down
29 changes: 7 additions & 22 deletions src/ccompass/marker_correlation_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from .visualize import marker_correlation_heatmap


def draw_figure(canvas: sg.Canvas, figure: plt.Figure) -> FigureCanvasTkAgg:
"""Draw a figure on a canvas."""
Expand All @@ -27,25 +29,6 @@ def update_figure(
return draw_figure(canvas, figure)


def create_heatmap(dataframe: pd.DataFrame, title: str = None) -> plt.Figure:
"""Create correlation heatmap."""
fig, ax = plt.subplots(figsize=(8, 8)) # Adjust the figure size as needed
cax = ax.matshow(dataframe, cmap="coolwarm", vmin=-1, vmax=1)
fig.colorbar(cax)
ax.set_xticks(range(len(dataframe.columns)))
ax.set_xticklabels(
dataframe.columns, rotation=90, fontsize=8
) # Rotate x-axis labels 90 degrees
ax.set_yticks(range(len(dataframe.index)))
ax.set_yticklabels(dataframe.index, fontsize=8)
plt.subplots_adjust(
top=0.8, bottom=0.1, left=0.2
) # Adjust the top and bottom margins
if title:
plt.title(title)
return fig


def update_class_info(
marker_list: pd.DataFrame, classnames: list[str], data: pd.DataFrame
) -> list[tuple[str, int]]:
Expand Down Expand Up @@ -165,7 +148,9 @@ def show_marker_correlation_dialog(
window = _create_window(condition, correlation_matrices, class_info_dict)

# Initial drawing
fig = create_heatmap(correlation_matrices[condition], title=condition)
fig = marker_correlation_heatmap(
correlation_matrices[condition], title=condition
)
figure_agg = draw_figure(window["-CANVAS-"].TKCanvas, fig)

while True:
Expand All @@ -176,7 +161,7 @@ def show_marker_correlation_dialog(

if event == "-condition-":
condition = values["-condition-"]
fig = create_heatmap(
fig = marker_correlation_heatmap(
correlation_matrices[condition], title=condition
)
figure_agg = update_figure(
Expand All @@ -190,7 +175,7 @@ def show_marker_correlation_dialog(

for cond, df in correlation_matrices.items():
# Save the plot
fig = create_heatmap(df, title=cond)
fig = marker_correlation_heatmap(df, title=cond)
fig.savefig(
os.path.join(folder_path, f"{cond}.pdf"), format="pdf"
)
Expand Down
31 changes: 6 additions & 25 deletions src/ccompass/marker_profiles_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,7 @@
update_class_info,
update_figure,
)


def create_line_plot(data: pd.DataFrame, title=None) -> plt.Figure:
fig, ax = plt.subplots(figsize=(8, 6)) # Adjust the figure size as needed
for column in data.columns:
ax.plot(data.index, data[column], label=column)
ax.legend(
loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3
) # Place the legend below the plot
ax.set_xlabel("fractions")
ax.set_ylabel("normalized intensity")
ax.set_xticks([]) # Remove x-tick labels
ax.set_yticks([0, 1])
ax.set_yticklabels(["0", "1"])
ax.set_xlim(0, len(data.index) - 1) # Set x-axis limits
if title:
plt.title(title)
plt.ylim(0, 1)
fig.tight_layout(
rect=(0, 0, 1, 0.95)
) # Adjust layout to make room for the legend
return fig
from .visualize import plot_marker_profiles


def show_marker_profiles_dialog(fract_data, fract_info, marker_list, key):
Expand Down Expand Up @@ -106,7 +85,7 @@ def show_marker_profiles_dialog(fract_data, fract_info, marker_list, key):
)

# Initial drawing
fig = create_line_plot(profiles_dict[condition], title=condition)
fig = plot_marker_profiles(profiles_dict[condition], title=condition)
figure_agg = draw_figure(window["-CANVAS-"].TKCanvas, fig)

while True:
Expand All @@ -117,7 +96,9 @@ def show_marker_profiles_dialog(fract_data, fract_info, marker_list, key):

if event == "-condition-":
condition = values["-condition-"]
fig = create_line_plot(profiles_dict[condition], title=condition)
fig = plot_marker_profiles(
profiles_dict[condition], title=condition
)
figure_agg = update_figure(
window["-CANVAS-"].TKCanvas, figure_agg, fig
)
Expand Down Expand Up @@ -145,7 +126,7 @@ def show_marker_profiles_dialog(fract_data, fract_info, marker_list, key):

# Save the plot
for cond, df in profiles_dict.items():
fig = create_line_plot(df, title=cond)
fig = plot_marker_profiles(df, title=cond)
fig.savefig(
os.path.join(
folder_path, f"markerprofiles_{cond}.pdf"
Expand Down
Loading

0 comments on commit caa10d2

Please sign in to comment.