Skip to content

Commit 8b8fa99

Browse files
Update all charts (#47)
1 parent a8222bb commit 8b8fa99

File tree

2 files changed

+119
-96
lines changed

2 files changed

+119
-96
lines changed

analysis/avg_agreement_final.py

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,91 +2,76 @@
22
import matplotlib.pyplot as plt
33
import numpy as np
44

5+
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
6+
COLORS = {"green": "#355145", "purple": "#d8a6e5", "orange": "#fe7759"}
7+
8+
PLOT_PARAMS = {
9+
"font.family": "serif",
10+
"font.serif": ["Times New Roman", "STIX"],
11+
"font.size": FONT_SIZES.get("medium"),
12+
"axes.titlesize": FONT_SIZES.get("large"),
13+
"axes.labelsize": FONT_SIZES.get("large"),
14+
"xtick.labelsize": FONT_SIZES.get("large"),
15+
"ytick.labelsize": FONT_SIZES.get("large"),
16+
"legend.fontsize": FONT_SIZES.get("medium"),
17+
"figure.titlesize": FONT_SIZES.get("medium"),
18+
"text.usetex": False,
19+
}
20+
21+
plt.rcParams.update(PLOT_PARAMS)
22+
23+
524
data = {
6-
"meta-llama/Meta-Llama-3.1-8B-Instruct": [
7-
0.3533086666014079,
8-
0.052422082615756406
9-
],
10-
"cohere/c4ai-aya-23-35b": [
11-
0.43767196047824003,
12-
0.026040919354464294
13-
],
14-
"cohere/c4ai-aya-23-8b": [
15-
0.013483014909052663,
16-
0.03363706833599835
17-
],
18-
"cohere/command-r-08-2024": [
19-
0.374457668650282,
20-
0.02926089754079793
21-
],
22-
"cohere/command-r-plus-08-2024": [
23-
0.3830841816733316,
24-
0.020185255968455686
25-
],
26-
"google/gemma-1.1-7b-it": [
27-
0.5190375637539242,
28-
0.027757722654111305
29-
],
30-
"google/gemma-2-9b-it": [
31-
0.5181663123111222,
32-
0.031090119385244894
33-
],
34-
"meta-llama/Meta-Llama-3-70B-Instruct": [
35-
0.5685224105896568,
36-
0.04853344616275034
37-
],
38-
"meta-llama/Meta-Llama-3-8B-Instruct": [
39-
0.37936948540837095,
40-
0.032172769265151994
41-
],
42-
"meta-llama/Meta-Llama-3.1-70B-Instruct": [
43-
0.603536768244583,
44-
0.027191895488989915
45-
],
46-
"mistralai/Mistral-7B-Instruct-v0.2": [
47-
0.4071166722276529,
48-
0.04577594028555328
49-
],
50-
"mistralai/Mistral-7B-Instruct-v0.3": [
51-
0.41195018984687265,
52-
0.056184679972755454
53-
],
54-
"openai/gpt-4-turbo-2024-04-09": [
55-
0.6106943361444249,
56-
0.02932446842558468
57-
],
58-
"openai/gpt-4o-2024-05-13": [
59-
0.5833874065757011,
60-
0.023695391445384514
61-
]
25+
"LlaMa 3.1 8B": [0.3533086666014079, 0.052422082615756406],
26+
"Aya 23 35B": [0.43767196047824003, 0.026040919354464294],
27+
# "Aya 23 8B": [0.013483014909052663, 0.03363706833599835],
28+
"Command R": [0.374457668650282, 0.02926089754079793],
29+
"Command R+": [0.3830841816733316, 0.020185255968455686],
30+
"Gemma 1.1 7B": [0.5190375637539242, 0.027757722654111305],
31+
"Gemma 2 9B": [0.5181663123111222, 0.031090119385244894],
32+
"LlaMa 3 70B": [0.5685224105896568, 0.04853344616275034],
33+
"LlaMa 3 8B": [0.37936948540837095, 0.032172769265151994],
34+
"LlaMa 3.1 70B": [0.603536768244583, 0.027191895488989915],
35+
"Mistal 7B v0.2": [0.4071166722276529, 0.04577594028555328],
36+
"Mistral 7B v0.3": [0.41195018984687265, 0.056184679972755454],
37+
"GPT-4 Turbo": [0.6106943361444249, 0.02932446842558468],
38+
"GPT-4o": [0.5833874065757011, 0.023695391445384514],
6239
}
6340

6441
sorted_data = dict(sorted(data.items(), key=lambda item: item[1][0]))
6542
labels_sorted = list(sorted_data.keys())
6643
means_sorted = [v[0] for v in sorted_data.values()]
6744
std_devs_sorted = [v[1] for v in sorted_data.values()]
6845

69-
sns.set(style="whitegrid")
70-
palette = sns.color_palette("coolwarm", len(labels_sorted))
46+
# sns.set(style="whitegrid")
47+
# palette = sns.color_palette("coolwarm", len(labels_sorted))
7148

72-
plt.figure(figsize=(10, 6))
49+
plt.figure(figsize=(7, 7))
7350
x_pos_sorted = np.arange(len(labels_sorted))
7451

75-
ax1 = sns.barplot(x=x_pos_sorted, y=means_sorted, palette=palette, errorbar=None)
76-
plt.errorbar(x_pos_sorted, means_sorted, yerr=std_devs_sorted, fmt='none', c='black', capsize=5)
52+
ax1 = sns.barplot(
53+
x=x_pos_sorted,
54+
y=means_sorted,
55+
errorbar=None,
56+
color=COLORS.get("orange"),
57+
edgecolor=COLORS.get("green"),
58+
)
59+
plt.errorbar(x_pos_sorted, means_sorted, yerr=std_devs_sorted, fmt="none", c="black", capsize=5)
7760

78-
ax1.spines['top'].set_color('black')
79-
ax1.spines['right'].set_color('black')
80-
ax1.spines['left'].set_color('black')
81-
ax1.spines['bottom'].set_color('black')
82-
for spine in ax1.spines.values():
83-
spine.set_linewidth(2) # Make the border thicker
61+
# ax1.spines["top"].set_color("black")
62+
# ax1.spines["right"].set_color("black")
63+
# ax1.spines["left"].set_color("black")
64+
# ax1.spines["bottom"].set_color("black")
65+
# for spine in ax1.spines.values():
66+
# spine.set_linewidth(2) # Make the border thicker
67+
plt.grid(color="gray", axis="y", alpha=0.2)
8468

8569
plt.ylim(0, 0.8)
70+
plt.gca().set_axisbelow(True)
8671

87-
plt.xticks(x_pos_sorted, labels_sorted, rotation=90)
72+
plt.xticks(x_pos_sorted, labels_sorted, rotation=45, ha="right")
8873
plt.ylabel("Cohen's Kappa")
89-
plt.title('Average Inner-Model Agreement Across Languages')
74+
plt.title("Average Inner-Model Agreement Across Languages")
9075

9176
plt.tight_layout()
92-
plt.savefig(f"./innermodel_agreement.pdf", bbox_inches='tight')
77+
plt.savefig("plots/innermodel_agreement_green_oracle.pdf", bbox_inches="tight")

analysis/plot_results.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
PLOT_PARAMS = {
1515
"font.family": "serif",
16-
"font.serif": ["Times New Roman", "STIX"],
16+
"font.serif": ["Times", "Times New Roman", "STIX"],
1717
"font.size": FONT_SIZES.get("medium"),
1818
"axes.titlesize": FONT_SIZES.get("large"),
1919
"axes.labelsize": FONT_SIZES.get("large"),
@@ -66,6 +66,8 @@
6666
"zho": "zh",
6767
}
6868

69+
COLORS = {"green": "#355145", "purple": "#d8a6e5", "orange": "#fe7759"}
70+
6971

7072
def get_args():
7173
# fmt: off
@@ -122,6 +124,7 @@ def plot_main_heatmap(
122124
df = pd.read_csv(input_path)
123125
# Remove unnecessary column
124126
df.pop("eng_Latn")
127+
df.pop("Family")
125128

126129
df = df.sort_values(by="Avg_Multilingual", ascending=False).head(10).reset_index(drop=True)
127130
data = df[[col for col in df.columns if col not in ["Model_Type"]]].rename(columns={"Avg_Multilingual": "Avg"})
@@ -133,14 +136,39 @@ def plot_main_heatmap(
133136
data.pop("zho_Hant")
134137
data = data[sorted(data.columns)]
135138
data.columns = [col.split("_")[0] for col in data.columns]
139+
data["Var"] = data[list(LANG_STANDARDIZATION.keys())].var(axis=1)
136140
data = data.rename(columns=LANG_STANDARDIZATION)
137141

138-
fig, ax = plt.subplots(1, 1, figsize=figsize)
139-
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, fmt=".2f", cbar=False)
140-
ax.xaxis.set_ticks_position("top")
141-
ax.tick_params(axis="x")
142-
ax.set_ylabel("")
143-
ax.set_yticklabels([f"{model} " for model in data.index])
142+
lang_results = data[list(LANG_STANDARDIZATION.values())]
143+
avg = data[["Avg"]]
144+
var = data[["Var"]]
145+
146+
fig, axs = plt.subplots(ncols=3, figsize=figsize, gridspec_kw={"width_ratios": [0.5, 0.5, 9]}, sharey=True)
147+
cmap = "Greys"
148+
fmt = ".1f"
149+
150+
sns.heatmap(avg, ax=axs[0], cmap=cmap, annot=True, annot_kws={"size": 16}, fmt=fmt, cbar=False)
151+
axs[0].xaxis.set_ticks_position("top")
152+
axs[0].set_xticklabels(avg.columns, fontsize=20)
153+
axs[0].tick_params(axis="x")
154+
axs[0].set_ylabel("")
155+
axs[0].set_yticklabels([f"{model} " for model in avg.index], fontsize=20)
156+
157+
sns.heatmap(var, ax=axs[1], cmap=cmap, annot=True, annot_kws={"size": 16}, fmt=fmt, cbar=False)
158+
axs[1].xaxis.set_ticks_position("top")
159+
axs[1].set_xticklabels(var.columns, fontsize=20)
160+
axs[1].tick_params(axis="x")
161+
axs[1].set_ylabel("")
162+
axs[1].tick_params(axis="y", length=0)
163+
axs[1].set_yticklabels([f"{model} " for model in var.index], fontsize=20)
164+
165+
sns.heatmap(lang_results, ax=axs[2], cmap=cmap, annot=True, annot_kws={"size": 16}, fmt=fmt, cbar=False)
166+
axs[2].xaxis.set_ticks_position("top")
167+
axs[2].set_xticklabels(lang_results.columns, fontsize=20)
168+
axs[2].tick_params(axis="x")
169+
axs[2].tick_params(axis="y", length=0)
170+
axs[2].set_ylabel("")
171+
axs[2].set_yticklabels([f"{model} " for model in lang_results.index], fontsize=20)
144172

145173
plt.tight_layout()
146174
fig.savefig(output_path, bbox_inches="tight")
@@ -155,7 +183,7 @@ def plot_eng_drop_line(
155183
from scipy.stats import pearsonr, spearmanr
156184

157185
df = pd.read_csv(input_path)
158-
df = df[["Model", "Model_Type", "eng_Latn", "Avg_Multilingual"]]
186+
df = df[["Model", "Model_Type", "Family", "eng_Latn", "Avg_Multilingual"]]
159187
df = df.sort_values(by="Avg_Multilingual", ascending=False).reset_index(drop=True)
160188
data = df.set_index("Model").dropna()
161189
data[data.select_dtypes(include="number").columns] = data.select_dtypes(include="number") * 100
@@ -166,11 +194,19 @@ def plot_eng_drop_line(
166194

167195
fig, ax = plt.subplots(figsize=figsize)
168196

169-
colors = ["red", "green", "blue"]
197+
colors = [COLORS.get("green"), COLORS.get("purple"), COLORS.get("orange")]
198+
markers = ["o", "*", "D"]
170199
for (label, group), color in zip(data.groupby("Model_Type"), colors):
171200
mrewardbench_scores = group["Avg_Multilingual"]
172201
rewardbench_scores = group["eng_Latn"]
173-
ax.scatter(rewardbench_scores, mrewardbench_scores, marker="o", s=40, label=label, color=color)
202+
ax.scatter(
203+
rewardbench_scores,
204+
mrewardbench_scores,
205+
marker="o",
206+
s=60,
207+
label=label,
208+
color=color,
209+
)
174210

175211
mrewardbench_scores = data["Avg_Multilingual"]
176212
rewardbench_scores = data["eng_Latn"]
@@ -188,22 +224,23 @@ def plot_eng_drop_line(
188224
ax.set_aspect("equal")
189225
ax.legend(frameon=False, handletextpad=0.2, fontsize=12)
190226

191-
model_names = [MODEL_STANDARDIZATION[model] for model in data.index]
192-
texts = [
193-
ax.text(
194-
rewardbench_scores[idx],
195-
mrewardbench_scores[idx],
196-
model_names[idx],
197-
fontsize=14,
227+
if top_n:
228+
model_names = [MODEL_STANDARDIZATION[model] for model in data.index]
229+
texts = [
230+
ax.text(
231+
rewardbench_scores[idx],
232+
mrewardbench_scores[idx],
233+
model_names[idx],
234+
fontsize=14,
235+
)
236+
for idx in range(len(data))
237+
]
238+
adjust_text(
239+
texts,
240+
ax=ax,
241+
# force_static=0.15,
242+
arrowprops=dict(arrowstyle="->", color="gray"),
198243
)
199-
for idx in range(len(data))
200-
]
201-
adjust_text(
202-
texts,
203-
ax=ax,
204-
# force_static=0.15,
205-
arrowprops=dict(arrowstyle="->", color="gray"),
206-
)
207244

208245
# ax.text(
209246
# 0.6,
@@ -270,7 +307,8 @@ def plot_ling_dims(
270307
y=dim,
271308
data=lingdf,
272309
ax=ax,
273-
color="green",
310+
color=COLORS.get("orange"),
311+
edgecolor=COLORS.get("green"),
274312
width=0.4 if dim == "Resource Availability" else 0.7,
275313
)
276314
ax.set_title(dim)

0 commit comments

Comments
 (0)