-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_reality_check.py
34 lines (23 loc) · 1.11 KB
/
test_reality_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from matplotlib import pyplot as plt
from src.ranking_metrics.embed_metrics import ClassBasedEmbeddingMetrics
def _example_helper(d, c, example_id, ks=(1, 10)) -> None:
print(f"Number of samples used in exmaple {example_id}: {d.shape[0]}")
m = ClassBasedEmbeddingMetrics(d.shape[-1], ks)
scores = m.forward(d, c, ignore_self=True)
cc = ["#219ebcff", "#ffb703ff"]
print_scores = ", ".join((f"{k}: {(float(v) * 100):.1f}%" for k, v in scores.items()))
_ = plt.figure(figsize=(10, 7))
plot_title = f"Example {example_id} of Musgrave et al. 2020. \n {print_scores}"
plt.title(plot_title)
plt.scatter(d[:, 0], d[:, 1], c=[cc[int(c_id)] for c_id in c])
plt.savefig(f"figures/reality_check_example{example_id}.png")
plt.close()
def test_class_reality_check_example_1(musgrave_example_1):
d, c = musgrave_example_1
_example_helper(d, c, 1)
def test_class_reality_check_example_2(musgrave_example_2):
d, c = musgrave_example_2
_example_helper(d, c, 2)
def test_class_reality_check_example_3(musgrave_example_3):
d, c = musgrave_example_3
_example_helper(d, c, 3)