Skip to content

Commit 2461f48

Browse files
Merge pull request #36 from for-ai/updates
Add new plots
2 parents d729fe5 + 5f99e2d commit 2461f48

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed

analysis/plot_leaderboard.py renamed to analysis/_plot_leaderboard.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def get_args():
2727

2828

2929
def main():
30+
logging.warning(
31+
"This is a deprecated script. You can still run this to get the most up-to-date leaderboard results."
32+
"However, it is preferable and way easier to use the output CSVs from the earlier runs and start from there."
33+
"I highly recommend checking the analysis.plot_figures script for more information."
34+
)
3035
args = get_args()
3136
output_dir = Path(args.output_dir)
3237
if not output_dir.exists():
@@ -83,7 +88,10 @@ def main():
8388
fig.tight_layout()
8489
output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.png"
8590
csv_output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.csv"
86-
data.to_csv(csv_output_file)
91+
data_to_cache = data.copy(deep=True)
92+
data_to_cache["eng_Latn"] = model_type_df["eng_Latn"]
93+
data_to_cache = data_to_cache.rename(columns={"Avg": "Avg_Multilingual"})
94+
data_to_cache.to_csv(csv_output_file)
8795
fig.savefig(output_file, dpi=120)
8896
logging.info(f"Saved to {output_file}")
8997

analysis/plot_results.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import logging
3+
from pathlib import Path
4+
from inspect import signature
5+
from typing import Optional
6+
7+
import pandas as pd
8+
import seaborn as sns
9+
import numpy as np
10+
import matplotlib.patches as patches
11+
import matplotlib.pyplot as plt
12+
13+
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
14+
15+
PLOT_PARAMS = {
16+
"font.family": "serif",
17+
"font.serif": ["Times New Roman", "STIX"],
18+
"font.size": FONT_SIZES.get("medium"),
19+
"axes.titlesize": FONT_SIZES.get("large"),
20+
"axes.labelsize": FONT_SIZES.get("large"),
21+
"xtick.labelsize": FONT_SIZES.get("large"),
22+
"ytick.labelsize": FONT_SIZES.get("large"),
23+
"legend.fontsize": FONT_SIZES.get("medium"),
24+
"figure.titlesize": FONT_SIZES.get("medium"),
25+
"text.usetex": False,
26+
}
27+
28+
plt.rcParams.update(PLOT_PARAMS)
29+
30+
logging.basicConfig(level=logging.INFO)
31+
32+
33+
def get_args():
34+
# fmt: off
35+
parser = argparse.ArgumentParser()
36+
parser = argparse.ArgumentParser(description="Plotting utilities", formatter_class=argparse.RawDescriptionHelpFormatter)
37+
subparsers = parser.add_subparsers(dest="command")
38+
39+
shared_args = argparse.ArgumentParser(add_help=False)
40+
shared_args.add_argument("--output_path", type=Path, required=True, help="Path to save the PDF plot.")
41+
shared_args.add_argument("--figsize", type=int, nargs=2, default=[10, 10], help="Matplotlib figure size.")
42+
43+
parser_main_results = subparsers.add_parser("main_heatmap", help="Plot results as a heatmap.", parents=[shared_args])
44+
parser_main_results.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
45+
parser_main_results.add_argument("--top_ten_only", action="store_true", help="If set, will only show the top-10 of all models.")
46+
parser_main_results.add_argument("--print_latex", action="store_true", help="If set, print LaTeX table.")
47+
48+
parser_eng_drop = subparsers.add_parser("eng_drop_line", help="Plot english drop as a line chart.", parents=[shared_args])
49+
parser_eng_drop.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
50+
# fmt: on
51+
return parser.parse_args()
52+
53+
54+
def main():
55+
args = get_args()
56+
57+
cmd_map = {
58+
"main_heatmap": plot_main_heatmap,
59+
"eng_drop_line": plot_eng_drop_line,
60+
}
61+
62+
def _filter_args(func, kwargs):
63+
func_params = signature(func).parameters
64+
return {k: v for k, v in kwargs.items() if k in func_params}
65+
66+
if args.command in cmd_map:
67+
plot_fn = cmd_map[args.command]
68+
kwargs = _filter_args(plot_fn, vars(args))
69+
plot_fn(**kwargs)
70+
else:
71+
logging.error(f"Unknown plotting command: {args.command}")
72+
73+
74+
def plot_main_heatmap(
75+
input_path: Path,
76+
output_path: Optional[Path] = None,
77+
top_ten_only: bool = False,
78+
print_latex: bool = False,
79+
figsize: Optional[tuple[int, int]] = (18, 5),
80+
):
81+
82+
df = pd.read_csv(input_path)
83+
# Remove unnecessary column
84+
df.pop("eng_Latn")
85+
86+
df = df.sort_values(by="Avg_Multilingual", ascending=False).head(10).reset_index(drop=True)
87+
data = df[[col for col in df.columns if col not in ("Model_Type", "Avg_Multilingual")]]
88+
data = data.set_index("Model")
89+
data = data * 100
90+
91+
fig, ax = plt.subplots(1, 1, figsize=figsize)
92+
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 14}, fmt=".2f", cbar=False)
93+
ax.xaxis.set_ticks_position("top")
94+
ax.tick_params(axis="x", rotation=45)
95+
ax.set_ylabel("")
96+
ax.set_yticklabels([f"{model} " for model in data.index])
97+
98+
plt.tight_layout()
99+
fig.savefig(output_path, bbox_inches="tight")
100+
101+
102+
if __name__ == "__main__":
103+
main()

0 commit comments

Comments
 (0)