Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/muse/evaluation/evaluate_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Generate CSV containing MT metric scores for machine translation corpus.

This script processes a machine translation corpus (JSONL format) and computes
evaluation metrics (ChrF and COMET) for each translation. The output is a CSV
file with columns: tr_id, chrf, comet.
evaluation metrics (ChrF, COMET, and CometKiwi) for each translation. The output
is a CSV file with columns: tr_id, chrf, comet, cometkiwi.

Usage:
evaluate_corpus.py INPUT OUTPUT [--verbose]
Expand All @@ -18,7 +18,7 @@
import orjsonl
from tqdm import tqdm

from muse.evaluation.metrics import compute_chrf, compute_comet
from muse.evaluation.metrics import compute_chrf, compute_comet, compute_cometkiwi

logger = logging.getLogger(__name__)

Expand All @@ -31,17 +31,19 @@ def evaluate_corpus(
"""
Compute evaluation metrics for machine translation corpus and save to CSV.

Reads machine translation records from input JSONL file, computes ChrF and
COMET scores for each translation, and writes results to output CSV file
with columns: tr_id, chrf, comet.
Reads machine translation records from input JSONL file, computes ChrF,
COMET, and CometKiwi scores for each translation, and writes results to
output CSV file with columns: tr_id, chrf, comet, cometkiwi.
"""
# Count total records for progress bar
total_records = sum(1 for _ in orjsonl.stream(input_path))
logger.info(f"Found {total_records} translations to evaluate")

# Open output CSV file
with output_path.open("w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["tr_id", "chrf", "comet"])
writer = csv.DictWriter(
csvfile, fieldnames=["tr_id", "chrf", "comet", "cometkiwi"]
)
writer.writeheader()

# Process each translation record
Expand All @@ -62,13 +64,18 @@ def evaluate_corpus(
src_text=record["src_text"],
ref_text=record["ref_text"],
)
cometkiwi_score = compute_cometkiwi(
tr_text=record["tr_text"],
src_text=record["src_text"],
)

# Write to CSV
writer.writerow(
{
"tr_id": record["tr_id"],
"chrf": chrf_score,
"comet": comet_score,
"cometkiwi": cometkiwi_score,
}
)

Expand Down
49 changes: 48 additions & 1 deletion src/muse/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import io
import logging
import os
from typing import Any

import evaluate
import torch
from comet import download_model, load_from_checkpoint

# Environment variable configuration for PyTorch and HuggingFace libraries
os.environ["TOKENIZERS_PARALLELISM"] = (
Expand All @@ -27,9 +29,10 @@

# Cache for loaded metrics to avoid reloading models
# Note: Caching COMET model requires ~2GB RAM for the wmt22-comet-da model
LOADED_METRICS = {
LOADED_METRICS: dict[str, Any] = {
"chrf": None,
"comet": None,
"cometkiwi": None,
}


Expand Down Expand Up @@ -95,3 +98,47 @@ def compute_comet(
score = result["mean_score"]

return score


def compute_cometkiwi(
tr_text: str,
src_text: str,
) -> float:
"""
Compute CometKiwi score for a translation using the comet package.

CometKiwi is a reference-free quality estimation metric that combines COMET
with OpenKiwi. Unlike COMET, it does not require a reference translation and
evaluates translation quality based only on the source text and machine
translation.

Returns a float in the range [0, 1], where 0 indicates a poor translation
and 1 indicates a perfect translation.
"""
# Load model once and cache it
if LOADED_METRICS["cometkiwi"] is None:
try:
model_path = download_model("Unbabel/wmt22-cometkiwi-da")
LOADED_METRICS["cometkiwi"] = load_from_checkpoint(model_path)
except KeyError as e: # download_model catches all exceptions and re-raises as KeyError
msg = (
"Authentication required for CometKiwi model. "
"Please:\n"
"1. Visit https://huggingface.co/Unbabel/wmt22-cometkiwi-da and accept the license\n"
"2. Run: hf auth login\n"
"3. Enter your HuggingFace token when prompted"
)
raise RuntimeError(msg) from e

model = LOADED_METRICS["cometkiwi"]
gpus = 1 if (torch.cuda.is_available() or torch.backends.mps.is_available()) else 0

# Prepare data in the format expected by CometKiwi
data = [{"src": src_text, "mt": tr_text}]

# Predict returns a Prediction object; access the first score
model_output = model.predict(data, batch_size=1, gpus=gpus)
# The Prediction object can be indexed to get individual scores
score = model_output[0]

return score
Loading