-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathranking_test.py
126 lines (95 loc) · 5.11 KB
/
ranking_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import matplotlib.pyplot as plt
import numpy as np
import json
import argparse
from scipy.stats import spearmanr
def plot_spearman_correlations(correlations, save_path=None, xlabel=''):
"""
Plots a boxplot of Spearman correlations and overlays individual data points
with jitter for visibility and transparency for overlapping points.
Parameters:
correlations (list or numpy array): List of 57 Spearman correlation coefficients.
"""
# Create figure and axis
fig, ax = plt.subplots(figsize=(8, 6))
# Create boxplot
ax.boxplot(correlations, vert=True, patch_artist=True,
boxprops=dict(facecolor='lightblue', color='black'),
medianprops=dict(color='black'))
# Add jittered individual points
# Create some jitter by adding small random noise to the x-coordinate
jitter = 0.04 * (np.random.rand(len(correlations)) - 0.5) # Adding jitter to the x-axis
x_values = np.ones(len(correlations)) + jitter # x-values for the points, centered around 1 (single boxplot)
# Scatter the points with transparency (alpha)
ax.scatter(x_values, correlations, color='black', alpha=0.5)
# Label the axes
ax.set_ylabel("Spearman Correlation Coefficient", fontsize=12)
ax.set_xticklabels([xlabel], fontsize=12)
ax.set_title("Distribution of Spearman Correlations across 57 Clusters", fontsize=14)
# Save the plot at 300 dpi if a path is provided
if save_path:
plt.savefig(save_path, dpi=300)
def compute_pearson_correlations_in_clusters(casf2016_predictions):
predicted_ids = list(casf2016_predictions.keys())
spearman_correlations = []
with open('clusters_casf2016.json') as f:
clusters = json.load(f)
for cluster in clusters:
data = clusters[cluster]
ids = []
for lst in data:
if lst[0] not in predicted_ids:
print(f"Warning: {lst[0]} not found in predictions.")
else:
ids.append(lst[0])
# Extract the true and predicted scores for the cluster
true_scores = [data[i][1] for i in range(len(data)) if data[i][0] in ids]
predicted_scores = [casf2016_predictions[id][1] for id in ids]
# Calculate the Spearman correlation
spearman_correlation, _ = spearmanr(true_scores, predicted_scores)
spearman_correlations.append(spearman_correlation)
return spearman_correlations
# Initialize the argument parser
parser = argparse.ArgumentParser(description="Plot Spearman correlations for CASF-2016 dataset")
parser.add_argument("model_path",
type=str,
help="Either the path to the folder containing the model prediction files for all random seeds \
or the path to the predictions file of a specific model (json).")
args = parser.parse_args()
model_path = args.model_path
# LOAD THE PREDICTIONS FROM THE MODEL ###
# -------------------------------------------------------------------------------------------------------------
# If the model path is a specific predictions file of a specific model, load the predictions
if model_path.endswith('.json'):
with open(model_path) as f:
casf2016_predictions = json.load(f)
spearman_correlations = compute_pearson_correlations_in_clusters(casf2016_predictions)
# SAVE PEARSON CORRELATIONS TO A FILE AT MODEL PATH
save_path = model_path.replace('.json', '_spearman_correlations.json')
with open(save_path, 'w') as f:
json.dump(spearman_correlations, f)
# Plot the Spearman correlations and save the plot where the prediction file is located
save_path = model_path.replace('.json', '_spearman_correlations.png')
plot_spearman_correlations(spearman_correlations, save_path, xlabel=model_path.split('/')[-1])
# If the model path is a folder containing predictions for all random seeds, load the predictions for each seed
else:
casf2016_predictions = {}
for random_seed in range(0, 5):
predictions_path = f'{model_path}_{random_seed}/dataset_casf2016_predictions.json'
with open(predictions_path) as f:
fold_predictions = json.load(f)
for complex in fold_predictions:
if complex not in casf2016_predictions:
casf2016_predictions[complex] = [fold_predictions[complex]]
else:
casf2016_predictions[complex].append(fold_predictions[complex])
# Summarize the saved predictions into average values for each complex
for complex in casf2016_predictions:
casf2016_predictions[complex] = sum(casf2016_predictions[complex]) / len(casf2016_predictions[complex])
spearman_correlations = compute_pearson_correlations_in_clusters(casf2016_predictions)
# SAVE PEARSON CORRELATIONS TO A FILE AT MODEL PATH
with open(f'{model_path}_spearman_correlations.json', 'w') as f:
json.dump(spearman_correlations, f)
# Plot the Spearman correlations and save the plot where the prediction file is located
save_path = f"{model_path}_spearman_correlations.png"
plot_spearman_correlations(spearman_correlations, save_path, xlabel=model_path.split('/')[-1])