Skip to content

Commit 64c72c0

Browse files
Add plot for NLLB vs Google Translate (#48)
1 parent 8b8fa99 commit 64c72c0

File tree

4 files changed

+131
-66
lines changed

4 files changed

+131
-66
lines changed

analysis/_plot_leaderboard.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from pathlib import Path
44
from typing import Optional
55

6+
import matplotlib.pyplot as plt
67
import pandas as pd
78
import seaborn as sns
8-
import matplotlib.pyplot as plt
99
from huggingface_hub import snapshot_download
1010

11-
from analysis.plot_utils import get_scores, PLOT_PARAMS
11+
from analysis.plot_utils import PLOT_PARAMS, get_scores
1212

1313
logging.basicConfig(level=logging.INFO)
1414

@@ -99,7 +99,8 @@ def main():
9999
output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.png"
100100
csv_output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.csv"
101101
data_to_cache = data.copy(deep=True)
102-
data_to_cache["eng_Latn"] = model_type_df["eng_Latn"]
102+
if "eng_Latn" in model_type_df.columns:
103+
data_to_cache["eng_Latn"] = model_type_df["eng_Latn"]
103104
data_to_cache = data_to_cache.rename(columns={"Avg": "Avg_Multilingual"})
104105
data_to_cache.to_csv(csv_output_file)
105106
fig.savefig(output_file, dpi=120)

analysis/avg_agreement_final.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import seaborn as sns
21
import matplotlib.pyplot as plt
32
import numpy as np
3+
import seaborn as sns
44

55
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
66
COLORS = {"green": "#355145", "purple": "#d8a6e5", "orange": "#fe7759"}

analysis/maple_results.py

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,87 @@
1-
import json
2-
from pathlib import Path
3-
41
import argparse
2+
import json
53
import logging
4+
from collections import defaultdict
5+
from itertools import combinations
66
from pathlib import Path
77
from typing import Optional
88

9+
import datasets
10+
import matplotlib.pyplot as plt
11+
import numpy as np
912
import pandas as pd
1013
import seaborn as sns
11-
import matplotlib.pyplot as plt
1214
from huggingface_hub import snapshot_download
13-
import datasets
14-
import json
15-
16-
import numpy as np
17-
import matplotlib.pyplot as plt
18-
from itertools import combinations
19-
from collections import defaultdict
2015

21-
2216
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
2317

2418
PLOT_PARAMS = {
25-
"font.family": "serif",
26-
"font.serif": ["Times New Roman", "STIX"],
27-
"font.size": FONT_SIZES.get("medium"),
28-
"axes.titlesize": FONT_SIZES.get("large"),
29-
"axes.labelsize": FONT_SIZES.get("large"),
30-
"xtick.labelsize": FONT_SIZES.get("large"),
31-
"ytick.labelsize": FONT_SIZES.get("small"),
32-
"legend.fontsize": FONT_SIZES.get("medium"),
33-
"figure.titlesize": FONT_SIZES.get("medium"),
34-
"text.usetex": False,
19+
"font.family": "serif",
20+
"font.serif": ["Times New Roman", "STIX"],
21+
"font.size": FONT_SIZES.get("medium"),
22+
"axes.titlesize": FONT_SIZES.get("large"),
23+
"axes.labelsize": FONT_SIZES.get("large"),
24+
"xtick.labelsize": FONT_SIZES.get("large"),
25+
"ytick.labelsize": FONT_SIZES.get("small"),
26+
"legend.fontsize": FONT_SIZES.get("medium"),
27+
"figure.titlesize": FONT_SIZES.get("medium"),
28+
"text.usetex": False,
3529
}
3630

3731
logging.basicConfig(level=logging.INFO)
3832

3933
plt.rcParams.update(PLOT_PARAMS)
4034

35+
4136
def load_json(json_file_path):
42-
with open(json_file_path, "r") as file:
43-
json_data = json.load(file)
44-
return json_data
37+
with open(json_file_path, "r") as file:
38+
json_data = json.load(file)
39+
return json_data
4540

46-
results_dir = 'data/eval-results-maple'
41+
42+
results_dir = "data/eval-results-maple"
4743
results_path = Path(results_dir)
4844

4945
results_all = []
5046
for result_file in results_path.glob("*.json"):
51-
raw_results = load_json(result_file)
52-
if "leaderboard" in raw_results.keys():
53-
model_id = raw_results["model"]
54-
subset_results = raw_results['subset']
55-
overall = raw_results['scores']['accuracy']
56-
remove_key = ['model', 'model_type', 'chat_template']
57-
for key in remove_key:
58-
del subset_results[key]
59-
elif "subset_results" in raw_results.keys():
60-
model_id = raw_results["model"]
61-
subset_results = raw_results['subset_results']
62-
overall = raw_results['accuracy']
63-
else:
64-
model_id = raw_results["model"]
65-
subset_results = raw_results['extra_results']
66-
overall = raw_results['accuracy']
67-
# print(model_id, overall)
68-
# print("\t", subset_results)
69-
# results_all.append([model_id, overall, subset_results])
70-
results_all.append({'Model': model_id, 'Avg': overall, **subset_results})
71-
72-
# import ipdb; ipdb.set_trace()
73-
74-
TOP = 10
47+
raw_results = load_json(result_file)
48+
if "leaderboard" in raw_results.keys():
49+
model_id = raw_results["model"]
50+
subset_results = raw_results["subset"]
51+
overall = raw_results["scores"]["accuracy"]
52+
remove_key = ["model", "model_type", "chat_template"]
53+
for key in remove_key:
54+
del subset_results[key]
55+
elif "subset_results" in raw_results.keys():
56+
model_id = raw_results["model"]
57+
subset_results = raw_results["subset_results"]
58+
overall = raw_results["accuracy"]
59+
else:
60+
model_id = raw_results["model"]
61+
subset_results = raw_results["extra_results"]
62+
overall = raw_results["accuracy"]
63+
# print(model_id, overall)
64+
# print("\t", subset_results)
65+
# results_all.append([model_id, overall, subset_results])
66+
results_all.append({"Model": model_id, "Avg": overall, **subset_results})
67+
68+
# import ipdb; ipdb.set_trace()
69+
70+
TOP = 10
7571
# results_all.sort(key=lambda x: x[1], reverse=True)
7672
# results_all = results_all[:TOP]
7773
# print(results_all)
7874

7975
df_results = pd.DataFrame(results_all)
80-
df_results = df_results.sort_values(by='Avg', ascending=False).reset_index(drop=True)
76+
df_results = df_results.sort_values(by="Avg", ascending=False).reset_index(drop=True)
8177
df_results = df_results.head(10).reset_index(drop=True)
8278

83-
df_results.columns = df_results.columns.str.replace('^maple-', '', regex=True)
79+
df_results.columns = df_results.columns.str.replace("^maple-", "", regex=True)
8480
df_results = df_results.set_index("Model")
8581
df_results = df_results * 100
8682
fig, ax = plt.subplots(1, 1, figsize=(18, 5))
8783

88-
sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16},
89-
fmt=".1f", cbar=False)
84+
sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, fmt=".1f", cbar=False)
9085

9186
ax.xaxis.set_ticks_position("top")
9287
ax.tick_params(axis="x", labelrotation=45)
@@ -97,5 +92,3 @@ def load_json(json_file_path):
9792

9893
plt.savefig("plots/maple.pdf", bbox_inches="tight")
9994
# import ipdb; ipdb.set_trace()
100-
101-

analysis/plot_results.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import argparse
22
import logging
3-
from pathlib import Path
43
from inspect import signature
4+
from pathlib import Path
55
from typing import Optional
66

7+
import matplotlib.pyplot as plt
78
import pandas as pd
89
import seaborn as sns
9-
import matplotlib.pyplot as plt
1010
from adjustText import adjust_text
1111

1212
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
@@ -90,6 +90,11 @@ def get_args():
9090
parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
9191
parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.")
9292
parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.")
93+
94+
parser_translate = subparsers.add_parser("translate", help="Plot translation quality.", parents=[shared_args])
95+
parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.")
96+
parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.")
97+
9398
# fmt: on
9499
return parser.parse_args()
95100

@@ -101,6 +106,7 @@ def main():
101106
"main_heatmap": plot_main_heatmap,
102107
"eng_drop_line": plot_eng_drop_line,
103108
"ling_dims": plot_ling_dims,
109+
"translate": plot_translate,
104110
}
105111

106112
def _filter_args(func, kwargs):
@@ -254,8 +260,6 @@ def plot_eng_drop_line(
254260
# # bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5"),
255261
# )
256262

257-
# ax.spines["right"].set_visible(False)
258-
# ax.spines["top"].set_visible(False)
259263
plt.tight_layout()
260264
fig.savefig(output_path, bbox_inches="tight")
261265

@@ -316,8 +320,75 @@ def plot_ling_dims(
316320
ax.set_ylabel("")
317321
ax.set_xlabel("M-RewardBench Score")
318322

319-
# ax.spines["right"].set_visible(False)
320-
# ax.spines["top"].set_visible(False)
323+
plt.tight_layout()
324+
fig.savefig(output_path, bbox_inches="tight")
325+
326+
327+
def plot_translate(
328+
gtrans: Path,
329+
nllb: Path,
330+
output_path: Path,
331+
figsize: Optional[tuple[int, int]] = (18, 5),
332+
):
333+
columns = ["Model", "Model_Type", "Avg_Multilingual"]
334+
gtrans_df = pd.read_csv(gtrans)[columns].rename(columns={"Avg_Multilingual": "Avg_Gtrans"})
335+
nllb_df = pd.read_csv(nllb)[columns].rename(columns={"Avg_Multilingual": "Avg_NLLB"})
336+
337+
combined = nllb_df.merge(gtrans_df, how="left", on="Model")
338+
combined = combined[["Model", "Avg_NLLB", "Avg_Gtrans", "Model_Type_x"]].rename(
339+
columns={"Model_Type_x": "Model_Type"}
340+
)
341+
342+
print(combined.sort_values(by="Avg_NLLB", ascending=False))
343+
344+
colors = {
345+
"Sequence Classifier": COLORS.get("green"),
346+
"Generative RM": COLORS.get("purple"),
347+
"DPO": COLORS.get("orange"),
348+
}
349+
350+
labels = {
351+
"Sequence Classifier": "Classifier RM",
352+
"Generative RM": "Generative RM",
353+
"DPO": "Implicit RM",
354+
}
355+
356+
fig, ax = plt.subplots(figsize=figsize)
357+
for _, row in combined.iterrows():
358+
ax.plot(
359+
[1, 2],
360+
[row["Avg_NLLB"], row["Avg_Gtrans"]],
361+
marker="o",
362+
color=colors[row["Model_Type"]],
363+
label=labels[row["Model_Type"]],
364+
)
365+
366+
# Avoid duplicate labels in the legend
367+
handles, labels = plt.gca().get_legend_handles_labels()
368+
by_label = dict(zip(labels, handles))
369+
ax.legend(
370+
by_label.values(),
371+
by_label.keys(),
372+
frameon=False,
373+
ncols=3,
374+
loc="lower center",
375+
bbox_to_anchor=(0.5, -0.2),
376+
)
377+
378+
# ax.grid(color="gray", alpha=0.2, which="both", axis="x")
379+
# ax.set_ylabel("M-RewardBench Overall Score")
380+
381+
ax.set_xticks([1, 2])
382+
ax.set_xticklabels(["NLLB", "Google Translate"])
383+
ax.yaxis.set_visible(False)
384+
385+
ax.spines[["top", "bottom", "left", "right"]].set_visible(False)
386+
ax.vlines(
387+
[1, 2],
388+
ymin=combined[["Avg_NLLB", "Avg_Gtrans"]].min().min(),
389+
ymax=combined[["Avg_NLLB", "Avg_Gtrans"]].max().max(),
390+
colors="gray",
391+
)
321392

322393
plt.tight_layout()
323394
fig.savefig(output_path, bbox_inches="tight")

0 commit comments

Comments
 (0)