Skip to content

Commit 07f6083

Browse files
Add English drop plots (#37)
1 parent 2461f48 commit 07f6083

File tree

1 file changed

+65
-7
lines changed

1 file changed

+65
-7
lines changed

analysis/plot_results.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
import pandas as pd
88
import seaborn as sns
9-
import numpy as np
10-
import matplotlib.patches as patches
119
import matplotlib.pyplot as plt
10+
from adjustText import adjust_text
1211

1312
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
1413

@@ -42,11 +41,10 @@ def get_args():
4241

4342
parser_main_results = subparsers.add_parser("main_heatmap", help="Plot results as a heatmap.", parents=[shared_args])
4443
parser_main_results.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
45-
parser_main_results.add_argument("--top_ten_only", action="store_true", help="If set, will only show the top-10 of all models.")
46-
parser_main_results.add_argument("--print_latex", action="store_true", help="If set, print LaTeX table.")
4744

4845
parser_eng_drop = subparsers.add_parser("eng_drop_line", help="Plot english drop as a line chart.", parents=[shared_args])
4946
parser_eng_drop.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
47+
parser_eng_drop.add_argument("--top_n", default=None, type=int, help="If set, will only show the .")
5048
# fmt: on
5149
return parser.parse_args()
5250

@@ -73,9 +71,7 @@ def _filter_args(func, kwargs):
7371

7472
def plot_main_heatmap(
7573
input_path: Path,
76-
output_path: Optional[Path] = None,
77-
top_ten_only: bool = False,
78-
print_latex: bool = False,
74+
output_path: Path,
7975
figsize: Optional[tuple[int, int]] = (18, 5),
8076
):
8177

@@ -99,5 +95,67 @@ def plot_main_heatmap(
9995
fig.savefig(output_path, bbox_inches="tight")
10096

10197

98+
def plot_eng_drop_line(
99+
input_path: Path,
100+
output_path: Path,
101+
figsize: Optional[tuple[int, int]] = (18, 5),
102+
top_n: Optional[int] = None,
103+
):
104+
from scipy.stats import pearsonr
105+
106+
df = pd.read_csv(input_path)
107+
df = df[["Model", "eng_Latn", "Avg_Multilingual"]]
108+
df = df.sort_values(by="Avg_Multilingual", ascending=False).reset_index(drop=True)
109+
data = df.set_index("Model").dropna() * 100
110+
if top_n:
111+
logging.info(f"Showing top {top_n}")
112+
data = data.head(top_n)
113+
114+
fig, ax = plt.subplots(figsize=figsize)
115+
mrewardbench_scores = data["Avg_Multilingual"]
116+
rewardbench_scores = data["eng_Latn"]
117+
r, _ = pearsonr(mrewardbench_scores, rewardbench_scores)
118+
ax.scatter(mrewardbench_scores, rewardbench_scores, marker="o", s=30, color="black")
119+
120+
min_val = min(mrewardbench_scores.min(), rewardbench_scores.min())
121+
max_val = max(mrewardbench_scores.max(), rewardbench_scores.max())
122+
ax.plot(
123+
[min_val, max_val],
124+
[min_val, max_val],
125+
linestyle="--",
126+
color="black",
127+
)
128+
ax.set_xlabel(f"M-RewardBench (Pearson r: {r:.2f})")
129+
ax.set_ylabel("RewardBench (Lambert et al., 2024)")
130+
ax.set_aspect("equal")
131+
132+
texts = [
133+
ax.text(
134+
mrewardbench_scores[idx],
135+
rewardbench_scores[idx],
136+
data.index[idx],
137+
fontsize=11,
138+
)
139+
for idx in range(len(data))
140+
]
141+
adjust_text(
142+
texts,
143+
ax=ax,
144+
# force_static=0.15,
145+
arrowprops=dict(arrowstyle="->", color="gray"),
146+
)
147+
148+
ax.spines["right"].set_visible(False)
149+
ax.spines["top"].set_visible(False)
150+
plt.tight_layout()
151+
fig.savefig(output_path, bbox_inches="tight")
152+
153+
logging.info("Showing top-10 models with biggest deltas")
154+
delta_df = data.copy(deep=True)
155+
delta_df["delta"] = delta_df["eng_Latn"] - delta_df["Avg_Multilingual"]
156+
delta_df = delta_df.sort_values(by="delta", ascending=False)
157+
print(delta_df.to_latex())
158+
159+
102160
if __name__ == "__main__":
103161
main()

0 commit comments

Comments
 (0)