Skip to content

Commit

Permalink
made mixture_and_plot provide a unique labeling
Browse files Browse the repository at this point in the history
  • Loading branch information
rsexton2 committed Feb 15, 2024
1 parent e494218 commit 88b227e
Showing 1 changed file with 7 additions and 31 deletions.
38 changes: 7 additions & 31 deletions basicrta/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,16 +670,13 @@ def mixture_and_plot(gibbs, method, **kwargs):
leg_labels = np.array([f'{i}' for i in uniq_labels])
predict_labels = r.predict(np.log(predict_data))

# sorts = r.precisions_.argsort()[::-1]
# tinds = np.array([np.where(labels == i)[0] for i in uniq_labels],
# dtype=object)
# pinds = np.array([np.where(predict_labels == i)[0] for i in uniq_labels],
# dtype=object)
# for i in uniq_labels:
# labels[tinds[i]] = sorts[i]
# predict_labels[pinds[i]] = sorts[i]
tinds = [np.where(labels == i)[0] for i in uniq_labels]
pinds = [np.where(predict_labels == i)[0] for i in uniq_labels]
sorts = r.precisions_.argsort()[::-1]
tinds = np.array([np.where(labels == i)[0] for i in uniq_labels],
dtype=object)
pinds = np.array([np.where(predict_labels == i)[0] for i in uniq_labels],
dtype=object)
tinds = tinds[sorts]
pinds = pinds[sorts]

train_data_inds = np.array([np.where(data == col)[0][0] for col in
train_data])
Expand Down Expand Up @@ -749,33 +746,12 @@ def mixture_and_plot(gibbs, method, **kwargs):
ax[1, 0].set_ylim(1e-4, 1)

handles, plot_labels = ax[0, 0].get_legend_handles_labels()
# sorts = np.argsort([int(i) for i in plot_labels])
# handles = np.array(handles)[sorts]
# plot_labels = np.array(plot_labels)[sorts]
[handle.set_color(cmap(get_color(int(i)))) for i, handle in
zip(plot_labels, handles)]
[handle.set_edgecolor('k') for i, handle in zip(plot_labels, handles)]
fig.legend(handles, plot_labels, loc='lower center',
ncols=len(plot_labels)/2, title='cluster')
fig.suptitle(f'{method} '+' '.join(keyvalpairs), fontsize=16)
plt.tight_layout(rect=(0, 0.05, 1, 1))
plt.savefig(f"{gibbs.residue}/results_{method}_{kwarg_str}.png",
bbox_inches='tight')
plt.show()
# tparams, pparams = [], []
# for i in uniq_labels:
# tinds = np.where(labels == i)[0]
# pinds = np.where(predict_labels == i)[0]
# tparams.append(trates[tinds].mean())
# pparams.append(prates[pinds].mean())
# tindex = np.where(tparams == np.min(tparams))[0]
# pindex = np.where(pparams == np.min(pparams))[0]
# clu_rates = np.concatenate([trates[labels == tindex],
# prates[predict_labels == pindex]])
# all_results = [(1/clu_rates).mean(), confidence_interval(1/clu_rates)]
# train_results = [(1/trates[labels == tindex]).mean(),
# confidence_interval(1/trates[labels == tindex])]
# predict_results = [(1/prates[predict_labels == pindex]).mean(),
# confidence_interval(1/prates[predict_labels == pindex])]
return r, all_labels

0 comments on commit 88b227e

Please sign in to comment.