@@ -26,6 +26,7 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
26
26
dictionary with 'wer' as the key and the word error rate as the value.
27
27
"""
28
28
wer_metric = load_metric ("wer" )
29
+ cer_metric = load_metric ("cer" )
29
30
tokenizer : PreTrainedTokenizerBase = getattr (processor , "tokenizer" )
30
31
pad_token = tokenizer .pad_token_id
31
32
@@ -84,13 +85,26 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
84
85
logger .info (f"Sample document: { labels_str [random_idx ]} " )
85
86
logger .info (f"Predicted: { predictions_str [random_idx ]} " )
86
87
87
- # Compute the word error rate
88
- computed = wer_metric .compute (predictions = predictions_str , references = labels_str )
89
- assert computed is not None
88
+ metrics : dict [str , float ] = dict ()
90
89
91
- # Ensure that `wer` is a dict, as metrics in the `evaluate` library can either be
92
- # dicts or floats
93
- if not isinstance (computed , dict ):
94
- return dict (wer = computed )
90
+ # Compute the word error rate
91
+ wer_computed = wer_metric .compute (
92
+ predictions = predictions_str , references = labels_str
93
+ )
94
+ assert wer_computed is not None
95
+ if not isinstance (wer_computed , dict ):
96
+ metrics = metrics | dict (wer = wer_computed )
95
97
else :
96
- return computed
98
+ metrics = metrics | wer_computed
99
+
100
+ # Compute the character error rate
101
+ cer_computed = cer_metric .compute (
102
+ predictions = predictions_str , references = labels_str
103
+ )
104
+ assert cer_computed is not None
105
+ if not isinstance (cer_computed , dict ):
106
+ metrics = metrics | dict (cer = cer_computed )
107
+ else :
108
+ metrics = metrics | cer_computed
109
+
110
+ return metrics
0 commit comments