1
1
# type: ignore
2
2
# pylint: disable=no-value-for-parameter,import-outside-toplevel,import-error
3
3
4
- from kfp .dsl import Artifact , Output , component
4
+ from kfp .dsl import Artifact , Input , Metrics , Output , component
5
5
6
- from utils .consts import RHELAI_IMAGE
6
+ from utils .consts import PYTHON_IMAGE , RHELAI_IMAGE
7
7
8
8
9
9
@component (base_image = RHELAI_IMAGE , install_kfp_package = False )
10
10
def run_final_eval_op (
11
- mmlu_branch_output : Output [Artifact ],
12
- mt_bench_branch_output : Output [Artifact ],
13
11
base_model_dir : str ,
14
12
base_branch : str ,
15
13
candidate_branch : str ,
@@ -20,6 +18,8 @@ def run_final_eval_op(
20
18
candidate_model : str = None ,
21
19
taxonomy_path : str = "/input/taxonomy" ,
22
20
sdg_path : str = "/input/sdg" ,
21
+ mmlu_branch_output_path : str = "/output/mmlu_branch" ,
22
+ mt_bench_branch_output_path : str = "/output/mt_bench_branch" ,
23
23
):
24
24
import json
25
25
import os
@@ -326,8 +326,13 @@ def find_node_dataset_directories(base_dir: str):
326
326
"summary" : summary ,
327
327
}
328
328
329
- with open (mmlu_branch_output .path , "w" , encoding = "utf-8" ) as f :
329
+ if not os .path .exists (mmlu_branch_output_path ):
330
+ os .makedirs (mmlu_branch_output_path )
331
+ with open (
332
+ f"{ mmlu_branch_output_path } /mmlu_branch_data.json" , "w" , encoding = "utf-8"
333
+ ) as f :
330
334
json .dump (mmlu_branch_data , f , indent = 4 )
335
+
331
336
else :
332
337
print ("No MMLU tasks directories found, skipping MMLU_branch evaluation." )
333
338
@@ -470,5 +475,41 @@ def find_node_dataset_directories(base_dir: str):
470
475
"summary" : summary ,
471
476
}
472
477
473
- with open (mt_bench_branch_output .path , "w" , encoding = "utf-8" ) as f :
478
+ if not os .path .exists (mt_bench_branch_output_path ):
479
+ os .makedirs (mt_bench_branch_output_path )
480
+ with open (
481
+ f"{ mt_bench_branch_output_path } /mt_bench_branch_data.json" ,
482
+ "w" ,
483
+ encoding = "utf-8" ,
484
+ ) as f :
474
485
json .dump (mt_bench_branch_data , f , indent = 4 )
486
+
487
+
488
+ @component (base_image = PYTHON_IMAGE , install_kfp_package = False )
489
+ def generate_metrics_report_op (
490
+ metrics : Output [Metrics ],
491
+ ):
492
+ import ast
493
+ import json
494
+
495
+ with open ("/output/mt_bench_data.json" , "r" ) as f :
496
+ mt_bench_data = f .read ()
497
+ mt_bench_data = ast .literal_eval (mt_bench_data )[0 ]
498
+
499
+ metrics .log_metric ("mt_bench_best_model" , mt_bench_data ["model" ])
500
+ metrics .log_metric ("mt_bench_best_score" , mt_bench_data ["overall_score" ])
501
+ metrics .log_metric ("mt_bench_best_model_error_rate" , mt_bench_data ["error_rate" ])
502
+
503
+ with open ("/output/mt_bench_branch/mt_bench_branch_data.json" , "r" ) as f :
504
+ mt_bench_branch_data = json .loads (f .read ())
505
+
506
+ metrics .log_metric ("mt_bench_branch_score" , mt_bench_branch_data ["overall_score" ])
507
+ metrics .log_metric (
508
+ "mt_bench_branch_base_score" , mt_bench_branch_data ["base_overall_score" ]
509
+ )
510
+
511
+ with open ("/output/mmlu_branch/mmlu_branch_data.json" , "r" ) as f :
512
+ mmlu_branch_data = json .loads (f .read ())
513
+
514
+ metrics .log_metric ("mmlu_branch_score" , mmlu_branch_data ["model_score" ])
515
+ metrics .log_metric ("mmlu_branch_base_score" , mmlu_branch_data ["base_model_score" ])
0 commit comments