13
13
14
14
PLOT_PARAMS = {
15
15
"font.family" : "serif" ,
16
- "font.serif" : ["Times New Roman" , "STIX" ],
16
+ "font.serif" : ["Times" , "Times New Roman" , "STIX" ],
17
17
"font.size" : FONT_SIZES .get ("medium" ),
18
18
"axes.titlesize" : FONT_SIZES .get ("large" ),
19
19
"axes.labelsize" : FONT_SIZES .get ("large" ),
66
66
"zho" : "zh" ,
67
67
}
68
68
69
+ COLORS = {"green" : "#355145" , "purple" : "#d8a6e5" , "orange" : "#fe7759" }
70
+
69
71
70
72
def get_args ():
71
73
# fmt: off
@@ -122,6 +124,7 @@ def plot_main_heatmap(
122
124
df = pd .read_csv (input_path )
123
125
# Remove unnecessary column
124
126
df .pop ("eng_Latn" )
127
+ df .pop ("Family" )
125
128
126
129
df = df .sort_values (by = "Avg_Multilingual" , ascending = False ).head (10 ).reset_index (drop = True )
127
130
data = df [[col for col in df .columns if col not in ["Model_Type" ]]].rename (columns = {"Avg_Multilingual" : "Avg" })
@@ -133,14 +136,39 @@ def plot_main_heatmap(
133
136
data .pop ("zho_Hant" )
134
137
data = data [sorted (data .columns )]
135
138
data .columns = [col .split ("_" )[0 ] for col in data .columns ]
139
+ data ["Var" ] = data [list (LANG_STANDARDIZATION .keys ())].var (axis = 1 )
136
140
data = data .rename (columns = LANG_STANDARDIZATION )
137
141
138
- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
139
- sns .heatmap (data , ax = ax , cmap = "YlGn" , annot = True , annot_kws = {"size" : 16 }, fmt = ".2f" , cbar = False )
140
- ax .xaxis .set_ticks_position ("top" )
141
- ax .tick_params (axis = "x" )
142
- ax .set_ylabel ("" )
143
- ax .set_yticklabels ([f"{ model } " for model in data .index ])
142
+ lang_results = data [list (LANG_STANDARDIZATION .values ())]
143
+ avg = data [["Avg" ]]
144
+ var = data [["Var" ]]
145
+
146
+ fig , axs = plt .subplots (ncols = 3 , figsize = figsize , gridspec_kw = {"width_ratios" : [0.5 , 0.5 , 9 ]}, sharey = True )
147
+ cmap = "Greys"
148
+ fmt = ".1f"
149
+
150
+ sns .heatmap (avg , ax = axs [0 ], cmap = cmap , annot = True , annot_kws = {"size" : 16 }, fmt = fmt , cbar = False )
151
+ axs [0 ].xaxis .set_ticks_position ("top" )
152
+ axs [0 ].set_xticklabels (avg .columns , fontsize = 20 )
153
+ axs [0 ].tick_params (axis = "x" )
154
+ axs [0 ].set_ylabel ("" )
155
+ axs [0 ].set_yticklabels ([f"{ model } " for model in avg .index ], fontsize = 20 )
156
+
157
+ sns .heatmap (var , ax = axs [1 ], cmap = cmap , annot = True , annot_kws = {"size" : 16 }, fmt = fmt , cbar = False )
158
+ axs [1 ].xaxis .set_ticks_position ("top" )
159
+ axs [1 ].set_xticklabels (var .columns , fontsize = 20 )
160
+ axs [1 ].tick_params (axis = "x" )
161
+ axs [1 ].set_ylabel ("" )
162
+ axs [1 ].tick_params (axis = "y" , length = 0 )
163
+ axs [1 ].set_yticklabels ([f"{ model } " for model in var .index ], fontsize = 20 )
164
+
165
+ sns .heatmap (lang_results , ax = axs [2 ], cmap = cmap , annot = True , annot_kws = {"size" : 16 }, fmt = fmt , cbar = False )
166
+ axs [2 ].xaxis .set_ticks_position ("top" )
167
+ axs [2 ].set_xticklabels (lang_results .columns , fontsize = 20 )
168
+ axs [2 ].tick_params (axis = "x" )
169
+ axs [2 ].tick_params (axis = "y" , length = 0 )
170
+ axs [2 ].set_ylabel ("" )
171
+ axs [2 ].set_yticklabels ([f"{ model } " for model in lang_results .index ], fontsize = 20 )
144
172
145
173
plt .tight_layout ()
146
174
fig .savefig (output_path , bbox_inches = "tight" )
@@ -155,7 +183,7 @@ def plot_eng_drop_line(
155
183
from scipy .stats import pearsonr , spearmanr
156
184
157
185
df = pd .read_csv (input_path )
158
- df = df [["Model" , "Model_Type" , "eng_Latn" , "Avg_Multilingual" ]]
186
+ df = df [["Model" , "Model_Type" , "Family" , " eng_Latn" , "Avg_Multilingual" ]]
159
187
df = df .sort_values (by = "Avg_Multilingual" , ascending = False ).reset_index (drop = True )
160
188
data = df .set_index ("Model" ).dropna ()
161
189
data [data .select_dtypes (include = "number" ).columns ] = data .select_dtypes (include = "number" ) * 100
@@ -166,11 +194,19 @@ def plot_eng_drop_line(
166
194
167
195
fig , ax = plt .subplots (figsize = figsize )
168
196
169
- colors = ["red" , "green" , "blue" ]
197
+ colors = [COLORS .get ("green" ), COLORS .get ("purple" ), COLORS .get ("orange" )]
198
+ markers = ["o" , "*" , "D" ]
170
199
for (label , group ), color in zip (data .groupby ("Model_Type" ), colors ):
171
200
mrewardbench_scores = group ["Avg_Multilingual" ]
172
201
rewardbench_scores = group ["eng_Latn" ]
173
- ax .scatter (rewardbench_scores , mrewardbench_scores , marker = "o" , s = 40 , label = label , color = color )
202
+ ax .scatter (
203
+ rewardbench_scores ,
204
+ mrewardbench_scores ,
205
+ marker = "o" ,
206
+ s = 60 ,
207
+ label = label ,
208
+ color = color ,
209
+ )
174
210
175
211
mrewardbench_scores = data ["Avg_Multilingual" ]
176
212
rewardbench_scores = data ["eng_Latn" ]
@@ -188,22 +224,23 @@ def plot_eng_drop_line(
188
224
ax .set_aspect ("equal" )
189
225
ax .legend (frameon = False , handletextpad = 0.2 , fontsize = 12 )
190
226
191
- model_names = [MODEL_STANDARDIZATION [model ] for model in data .index ]
192
- texts = [
193
- ax .text (
194
- rewardbench_scores [idx ],
195
- mrewardbench_scores [idx ],
196
- model_names [idx ],
197
- fontsize = 14 ,
227
+ if top_n :
228
+ model_names = [MODEL_STANDARDIZATION [model ] for model in data .index ]
229
+ texts = [
230
+ ax .text (
231
+ rewardbench_scores [idx ],
232
+ mrewardbench_scores [idx ],
233
+ model_names [idx ],
234
+ fontsize = 14 ,
235
+ )
236
+ for idx in range (len (data ))
237
+ ]
238
+ adjust_text (
239
+ texts ,
240
+ ax = ax ,
241
+ # force_static=0.15,
242
+ arrowprops = dict (arrowstyle = "->" , color = "gray" ),
198
243
)
199
- for idx in range (len (data ))
200
- ]
201
- adjust_text (
202
- texts ,
203
- ax = ax ,
204
- # force_static=0.15,
205
- arrowprops = dict (arrowstyle = "->" , color = "gray" ),
206
- )
207
244
208
245
# ax.text(
209
246
# 0.6,
@@ -270,7 +307,8 @@ def plot_ling_dims(
270
307
y = dim ,
271
308
data = lingdf ,
272
309
ax = ax ,
273
- color = "green" ,
310
+ color = COLORS .get ("orange" ),
311
+ edgecolor = COLORS .get ("green" ),
274
312
width = 0.4 if dim == "Resource Availability" else 0.7 ,
275
313
)
276
314
ax .set_title (dim )
0 commit comments