Skip to content

Commit 4b965ab

Browse files
Add disagreement histogram (#52)
1 parent c9deca3 commit 4b965ab

File tree

2 files changed

+178
-1
lines changed

2 files changed

+178
-1
lines changed

analysis/avg_agreement_final.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
3-
import seaborn as sns
43
import pandas as pd
4+
import seaborn as sns
55

66
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
77
COLORS = {"green": "#355145", "purple": "#d8a6e5", "orange": "#fe7759"}

analysis/compute_iaa.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import warnings
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
import seaborn as sns
7+
from sklearn.metrics import cohen_kappa_score
8+
from datasets import load_dataset
9+
10+
warnings.filterwarnings("ignore", category=RuntimeWarning, module="sklearn")
11+
12+
13+
FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
14+
COLORS = {
15+
"green": "#355145",
16+
"purple": "#d8a6e5",
17+
"orange": "#fe7759",
18+
"blue": "#4c6ee6",
19+
}
20+
21+
PLOT_PARAMS = {
22+
"font.family": "serif",
23+
"font.serif": ["Times New Roman", "STIX"],
24+
"font.size": FONT_SIZES.get("medium"),
25+
"axes.titlesize": FONT_SIZES.get("large"),
26+
"axes.labelsize": FONT_SIZES.get("large"),
27+
"xtick.labelsize": FONT_SIZES.get("large"),
28+
"ytick.labelsize": FONT_SIZES.get("large"),
29+
"legend.fontsize": FONT_SIZES.get("medium"),
30+
"figure.titlesize": FONT_SIZES.get("medium"),
31+
"text.usetex": False,
32+
}
33+
34+
LANG_STANDARDIZATION = {
35+
"arb": "ar",
36+
"ces": "cs",
37+
"deu": "de",
38+
"ell": "el",
39+
"fra": "fr",
40+
"heb": "he",
41+
"hin": "hi",
42+
"ind": "id",
43+
"ita": "it",
44+
"jpn": "jp",
45+
"kor": "kr",
46+
"nld": "nl",
47+
"pes": "fa",
48+
"pol": "pl",
49+
"por": "pt",
50+
"ron": "ro",
51+
"rus": "ru",
52+
"spa": "es",
53+
"tur": "tr",
54+
"ukr": "uk",
55+
"vie": "vi",
56+
"zho": "zh",
57+
}
58+
59+
SUBSET_MAPPING = {
60+
"Chat": [
61+
"alpacaeval-easy",
62+
"alpacaeval-length",
63+
"alpacaeval-hard",
64+
"mt-bench-easy",
65+
"mt-bench-med",
66+
],
67+
"Chat Hard": [
68+
"mt-bench-hard",
69+
"llmbar-natural",
70+
"llmbar-adver-neighbor",
71+
"llmbar-adver-GPTInst",
72+
"llmbar-adver-GPTOut",
73+
"llmbar-adver-manual",
74+
],
75+
"Safety": [
76+
"refusals-dangerous",
77+
"refusals-offensive",
78+
"xstest-should-refuse",
79+
"xstest-should-respond",
80+
"donotanswer",
81+
],
82+
"Reasoning": [
83+
"math-prm",
84+
"hep-cpp",
85+
"hep-go",
86+
"hep-java",
87+
"hep-js",
88+
"hep-python",
89+
"hep-rust",
90+
],
91+
}
92+
93+
94+
def find_key(d: dict[str, list[str]], value: str) -> str | None:
95+
for key, values in d.items():
96+
if value in values:
97+
return key
98+
return None
99+
100+
101+
plt.rcParams.update(PLOT_PARAMS)
102+
103+
# annotations = Path("data/hin_Deva_histogram.csv")
104+
lang = "hin_Deva"
105+
lang = "ind_Latn"
106+
annotations = Path(f"plots/{lang}_histogram.csv")
107+
reference = Path("plots/eng_Latn_histogram.csv")
108+
109+
annot_df = pd.read_csv(annotations).set_index("model").T
110+
ref_df = pd.read_csv(reference).set_index("model").T
111+
112+
cohen_scores: dict[str, float] = {}
113+
for (idx, annot), (_, ref) in zip(annot_df.iterrows(), ref_df.iterrows()):
114+
cohen_scores[idx] = cohen_kappa_score(annot.to_list(), ref.to_list(), labels=[0, 1, 2])
115+
116+
117+
df = pd.DataFrame([cohen_scores]).T.reset_index().rename(columns={0: "cohen", "index": "instance_id"}).dropna()
118+
sdf = load_dataset(
119+
"aya-rm-multilingual/multilingual-reward-bench-gtranslate", "ind_Latn", split="filtered"
120+
).to_pandas()
121+
sdf = sdf[["prompt", "chosen", "rejected", "subset", "id"]].rename(columns={"id": "instance_id"})
122+
sdf["instance_id"] = sdf["instance_id"].apply(lambda x: str(x))
123+
combi = df.merge(sdf, on="instance_id").sort_values(by="cohen", ascending=False).reset_index(drop=True)
124+
combi["category"] = combi["subset"].apply(lambda x: find_key(SUBSET_MAPPING, x))
125+
126+
# df_grouped = combi.groupby(["category", "cohen"]).count().groupby(level=0).apply(lambda x: x / x.sum()).reset_index()
127+
combi["count"] = 1
128+
129+
# Bin the data and compute percentages
130+
131+
132+
fig, axs = plt.subplots(
133+
2,
134+
1,
135+
figsize=(8, 6),
136+
sharex=True,
137+
gridspec_kw={"height_ratios": [5, 2]},
138+
)
139+
sns.histplot(
140+
df["cohen"],
141+
ax=axs[0],
142+
stat="count",
143+
fill=True,
144+
color=COLORS.get("orange"),
145+
)
146+
147+
axs[0].axvline(x=0, color=COLORS.get("green"), linestyle="--", linewidth=1)
148+
axs[0].axvline(x=0.60, color=COLORS.get("green"), linestyle="--", linewidth=1)
149+
150+
sns.histplot(
151+
data=combi,
152+
x="cohen",
153+
# weights="count",
154+
hue="category",
155+
multiple="fill",
156+
ax=axs[1],
157+
palette=[
158+
COLORS.get("green"),
159+
COLORS.get("purple"),
160+
COLORS.get("orange"),
161+
COLORS.get("blue"),
162+
],
163+
)
164+
165+
lang_code = LANG_STANDARDIZATION[lang.split("_")[0]]
166+
axs[1].set_xlabel(f"Cohen's Kappa (Language: {lang_code})")
167+
axs[1].set_ylabel("Percentage")
168+
169+
# annot_df["model_annotations"] = [i for i in annot_df.values]
170+
# annot_df["eng_reference"] = [i for i in ref_df.values]
171+
# annotations = annot_df[["model_annotations", "eng_reference"]].reset_index().rename(columns={"index": "instance_id"})
172+
# df = df.merge(annotations, how="left", on="instance_id")
173+
174+
axs[0].set_axisbelow(True)
175+
axs[0].grid(True, color="gray", axis="y", alpha=0.2)
176+
plt.tight_layout()
177+
plt.savefig(f"plots/cohen_k_histogram_{lang}.svg", bbox_inches="tight")

0 commit comments

Comments
 (0)