Skip to content

Commit

Permalink
[WIP] Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Oct 13, 2024
1 parent 9c9f42f commit 72e2852
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def get_args():
parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.")
parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.")

parser_translate = subparsers.add_parser("translate", help="Plot translation quality.", parents=[shared_args])
parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.")
parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.")

# fmt: on
return parser.parse_args()

Expand All @@ -101,6 +106,7 @@ def main():
"main_heatmap": plot_main_heatmap,
"eng_drop_line": plot_eng_drop_line,
"ling_dims": plot_ling_dims,
"translate": plot_translate,
}

def _filter_args(func, kwargs):
Expand Down Expand Up @@ -318,5 +324,75 @@ def plot_ling_dims(
fig.savefig(output_path, bbox_inches="tight")


def plot_translate(
gtrans: Path,
nllb: Path,
output_path: Path,
figsize: Optional[tuple[int, int]] = (18, 5),
):
columns = ["Model", "Model_Type", "Avg_Multilingual"]
gtrans_df = pd.read_csv(gtrans)[columns].rename(columns={"Avg_Multilingual": "Avg_Gtrans"})
nllb_df = pd.read_csv(nllb)[columns].rename(columns={"Avg_Multilingual": "Avg_NLLB"})

combined = nllb_df.merge(gtrans_df, how="left", on="Model")
combined = combined[["Model", "Avg_NLLB", "Avg_Gtrans", "Model_Type_x"]].rename(
columns={"Model_Type_x": "Model_Type"}
)

print(combined.sort_values(by="Avg_NLLB", ascending=False))

colors = {
"Sequence Classifier": COLORS.get("green"),
"Generative RM": COLORS.get("purple"),
"DPO": COLORS.get("orange"),
}

labels = {
"Sequence Classifier": "Classifier RM",
"Generative RM": "Generative RM",
"DPO": "Implicit RM",
}

fig, ax = plt.subplots(figsize=figsize)
for _, row in combined.iterrows():
ax.plot(
[1, 2],
[row["Avg_NLLB"], row["Avg_Gtrans"]],
marker="o",
color=colors[row["Model_Type"]],
label=labels[row["Model_Type"]],
)

# Avoid duplicate labels in the legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(
by_label.values(),
by_label.keys(),
frameon=False,
ncols=3,
loc="lower center",
bbox_to_anchor=(0.5, -0.2),
)

# ax.grid(color="gray", alpha=0.2, which="both", axis="x")
# ax.set_ylabel("M-RewardBench Overall Score")

ax.set_xticks([1, 2])
ax.set_xticklabels(["NLLB", "Google Translate"])
ax.yaxis.set_visible(False)

ax.spines[["top", "bottom", "left", "right"]].set_visible(False)
ax.vlines(
[1, 2],
ymin=combined[["Avg_NLLB", "Avg_Gtrans"]].min().min(),
ymax=combined[["Avg_NLLB", "Avg_Gtrans"]].max().max(),
colors="gray",
)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")


if __name__ == "__main__":
main()

0 comments on commit 72e2852

Please sign in to comment.