Skip to content

Commit

Permalink
figure changes
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Nov 10, 2024
1 parent d2228bf commit 37d1442
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 25 deletions.
112 changes: 88 additions & 24 deletions figure_33_joint_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import seaborn as sb
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd
import pickle
import os
Expand Down Expand Up @@ -116,26 +117,31 @@ def get_dur_per_relation(label):

return df

def plot_boxplot(df, x_label, y_label="Balanced accuracy", order_ = None):
def plot_boxplot(df, x_label, y_label="Balanced accuracy", order_ = None, plt_txt = False, hide_ylabel=False):
sns.boxplot(x=x_label, y="per", data=df, showmeans=True, showfliers=False, palette="viridis", order=order_)
#sns.swarmplot(x="norm_window", y="ba", data=df_all, color="black", alpha=0.5, palette="viridis")
# put the mean values as text on top of the boxplot
means = df.groupby(x_label)["per"].mean()
if order_ is not None:
for i, x_label_ in enumerate(df.groupby(x_label)["per"].mean().sort_values(ascending=True).index):
mean = df[df[x_label] == x_label_]["per"].mean()
plt.text(i, mean, f"{np.round(mean, 2)}", ha="center", va="bottom")
else:
for i, mean in enumerate(means):
plt.text(i, mean, f"{mean:.2f}", ha="center", va="bottom")
if plt_txt:
if order_ is not None:
for i, x_label_ in enumerate(df.groupby(x_label)["per"].mean().sort_values(ascending=True).index):
mean = df[df[x_label] == x_label_]["per"].mean()
plt.text(i, mean, f"{np.round(mean, 2)}", ha="center", va="bottom")
else:
for i, mean in enumerate(means):
plt.text(i, mean, f"{mean:.2f}", ha="center", va="bottom")

plt.xlabel(x_label)
plt.xlabel("")
plt.ylabel(y_label)
plt.xticks(rotation=90)
plt.tight_layout()
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
if hide_ylabel:
plt.ylabel("")
#plt.tight_layout()
#plt.show(block=True)

def plot_per_train_time_relation(df, label):
def plot_per_train_time_relation(df, label, plt_txt=False, hide_ylabel=False):
#plt.figure(figsize=(10, 5), dpi=300)
durations = np.sort(df["dur"].unique())
sub_per = []
Expand All @@ -146,24 +152,76 @@ def plot_per_train_time_relation(df, label):
plt.plot(durations / 60, df_sub["per"], color="gray", alpha=0.2)
sub_per.append(df_sub["per"].values)
plt.xlabel("Duration [h]")
if label == "bk":
if label == "pkg_bk":
plt.ylabel("Correlation coefficient")
else:
plt.ylabel("Balanced accuracy")
if hide_ylabel:
plt.ylabel("")
# plot the mean accuracy for each duration
plt.plot(durations / 60, np.array(sub_per).mean(axis=0), marker="o", linestyle="-", color="black")
# write the mean accuracy on top of the line
for i, dur in enumerate(durations):
plt.text(durations[i] / 60, np.array(sub_per).mean(axis=0)[i], f"{np.round(np.array(sub_per).mean(axis=0)[i], 2)}", ha="center", va="bottom")

if plt_txt:
for i, dur in enumerate(durations):
plt.text(durations[i] / 60, np.array(sub_per).mean(axis=0)[i], f"{np.round(np.array(sub_per).mean(axis=0)[i], 2)}", ha="center", va="bottom")

plt.xscale('log')
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

#plt.title(f"LOHO PKG BK CV different training duration")
#plt.tight_layout()
#plt.savefig(os.path.join(PATH_FIGURE, f"LOHO_different_training_duration_sub_{label}.pdf"))

def read_columns_and_importances():
PATH_FEATURES = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/features/merged_normalized_10s_window_length/480"
df_all = pd.read_csv(os.path.join(PATH_FEATURES, "all_merged_normed.csv"), index_col=0)
df_all = df_all.dropna(axis=1)
df_all = df_all.replace([np.inf, -np.inf], np.nan)
df_all = df_all.dropna(axis=1)
df_all = df_all.drop(columns=["sub",])
df_all["pkg_dt"] = pd.to_datetime(df_all["pkg_dt"])
df_all["hour"] = df_all["pkg_dt"].dt.hour

# remove columns that start with pkg
df_all = df_all[[c for c in df_all.columns if not c.startswith("pkg")]]
columns_ = df_all.columns
PATH_FIGURES = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/figures_ucsf"
PATH_PER = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per"
PATH_PER = os.path.join(PATH_PER, "LOHO_ALL_LABELS_ALL_GROUPS_exludehour_False.pkl")

with open(PATH_PER, "rb") as f:
d_out = pickle.load(f)
return columns_, d_out

def plot_best_features(columns_, d_out, pkg_decode_label, cols_show=10):

data = []
if pkg_decode_label == "pkg_bk":
CLASS_ = False
else:
CLASS_ = True
d_out_ = d_out[CLASS_][pkg_decode_label]["ecog_stn"]
for sub in d_out_.keys():
data.append(d_out_[sub]["feature_importances"])
fimp = np.array(data)
mean_fimp = fimp.mean(axis=0)
cols_sorted = np.array(columns_)[np.argsort(mean_fimp)[::-1]]
colors = cm.viridis_r(np.linspace(0, 1, cols_show))


plt.barh(cols_sorted[:cols_show], mean_fimp[np.argsort(mean_fimp)[::-1]][:cols_show], color=colors)
plt.gca().invert_yaxis()
#plt.title(pkg_decode_label)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.xlabel("Feature importance - Prediction Value Change")

if __name__ == "__main__":

plt.figure(figsize=(10, 7))
columns_, d_out = read_columns_and_importances()

plt.figure(figsize=(12, 9))
for idx_, label_name in enumerate(["pkg_bk", "pkg_dk", "pkg_tremor"]):

l_models = []
Expand Down Expand Up @@ -216,21 +274,27 @@ def plot_per_train_time_relation(df, label):
y_label = "Correlation coefficient"
else:
y_label = "Balanced accuracy"
plt.subplot(3, 4, 4*idx_+1)
plot_boxplot(df_norm, "norm_window", y_label)
#plt.subplot(3, 4, 4*idx_+1)
#plot_boxplot(df_norm, "norm_window", y_label)

plt.subplot(3, 4, 4*idx_+2)
plt.subplot(3, 4, 4*idx_+1)
plot_boxplot(df_features_comb, "feature_mod", y_label,
order_=df_features_comb.groupby("feature_mod")["per"].mean().sort_values(ascending=True).index)
order_=df_features_comb.groupby("feature_mod")["per"].mean().sort_values(ascending=True).index,
hide_ylabel=False)

plt.subplot(3, 4, 4*idx_+3)
plt.subplot(3, 4, 4*idx_+2)
plot_boxplot(df_models, "model", y_label,
order_=df_models.groupby("model")["per"].mean().sort_values(ascending=True).index)
order_=df_models.groupby("model")["per"].mean().sort_values(ascending=True).index,
hide_ylabel=True)

plt.subplot(3, 4, 4*idx_+3)
plot_per_train_time_relation(df_per_dur_rel, label_name, hide_ylabel=True)

plt.subplot(3, 4, 4*idx_+4)
plot_per_train_time_relation(df_per_dur_rel, label_name)
plot_best_features(columns_, d_out, label_name)

#plt.savefig(os.path.join(PATH_FIGURES, "figure_33_joint_plot.pdf"))
plt.tight_layout()
plt.savefig(os.path.join(PATH_FIGURES, "figure_33_joint_plot_1011.pdf"))
plt.show(block=True)

print("df")
55 changes: 55 additions & 0 deletions figure_34_viz_feature_importances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats
import seaborn as sns
import pickle

PATH_PER = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per"
PATH_PER = os.path.join(PATH_PER, "LOHO_ALL_LABELS_ALL_GROUPS_exludehour_False.pkl")

PATH_FIGURES = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/figures_ucsf"

with open(PATH_PER, "rb") as f:
d_out = pickle.load(f)

PATH_FEATURES = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/features/merged_normalized_10s_window_length/480"
df_all = pd.read_csv(os.path.join(PATH_FEATURES, "all_merged_normed.csv"), index_col=0)
df_all = df_all.dropna(axis=1)
df_all = df_all.replace([np.inf, -np.inf], np.nan)
df_all = df_all.dropna(axis=1)
df_all = df_all.drop(columns=["sub",])
df_all["pkg_dt"] = pd.to_datetime(df_all["pkg_dt"])
df_all["hour"] = df_all["pkg_dt"].dt.hour

# remove columns that start with pkg
df_all = df_all[[c for c in df_all.columns if not c.startswith("pkg")]]
columns_ = df_all.columns

plt.figure(figsize=(10, 10))
cols_show = 50
for idx_, pkg_decode_label in enumerate(["pkg_dk", "pkg_bk", "pkg_tremor"]):

data = []
if pkg_decode_label == "pkg_bk":
CLASS_ = False
else:
CLASS_ = True
d_out_ = d_out[CLASS_][pkg_decode_label]["ecog_stn"]
for sub in d_out_.keys():
data.append(d_out_[sub]["feature_importances"])
fimp = np.array(data)
mean_fimp = fimp.mean(axis=0)
cols_sorted = np.array(columns_)[np.argsort(mean_fimp)[::-1]]

plt.subplot(3, 1, idx_+1)
plt.barh(cols_sorted[:cols_show], mean_fimp[np.argsort(mean_fimp)[::-1]][:cols_show])
plt.gca().invert_yaxis()
plt.title(pkg_decode_label)
plt.xlabel("Feature importance - Prediction Value Change")
#plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(os.path.join(PATH_FIGURES, "feature_importance_plt_bar.pdf"))
plt.show(block=True)

78 changes: 78 additions & 0 deletions figure_35_time_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl
import pickle
import os
import seaborn as sns


def read_per(d_out):
l = []
for CLASSIFICATION in d_out.keys():
if CLASSIFICATION is True:
per_ = "ba"
else:
per_ = "corr_coeff"
for pkg_label in d_out[CLASSIFICATION].keys():
for sub in d_out[CLASSIFICATION][pkg_label]["ecog_stn"].keys():
l.append({
"sub": sub,
"pkg_label": pkg_label,
"CLASSIFICATION": CLASSIFICATION,
"per": d_out[CLASSIFICATION][pkg_label]["ecog_stn"][sub][per_]
})
df_loso = pd.DataFrame(l)
return df_loso

PATH_PER = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per"
PATH_FIGURES = "/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/figures_ucsf"
l_ = []
for exclude_hour in [True, False]:
file = f"LOHO_ALL_LABELS_ALL_GROUPS_exludehour_{exclude_hour}.pkl"
with open(os.path.join(PATH_PER, file), "rb") as f:
d_out = pickle.load(f)
df_ = read_per(d_out)
df_["hour_feature"] = not exclude_hour
l_.append(df_)
df_h = pd.concat(l_, axis=0)



l_ = []
for exclude_night in [True, False]:
file = f"LOHO_ALL_LABELS_ALL_GROUPS_exludenight_{exclude_night}.pkl"
with open(os.path.join(PATH_PER, file), "rb") as f:
d_out = pickle.load(f)
df_ = read_per(d_out)
df_["include_night"] = not exclude_night
l_.append(df_)
df_n = pd.concat(l_, axis=0)

def set_box_alpha(ax, alpha=0.5):
for patch in ax.patches:
r, g, b, a = patch.get_facecolor()
patch.set_facecolor((r, g, b, alpha))

plt.figure(figsize=(10, 7), dpi=300)
plt.subplot(2, 2, 1)
ax = sns.boxplot(data=df_h.query("CLASSIFICATION == True"), x="pkg_label", y="per", hue="hour_feature", palette="viridis", showmeans=True, showfliers=False); set_box_alpha(ax)
sns.swarmplot(data=df_h.query("CLASSIFICATION == True"), x="pkg_label", y="per", hue="hour_feature", dodge=True, palette="viridis", alpha=0.9, s=2)
plt.ylabel("Balanced accuracy")
plt.subplot(2, 2, 2)
ax = sns.boxplot(data=df_h.query("CLASSIFICATION == False"), x="pkg_label", y="per", hue="hour_feature", palette="viridis", showmeans=True, showfliers=False); set_box_alpha(ax)
sns.swarmplot(data=df_h.query("CLASSIFICATION == False"), x="pkg_label", y="per", hue="hour_feature", dodge=True, palette="viridis", alpha=0.9, s=2)
plt.ylabel("Correlation coefficient")
plt.tight_layout()
plt.savefig(os.path.join(PATH_FIGURES, "figure_35_per_exlude_hour_feature.pdf"))
plt.subplot(2, 2, 3)
ax = sns.boxplot(data=df_n.query("CLASSIFICATION == True"), x="pkg_label", y="per", hue="include_night", palette="viridis", showmeans=True, showfliers=False); set_box_alpha(ax)
sns.swarmplot(data=df_n.query("CLASSIFICATION == True"), x="pkg_label", y="per", hue="include_night", dodge=True, palette="viridis", alpha=0.9, s=2)
plt.ylabel("Balanced accuracy")
plt.subplot(2, 2, 4)
ax = sns.boxplot(data=df_n.query("CLASSIFICATION == False"), x="pkg_label", y="per", hue="include_night", palette="viridis", showmeans=True, showfliers=False); set_box_alpha(ax)
sns.swarmplot(data=df_n.query("CLASSIFICATION == False"), x="pkg_label", y="per", hue="include_night", dodge=True, palette="viridis", alpha=0.9, s=2)
plt.ylabel("Correlation coefficient")
plt.tight_layout()
plt.savefig(os.path.join(PATH_FIGURES, "figure_35_per_exclude_analysis.pdf"))
plt.show(block=True)
2 changes: 1 addition & 1 deletion run_decoding_ucsf_across_patients.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
df_all = df_all.drop(columns=df_all.columns[df_all.isnull().all()])
df_all["pkg_dt"] = pd.to_datetime(df_all["pkg_dt"])
if EXCLUDE_NIGHT_TIME:
df_all = df_all[(df_all["pkg_dt"].dt.hour >= 9) & (df_all["pkg_dt"].dt.hour <= 18)]
df_all = df_all[(df_all["pkg_dt"].dt.hour >= 8) & (df_all["pkg_dt"].dt.hour <= 20)]

d_out = {}

Expand Down
Loading

0 comments on commit 37d1442

Please sign in to comment.