From 72e28527bbab416ef5a980e6f58ead0d69167bde Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Sat, 12 Oct 2024 21:03:13 -0700 Subject: [PATCH] [WIP] Update --- analysis/plot_results.py | 76 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/analysis/plot_results.py b/analysis/plot_results.py index 1a922d7..ae885e8 100644 --- a/analysis/plot_results.py +++ b/analysis/plot_results.py @@ -90,6 +90,11 @@ def get_args(): parser_ling_dims.add_argument("--input_path", type=Path, required=True, help="Path to the results file.") parser_ling_dims.add_argument("--langdata", type=Path, required=True, help="Path to the language data file.") parser_ling_dims.add_argument("--top_n", type=int, required=False, default=None, help="Aggregate only the scores for top-n.") + + parser_translate = subparsers.add_parser("translate", help="Plot translation quality.", parents=[shared_args]) + parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.") + parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.") + # fmt: on return parser.parse_args() @@ -101,6 +106,7 @@ def main(): "main_heatmap": plot_main_heatmap, "eng_drop_line": plot_eng_drop_line, "ling_dims": plot_ling_dims, + "translate": plot_translate, } def _filter_args(func, kwargs): @@ -318,5 +324,75 @@ def plot_ling_dims( fig.savefig(output_path, bbox_inches="tight") +def plot_translate( + gtrans: Path, + nllb: Path, + output_path: Path, + figsize: Optional[tuple[int, int]] = (18, 5), +): + columns = ["Model", "Model_Type", "Avg_Multilingual"] + gtrans_df = pd.read_csv(gtrans)[columns].rename(columns={"Avg_Multilingual": "Avg_Gtrans"}) + nllb_df = pd.read_csv(nllb)[columns].rename(columns={"Avg_Multilingual": "Avg_NLLB"}) + + combined = nllb_df.merge(gtrans_df, how="left", on="Model") + combined = combined[["Model", "Avg_NLLB", "Avg_Gtrans", "Model_Type_x"]].rename( + columns={"Model_Type_x": "Model_Type"} + ) + + print(combined.sort_values(by="Avg_NLLB", ascending=False)) + + colors = { + "Sequence Classifier": COLORS.get("green"), + "Generative RM": COLORS.get("purple"), + "DPO": COLORS.get("orange"), + } + + labels = { + "Sequence Classifier": "Classifier RM", + "Generative RM": "Generative RM", + "DPO": "Implicit RM", + } + + fig, ax = plt.subplots(figsize=figsize) + for _, row in combined.iterrows(): + ax.plot( + [1, 2], + [row["Avg_NLLB"], row["Avg_Gtrans"]], + marker="o", + color=colors[row["Model_Type"]], + label=labels[row["Model_Type"]], + ) + + # Avoid duplicate labels in the legend + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + ax.legend( + by_label.values(), + by_label.keys(), + frameon=False, + ncols=3, + loc="lower center", + bbox_to_anchor=(0.5, -0.2), + ) + + # ax.grid(color="gray", alpha=0.2, which="both", axis="x") + # ax.set_ylabel("M-RewardBench Overall Score") + + ax.set_xticks([1, 2]) + ax.set_xticklabels(["NLLB", "Google Translate"]) + ax.yaxis.set_visible(False) + + ax.spines[["top", "bottom", "left", "right"]].set_visible(False) + ax.vlines( + [1, 2], + ymin=combined[["Avg_NLLB", "Avg_Gtrans"]].min().min(), + ymax=combined[["Avg_NLLB", "Avg_Gtrans"]].max().max(), + colors="gray", + ) + + plt.tight_layout() + fig.savefig(output_path, bbox_inches="tight") + + if __name__ == "__main__": main()