|
| 1 | +import warnings |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import pandas as pd |
| 6 | +import seaborn as sns |
| 7 | +from sklearn.metrics import cohen_kappa_score |
| 8 | +from datasets import load_dataset |
| 9 | + |
| 10 | +warnings.filterwarnings("ignore", category=RuntimeWarning, module="sklearn") |
| 11 | + |
| 12 | + |
| 13 | +FONT_SIZES = {"small": 12, "medium": 16, "large": 18} |
| 14 | +COLORS = { |
| 15 | + "green": "#355145", |
| 16 | + "purple": "#d8a6e5", |
| 17 | + "orange": "#fe7759", |
| 18 | + "blue": "#4c6ee6", |
| 19 | +} |
| 20 | + |
| 21 | +PLOT_PARAMS = { |
| 22 | + "font.family": "serif", |
| 23 | + "font.serif": ["Times New Roman", "STIX"], |
| 24 | + "font.size": FONT_SIZES.get("medium"), |
| 25 | + "axes.titlesize": FONT_SIZES.get("large"), |
| 26 | + "axes.labelsize": FONT_SIZES.get("large"), |
| 27 | + "xtick.labelsize": FONT_SIZES.get("large"), |
| 28 | + "ytick.labelsize": FONT_SIZES.get("large"), |
| 29 | + "legend.fontsize": FONT_SIZES.get("medium"), |
| 30 | + "figure.titlesize": FONT_SIZES.get("medium"), |
| 31 | + "text.usetex": False, |
| 32 | +} |
| 33 | + |
| 34 | +LANG_STANDARDIZATION = { |
| 35 | + "arb": "ar", |
| 36 | + "ces": "cs", |
| 37 | + "deu": "de", |
| 38 | + "ell": "el", |
| 39 | + "fra": "fr", |
| 40 | + "heb": "he", |
| 41 | + "hin": "hi", |
| 42 | + "ind": "id", |
| 43 | + "ita": "it", |
| 44 | + "jpn": "jp", |
| 45 | + "kor": "kr", |
| 46 | + "nld": "nl", |
| 47 | + "pes": "fa", |
| 48 | + "pol": "pl", |
| 49 | + "por": "pt", |
| 50 | + "ron": "ro", |
| 51 | + "rus": "ru", |
| 52 | + "spa": "es", |
| 53 | + "tur": "tr", |
| 54 | + "ukr": "uk", |
| 55 | + "vie": "vi", |
| 56 | + "zho": "zh", |
| 57 | +} |
| 58 | + |
| 59 | +SUBSET_MAPPING = { |
| 60 | + "Chat": [ |
| 61 | + "alpacaeval-easy", |
| 62 | + "alpacaeval-length", |
| 63 | + "alpacaeval-hard", |
| 64 | + "mt-bench-easy", |
| 65 | + "mt-bench-med", |
| 66 | + ], |
| 67 | + "Chat Hard": [ |
| 68 | + "mt-bench-hard", |
| 69 | + "llmbar-natural", |
| 70 | + "llmbar-adver-neighbor", |
| 71 | + "llmbar-adver-GPTInst", |
| 72 | + "llmbar-adver-GPTOut", |
| 73 | + "llmbar-adver-manual", |
| 74 | + ], |
| 75 | + "Safety": [ |
| 76 | + "refusals-dangerous", |
| 77 | + "refusals-offensive", |
| 78 | + "xstest-should-refuse", |
| 79 | + "xstest-should-respond", |
| 80 | + "donotanswer", |
| 81 | + ], |
| 82 | + "Reasoning": [ |
| 83 | + "math-prm", |
| 84 | + "hep-cpp", |
| 85 | + "hep-go", |
| 86 | + "hep-java", |
| 87 | + "hep-js", |
| 88 | + "hep-python", |
| 89 | + "hep-rust", |
| 90 | + ], |
| 91 | +} |
| 92 | + |
| 93 | + |
| 94 | +def find_key(d: dict[str, list[str]], value: str) -> str | None: |
| 95 | + for key, values in d.items(): |
| 96 | + if value in values: |
| 97 | + return key |
| 98 | + return None |
| 99 | + |
| 100 | + |
| 101 | +plt.rcParams.update(PLOT_PARAMS) |
| 102 | + |
| 103 | +# annotations = Path("data/hin_Deva_histogram.csv") |
| 104 | +lang = "hin_Deva" |
| 105 | +lang = "ind_Latn" |
| 106 | +annotations = Path(f"plots/{lang}_histogram.csv") |
| 107 | +reference = Path("plots/eng_Latn_histogram.csv") |
| 108 | + |
| 109 | +annot_df = pd.read_csv(annotations).set_index("model").T |
| 110 | +ref_df = pd.read_csv(reference).set_index("model").T |
| 111 | + |
| 112 | +cohen_scores: dict[str, float] = {} |
| 113 | +for (idx, annot), (_, ref) in zip(annot_df.iterrows(), ref_df.iterrows()): |
| 114 | + cohen_scores[idx] = cohen_kappa_score(annot.to_list(), ref.to_list(), labels=[0, 1, 2]) |
| 115 | + |
| 116 | + |
| 117 | +df = pd.DataFrame([cohen_scores]).T.reset_index().rename(columns={0: "cohen", "index": "instance_id"}).dropna() |
| 118 | +sdf = load_dataset( |
| 119 | + "aya-rm-multilingual/multilingual-reward-bench-gtranslate", "ind_Latn", split="filtered" |
| 120 | +).to_pandas() |
| 121 | +sdf = sdf[["prompt", "chosen", "rejected", "subset", "id"]].rename(columns={"id": "instance_id"}) |
| 122 | +sdf["instance_id"] = sdf["instance_id"].apply(lambda x: str(x)) |
| 123 | +combi = df.merge(sdf, on="instance_id").sort_values(by="cohen", ascending=False).reset_index(drop=True) |
| 124 | +combi["category"] = combi["subset"].apply(lambda x: find_key(SUBSET_MAPPING, x)) |
| 125 | + |
| 126 | +# df_grouped = combi.groupby(["category", "cohen"]).count().groupby(level=0).apply(lambda x: x / x.sum()).reset_index() |
| 127 | +combi["count"] = 1 |
| 128 | + |
| 129 | +# Bin the data and compute percentages |
| 130 | + |
| 131 | + |
| 132 | +fig, axs = plt.subplots( |
| 133 | + 2, |
| 134 | + 1, |
| 135 | + figsize=(8, 6), |
| 136 | + sharex=True, |
| 137 | + gridspec_kw={"height_ratios": [5, 2]}, |
| 138 | +) |
| 139 | +sns.histplot( |
| 140 | + df["cohen"], |
| 141 | + ax=axs[0], |
| 142 | + stat="count", |
| 143 | + fill=True, |
| 144 | + color=COLORS.get("orange"), |
| 145 | +) |
| 146 | + |
| 147 | +axs[0].axvline(x=0, color=COLORS.get("green"), linestyle="--", linewidth=1) |
| 148 | +axs[0].axvline(x=0.60, color=COLORS.get("green"), linestyle="--", linewidth=1) |
| 149 | + |
| 150 | +sns.histplot( |
| 151 | + data=combi, |
| 152 | + x="cohen", |
| 153 | + # weights="count", |
| 154 | + hue="category", |
| 155 | + multiple="fill", |
| 156 | + ax=axs[1], |
| 157 | + palette=[ |
| 158 | + COLORS.get("green"), |
| 159 | + COLORS.get("purple"), |
| 160 | + COLORS.get("orange"), |
| 161 | + COLORS.get("blue"), |
| 162 | + ], |
| 163 | +) |
| 164 | + |
| 165 | +lang_code = LANG_STANDARDIZATION[lang.split("_")[0]] |
| 166 | +axs[1].set_xlabel(f"Cohen's Kappa (Language: {lang_code})") |
| 167 | +axs[1].set_ylabel("Percentage") |
| 168 | + |
| 169 | +# annot_df["model_annotations"] = [i for i in annot_df.values] |
| 170 | +# annot_df["eng_reference"] = [i for i in ref_df.values] |
| 171 | +# annotations = annot_df[["model_annotations", "eng_reference"]].reset_index().rename(columns={"index": "instance_id"}) |
| 172 | +# df = df.merge(annotations, how="left", on="instance_id") |
| 173 | + |
| 174 | +axs[0].set_axisbelow(True) |
| 175 | +axs[0].grid(True, color="gray", axis="y", alpha=0.2) |
| 176 | +plt.tight_layout() |
| 177 | +plt.savefig(f"plots/cohen_k_histogram_{lang}.svg", bbox_inches="tight") |
0 commit comments