Skip to content

Commit b30adb9

Browse files
maple/agreement codes
1 parent 454c66a commit b30adb9

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

analysis/avg_agreement_final.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import seaborn as sns
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
data = {
6+
"meta-llama/Meta-Llama-3.1-8B-Instruct": [
7+
0.3533086666014079,
8+
0.052422082615756406
9+
],
10+
"cohere/c4ai-aya-23-35b": [
11+
0.43767196047824003,
12+
0.026040919354464294
13+
],
14+
"cohere/c4ai-aya-23-8b": [
15+
0.013483014909052663,
16+
0.03363706833599835
17+
],
18+
"cohere/command-r-08-2024": [
19+
0.374457668650282,
20+
0.02926089754079793
21+
],
22+
"cohere/command-r-plus-08-2024": [
23+
0.3830841816733316,
24+
0.020185255968455686
25+
],
26+
"google/gemma-1.1-7b-it": [
27+
0.5190375637539242,
28+
0.027757722654111305
29+
],
30+
"google/gemma-2-9b-it": [
31+
0.5181663123111222,
32+
0.031090119385244894
33+
],
34+
"meta-llama/Meta-Llama-3-70B-Instruct": [
35+
0.5685224105896568,
36+
0.04853344616275034
37+
],
38+
"meta-llama/Meta-Llama-3-8B-Instruct": [
39+
0.37936948540837095,
40+
0.032172769265151994
41+
],
42+
"meta-llama/Meta-Llama-3.1-70B-Instruct": [
43+
0.603536768244583,
44+
0.027191895488989915
45+
],
46+
"mistralai/Mistral-7B-Instruct-v0.2": [
47+
0.4071166722276529,
48+
0.04577594028555328
49+
],
50+
"mistralai/Mistral-7B-Instruct-v0.3": [
51+
0.41195018984687265,
52+
0.056184679972755454
53+
],
54+
"openai/gpt-4-turbo-2024-04-09": [
55+
0.6106943361444249,
56+
0.02932446842558468
57+
],
58+
"openai/gpt-4o-2024-05-13": [
59+
0.5833874065757011,
60+
0.023695391445384514
61+
]
62+
}
63+
64+
sorted_data = dict(sorted(data.items(), key=lambda item: item[1][0]))
65+
labels_sorted = list(sorted_data.keys())
66+
means_sorted = [v[0] for v in sorted_data.values()]
67+
std_devs_sorted = [v[1] for v in sorted_data.values()]
68+
69+
sns.set(style="whitegrid")
70+
palette = sns.color_palette("coolwarm", len(labels_sorted))
71+
72+
plt.figure(figsize=(10, 6))
73+
x_pos_sorted = np.arange(len(labels_sorted))
74+
75+
ax1 = sns.barplot(x=x_pos_sorted, y=means_sorted, palette=palette, errorbar=None)
76+
plt.errorbar(x_pos_sorted, means_sorted, yerr=std_devs_sorted, fmt='none', c='black', capsize=5)
77+
78+
ax1.spines['top'].set_color('black')
79+
ax1.spines['right'].set_color('black')
80+
ax1.spines['left'].set_color('black')
81+
ax1.spines['bottom'].set_color('black')
82+
for spine in ax1.spines.values():
83+
spine.set_linewidth(2) # Make the border thicker
84+
85+
plt.ylim(0, 0.8)
86+
87+
plt.xticks(x_pos_sorted, labels_sorted, rotation=90)
88+
plt.ylabel("Cohen's Kappa")
89+
plt.title('Average Inner-Model Agreement Across Languages')
90+
91+
plt.tight_layout()
92+
plt.savefig(f"./innermodel_agreement.pdf", bbox_inches='tight')

analysis/maple_results.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import json
2+
from pathlib import Path
3+
4+
import argparse
5+
import logging
6+
from pathlib import Path
7+
from typing import Optional
8+
9+
import pandas as pd
10+
import seaborn as sns
11+
import matplotlib.pyplot as plt
12+
from huggingface_hub import snapshot_download
13+
import datasets
14+
import json
15+
16+
import numpy as np
17+
import matplotlib.pyplot as plt
18+
from itertools import combinations
19+
from collections import defaultdict
20+
21+
22+
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
23+
24+
PLOT_PARAMS = {
25+
"font.family": "serif",
26+
"font.serif": ["Times New Roman", "STIX"],
27+
"font.size": FONT_SIZES.get("medium"),
28+
"axes.titlesize": FONT_SIZES.get("large"),
29+
"axes.labelsize": FONT_SIZES.get("large"),
30+
"xtick.labelsize": FONT_SIZES.get("large"),
31+
"ytick.labelsize": FONT_SIZES.get("small"),
32+
"legend.fontsize": FONT_SIZES.get("medium"),
33+
"figure.titlesize": FONT_SIZES.get("medium"),
34+
"text.usetex": False,
35+
}
36+
37+
logging.basicConfig(level=logging.INFO)
38+
39+
plt.rcParams.update(PLOT_PARAMS)
40+
41+
def load_json(json_file_path):
42+
with open(json_file_path, "r") as file:
43+
json_data = json.load(file)
44+
return json_data
45+
46+
results_dir = 'data/eval-results-maple'
47+
results_path = Path(results_dir)
48+
49+
results_all = []
50+
for result_file in results_path.glob("*.json"):
51+
raw_results = load_json(result_file)
52+
if "leaderboard" in raw_results.keys():
53+
model_id = raw_results["model"]
54+
subset_results = raw_results['subset']
55+
overall = raw_results['scores']['accuracy']
56+
remove_key = ['model', 'model_type', 'chat_template']
57+
for key in remove_key:
58+
del subset_results[key]
59+
elif "subset_results" in raw_results.keys():
60+
model_id = raw_results["model"]
61+
subset_results = raw_results['subset_results']
62+
overall = raw_results['accuracy']
63+
else:
64+
model_id = raw_results["model"]
65+
subset_results = raw_results['extra_results']
66+
overall = raw_results['accuracy']
67+
# print(model_id, overall)
68+
# print("\t", subset_results)
69+
# results_all.append([model_id, overall, subset_results])
70+
results_all.append({'Model': model_id, 'Avg': overall, **subset_results})
71+
72+
# import ipdb; ipdb.set_trace()
73+
74+
TOP = 10
75+
# results_all.sort(key=lambda x: x[1], reverse=True)
76+
# results_all = results_all[:TOP]
77+
# print(results_all)
78+
79+
df_results = pd.DataFrame(results_all)
80+
df_results = df_results.sort_values(by='Avg', ascending=False).reset_index(drop=True)
81+
df_results = df_results.head(10).reset_index(drop=True)
82+
83+
df_results.columns = df_results.columns.str.replace('^maple-', '', regex=True)
84+
df_results = df_results.set_index("Model")
85+
df_results = df_results * 100
86+
fig, ax = plt.subplots(1, 1, figsize=(18, 5))
87+
88+
sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16},
89+
fmt=".1f", cbar=False)
90+
91+
ax.xaxis.set_ticks_position("top")
92+
ax.tick_params(axis="x", labelrotation=45)
93+
ax.set_ylabel("")
94+
ax.set_yticklabels([f"{model} " for model in df_results.index])
95+
96+
plt.tight_layout()
97+
98+
plt.savefig("plots/maple.pdf", bbox_inches="tight")
99+
# import ipdb; ipdb.set_trace()
100+
101+

0 commit comments

Comments
 (0)