Skip to content

Commit

Permalink
Merge branch 'stats' into publication1
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinay Jayaram committed May 16, 2018
2 parents fd020c0 + 67951c7 commit 579bbe7
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 29 deletions.
3 changes: 0 additions & 3 deletions examples/plot_filterbank_csp_vs_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@
suffix='examples', overwrite=overwrite)
results = evaluation.process(pipelines)

# cashed results might return other pipelines
results = results[results.pipeline == 'CSP + LDA']

# bank of 6 filter, by 4 Hz increment
filters = [[8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 35]]
paradigm = FilterBankLeftRightImagery(filters=filters)
Expand Down
68 changes: 49 additions & 19 deletions moabb/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@


PIPELINE_PALETTE = sea.color_palette("husl", 6)
sea.set_palette(PIPELINE_PALETTE)
sea.set(font='serif', style='whitegrid')
sea.set(font='serif', style='whitegrid', palette=PIPELINE_PALETTE)

log = logging.getLogger()

Expand All @@ -38,15 +37,16 @@ def score_plot(data, pipelines=None):
data = data[data.pipeline.isin(pipelines)]
fig = plt.figure(figsize=(8.5, 11))
ax = fig.add_subplot(111)
sea.stripplot(data=data, y="dataset", x="score", jitter=True, dodge=True,
hue="pipeline", zorder=1, alpha=0.7, ax=ax)
# sea.pointplot(data=data, y="score", x="dataset",
# hue="pipeline", zorder=1, ax=ax)
# sometimes the score is lower than 0.5 (...not sure how to deal with that)
# markers = ['o', '8', 's', 'p', '+', 'x', 'D', 'd', '>', '<', '^']
sea.stripplot(data=data, y="dataset", x="score", jitter=0.15,
palette=PIPELINE_PALETTE, hue='pipeline', dodge=True, ax=ax,
alpha=0.7)
ax.set_xlim([0, 1])
ax.axvline(0.5, linestyle='--', color='k', linewidth=2)
ax.set_title('Scores per dataset and algorithm')
handles, labels = ax.get_legend_handles_labels()
color_dict = {l: h.get_facecolor()[0] for l, h in zip(labels, handles)}
plt.tight_layout()
return fig, color_dict


Expand Down Expand Up @@ -81,15 +81,27 @@ def ordering_heatmap(sig_df, effect_df, p_threshold=0.05):
'''
effect_df.columns = effect_df.columns.map(_simplify_names)
sig_df.columns = sig_df.columns.map(_simplify_names)
annot_df = effect_df.copy()
for row in annot_df.index:
for col in annot_df.columns:
if effect_df.loc[row,col] > 0:
txt = '{:.2f}\np={:1.0e}'.format(effect_df.loc[row,col],
sig_df.loc[row,col])
else:
txt = ''
annot_df.loc[row, col] = txt
fig = plt.figure()
ax = fig.add_subplot(111)
sea.heatmap(data=effect_df, center=0, annot=True,
mask=(sig_df > p_threshold),
fmt="2.2f", cbar_kws={'label': 'Meta-effect'},
cmap=sea.light_palette("green"))
palette = sea.light_palette("green", as_cmap=True)
palette.set_under(color=[1,1,1])
sea.heatmap(data=-np.log(sig_df), annot=annot_df,
fmt='', cmap=palette, linewidths=1,
linecolor='0.8', annot_kws={'size': 10}, cbar=False,
vmin=-np.log(0.05))
for l in ax.get_xticklabels():
l.set_rotation(45)
ax.tick_params(axis='y', rotation=0.9)
ax.set_title("Algorithm comparison")
plt.tight_layout()
return fig

Expand All @@ -98,6 +110,15 @@ def meta_analysis_plot(stats_df, alg1, alg2):
'''A meta-analysis style plot that shows the standardized effect with
confidence intervals over all datasets for two algorithms.
Hypothesis is that alg1 is larger than alg2'''
def _marker(pval):
if pval < 0.001:
return '$***$', 100
elif pval < 0.01:
return '$**$', 70
elif pval < 0.05:
return '$*$', 30
else:
raise ValueError('insignificant pval {}'.format(pval))
assert (alg1 in stats_df.pipe1.unique())
assert (alg2 in stats_df.pipe1.unique())
df_fw = stats_df.loc[(stats_df.pipe1 == alg1) & (stats_df.pipe2 == alg2)]
Expand Down Expand Up @@ -144,14 +165,15 @@ def meta_analysis_plot(stats_df, alg1, alg2):
s=np.array([50] + [30]*len(dsets)),
marker='D',
c=['k'] + ['tab:grey']*len(dsets))
sig_ind = np.array(sig_ind)
ax.scatter(df_fw['smd'].iloc[sig_ind],
sig_ind + 1.4, s=20,
marker='*', c='r')
for i, p in zip(sig_ind, pvals):
m, s = _marker(p)
ax.scatter(df_fw['smd'].iloc[i],
i + 1.4, s=s,
marker=m, color='r')
# pvalues axis stuf
pval_ax.set_xlim([-0.1, 0.1])
pval_ax.grid(False)
pval_ax.set_title('p-value')
pval_ax.set_title('p-value', fontdict={'fontsize': 10})
pval_ax.set_xticks([])
for spine in pval_ax.spines.values():
spine.set_visible(False)
Expand All @@ -162,14 +184,18 @@ def meta_analysis_plot(stats_df, alg1, alg2):
if final_effect > 0:
p = combine_pvalues(df_fw['p'], df_fw['nsub'])
if p < 0.05:
ax.scatter([final_effect], [-0.4], s=20, marker='*', c='r')
m, s = _marker(p)
ax.scatter([final_effect], [-0.4], s=s,
marker=m, c='r')
pval_ax.text(0, 0, horizontalalignment='center',
verticalalignment='center',
s='{:.2e}'.format(p), fontsize=8)
else:
p = combine_pvalues(df_bk['p'], df_bk['nsub'])
if p < 0.05:
ax.scatter([final_effect], [-0.4], s=20, marker='*', c='r')
m, s = _marker(p)
ax.scatter([final_effect], [-0.4], s=s,
marker=m, c='r')
pval_ax.text(0, 0, horizontalalignment='center',
verticalalignment='center',
s='{:.2e}'.format(p), fontsize=8)
Expand All @@ -179,7 +205,11 @@ def meta_analysis_plot(stats_df, alg1, alg2):
ax.spines['right'].set_visible(False)
ax.axvline(0, linestyle='--', c='k')
ax.axhline(0.5, linestyle='-', linewidth=3, c='k')
ax.set_title('{} vs {}'.format(alg2, alg1))
title = '< {} better{}\n{}{} better >'.format(alg2,
' '*(45-len(alg2)),
' '*(45 - len(alg1)),
alg1)
ax.set_title(title, ha='left', ma='right', loc='left')
ax.set_xlabel('Standardized Mean Difference')
fig.tight_layout()

Expand Down
15 changes: 13 additions & 2 deletions moabb/analysis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,21 @@ def to_list(res):
d['time'],
d['n_samples']])

def to_dataframe(self):
def to_dataframe(self, pipelines=None):
df_list = []

# get the list of pipeline hash
digests = []
if pipelines is not None:
digests = [get_digest(pipelines[name]) for name in pipelines]

with h5py.File(self.filepath, 'r') as f:
for _, p_group in f.items():
for digest, p_group in f.items():

# skip if not in pipeline list
if (pipelines is not None) & (digest not in digests):
continue

name = p_group.attrs['name']
for dname, dset in p_group.items():
array = np.array(dset['data'])
Expand Down
5 changes: 4 additions & 1 deletion moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def process(self, pipelines):
for res in results:
self.push_result(res, pipelines)

return self.results.to_dataframe()
return self.results.to_dataframe(pipelines=pipelines)

def push_result(self, res, pipelines):
message = '{} | '.format(res['pipeline'])
Expand All @@ -126,6 +126,9 @@ def push_result(self, res, pipelines):
log.info(message)
self.results.add({res['pipeline']: res}, pipelines=pipelines)

def get_results(self):
return self.results.to_dataframe()

@abstractmethod
def evaluate(self, dataset, pipelines):
'''Evaluate results on a single dataset.
Expand Down
15 changes: 12 additions & 3 deletions moabb/pipelines/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@ def transform(self, X):

class FM(BaseEstimator, TransformerMixin):

def __init__(self, freq=128):
'''instantaneous frequencies require a sampling frequency to be properly
scaled,
which is helpful for some algorithms. This assumes 128 if not told
otherwise.
'''
self.freq = freq

def fit(self, X, y):
"""fit."""
return self

def transform(self, X):
"""transform. Note however that without the
sampling rate these values are unnormalized."""
"""transform. """
xphase = np.unwrap(np.angle(signal.hilbert(X, axis=-1)))
return np.median(np.diff(xphase, axis=-1) / (2 * np.pi), axis=-1)
return np.median(self.freq * np.diff(xphase, axis=-1) / (2 * np.pi),
axis=-1)
3 changes: 2 additions & 1 deletion moabb/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def generate_paradigms(pipeline_configs, context={}):
log.debug('{}: {}'.format(paradigm, context_params[paradigm]))
p = getattr(moabb_paradigms, paradigm)(**context_params[paradigm])
context = WithinSessionEvaluation(paradigm=p, random_state=42,
n_jobs=options.threads)
n_jobs=options.threads,
overwrite=options.force)
results = context.process(pipelines=paradigms[paradigm])
all_results.append(results)
analyze(pd.concat(all_results, ignore_index=True), options.output,
Expand Down

0 comments on commit 579bbe7

Please sign in to comment.