From a633a82e4ad47a5a11d52725d4ef02257e7ee0b5 Mon Sep 17 00:00:00 2001 From: Sheiphan Joseph Date: Mon, 26 May 2025 19:00:32 +0530 Subject: [PATCH] FEAT: Basic and Advanced Evaluation: WIP --- evaluate.py | 330 +++++++++++++++++++++++++++++++++++++ evaluate_advanced.py | 379 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- run_evaluation.py | 169 +++++++++++++++++++ 4 files changed, 880 insertions(+), 1 deletion(-) create mode 100644 evaluate.py create mode 100644 evaluate_advanced.py create mode 100644 run_evaluation.py diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..42c145e --- /dev/null +++ b/evaluate.py @@ -0,0 +1,330 @@ +import logging +import os +from functools import partial +from typing import List, Tuple, Dict, Any +import numpy as np +from collections import defaultdict + +import torch +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoProcessor, Gemma3ForConditionalGeneration + +from config import Configuration +from utils import test_collate_function, parse_paligemma_label + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def parse_model_output(output_text: str, width: int, height: int) -> List[Dict[str, Any]]: + """ + Parse the model output to extract bounding boxes and categories. + + Args: + output_text: Raw model output text + width: Image width + height: Image height + + Returns: + List of dictionaries containing bbox coordinates and category + """ + predictions = [] + + # Split by semicolon to handle multiple detections + detections = output_text.split(';') + + for detection in detections: + detection = detection.strip() + if not detection: + continue + + try: + category, bbox = parse_paligemma_label(detection, width, height) + predictions.append({ + 'bbox': bbox, # [x1, y1, x2, y2] + 'category': category.strip(), + 'confidence': 1.0 # Model doesn't output confidence scores + }) + except Exception as e: + logger.warning(f"Failed to parse detection: {detection}, Error: {e}") + continue + + return predictions + + +def calculate_iou(box1: List[float], box2: List[float]) -> float: + """ + Calculate Intersection over Union (IoU) between two bounding boxes. + + Args: + box1: [x1, y1, x2, y2] + box2: [x1, y1, x2, y2] + + Returns: + IoU value between 0 and 1 + """ + x1_inter = max(box1[0], box2[0]) + y1_inter = max(box1[1], box2[1]) + x2_inter = min(box1[2], box2[2]) + y2_inter = min(box1[3], box2[3]) + + # Calculate intersection area + inter_width = max(0, x2_inter - x1_inter) + inter_height = max(0, y2_inter - y1_inter) + intersection = inter_width * inter_height + + # Calculate union area + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + if union == 0: + return 0.0 + + return intersection / union + + +def convert_coco_to_xyxy(bbox: List[float]) -> List[float]: + """Convert COCO format [x, y, width, height] to [x1, y1, x2, y2]""" + x, y, w, h = bbox + return [x, y, x + w, y + h] + + +def calculate_ap(precisions: List[float], recalls: List[float]) -> float: + """ + Calculate Average Precision using the 11-point interpolation method. + + Args: + precisions: List of precision values + recalls: List of recall values + + Returns: + Average Precision value + """ + # Sort by recall + sorted_indices = np.argsort(recalls) + sorted_recalls = np.array(recalls)[sorted_indices] + sorted_precisions = np.array(precisions)[sorted_indices] + + # Use 11-point interpolation + ap = 0.0 + for t in np.arange(0, 1.1, 0.1): + # Find precisions for recalls >= t + valid_precisions = sorted_precisions[sorted_recalls >= t] + if len(valid_precisions) > 0: + ap += np.max(valid_precisions) + + return ap / 11.0 + + +def evaluate_detections(predictions: List[Dict], ground_truths: List[Dict], + iou_threshold: float = 0.5) -> Dict[str, float]: + """ + Evaluate detections for a single image. + + Args: + predictions: List of predicted detections + ground_truths: List of ground truth detections + iou_threshold: IoU threshold for considering a detection as correct + + Returns: + Dictionary containing TP, FP, FN counts + """ + true_positives = 0 + false_positives = 0 + matched_gt = set() + + # For each prediction, find the best matching ground truth + for pred in predictions: + best_iou = 0 + best_gt_idx = -1 + + for gt_idx, gt in enumerate(ground_truths): + if gt_idx in matched_gt: + continue + + # Only match if categories are the same + if pred['category'] != gt['category']: + continue + + iou = calculate_iou(pred['bbox'], gt['bbox']) + if iou > best_iou: + best_iou = iou + best_gt_idx = gt_idx + + if best_iou >= iou_threshold and best_gt_idx != -1: + true_positives += 1 + matched_gt.add(best_gt_idx) + else: + false_positives += 1 + + false_negatives = len(ground_truths) - len(matched_gt) + + return { + 'tp': true_positives, + 'fp': false_positives, + 'fn': false_negatives + } + + +def compute_metrics(all_results: List[Dict[str, int]]) -> Dict[str, float]: + """ + Compute overall precision, recall, and F1 score. + + Args: + all_results: List of dictionaries containing TP, FP, FN for each image + + Returns: + Dictionary containing precision, recall, and F1 score + """ + total_tp = sum(result['tp'] for result in all_results) + total_fp = sum(result['fp'] for result in all_results) + total_fn = sum(result['fn'] for result in all_results) + + precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0 + recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0 + f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + + return { + 'precision': precision, + 'recall': recall, + 'f1_score': f1_score, + 'total_tp': total_tp, + 'total_fp': total_fp, + 'total_fn': total_fn + } + + +def get_dataloader(processor, cfg): + """Create test dataloader""" + test_dataset = load_dataset(cfg.dataset_id, split="test") + test_collate_fn = partial( + test_collate_function, processor=processor, dtype=cfg.dtype + ) + test_dataloader = DataLoader( + test_dataset, batch_size=1, collate_fn=test_collate_fn # Batch size 1 for evaluation + ) + return test_dataloader, test_dataset + + +def main(): + """Main evaluation function""" + cfg = Configuration() + + logger.info("Loading model and processor...") + processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.checkpoint_id, + torch_dtype=cfg.dtype, + device_map="cpu", + ) + model.eval() + model.to(cfg.device) + + logger.info("Creating test dataloader...") + test_dataloader, test_dataset = get_dataloader(processor=processor, cfg=cfg) + + all_results = [] + all_ious = [] + category_results = defaultdict(list) + + logger.info(f"Evaluating on {len(test_dataset)} test samples...") + + with torch.no_grad(): + for idx, (sample, sample_images) in enumerate(test_dataloader): + if idx >= len(test_dataset): + break + + sample = sample.to(cfg.device) + + # Get predictions from model + generation = model.generate(**sample, max_new_tokens=100) + decoded = processor.batch_decode(generation, skip_special_tokens=True) + + # Process each sample in the batch (batch size is 1) + for batch_idx, (output_text, sample_image) in enumerate(zip(decoded, sample_images)): + image = sample_image[0] + width, height = image.size + + # Get ground truth from original dataset + gt_sample = test_dataset[idx] + ground_truths = [] + + # Convert ground truth bounding boxes + for bbox, category in zip(gt_sample['objects']['bbox'], gt_sample['objects']['category']): + gt_bbox = convert_coco_to_xyxy(bbox) + ground_truths.append({ + 'bbox': gt_bbox, + 'category': 'plate' # Assuming all objects are plates + }) + + # Parse model predictions + predictions = parse_model_output(output_text, width, height) + + # Evaluate detections + result = evaluate_detections(predictions, ground_truths) + all_results.append(result) + + # Calculate IoUs for matched detections + for pred in predictions: + for gt in ground_truths: + if pred['category'] == gt['category']: + iou = calculate_iou(pred['bbox'], gt['bbox']) + all_ious.append(iou) + category_results[pred['category']].append(iou) + + if (idx + 1) % 50 == 0: + logger.info(f"Processed {idx + 1}/{len(test_dataset)} samples") + + # Compute overall metrics + logger.info("Computing final metrics...") + metrics = compute_metrics(all_results) + + # Compute mAP (simplified version) + # For a more accurate mAP, we would need confidence scores and multiple IoU thresholds + map_50 = metrics['precision'] # Simplified mAP@0.5 + + # Compute average IoU + avg_iou = np.mean(all_ious) if all_ious else 0.0 + + # Print results + print("\n" + "="*60) + print("OBJECT DETECTION EVALUATION RESULTS") + print("="*60) + print(f"Dataset: {cfg.dataset_id}") + print(f"Model: {cfg.checkpoint_id}") + print(f"Test samples: {len(test_dataset)}") + print("-"*60) + print(f"Precision: {metrics['precision']:.4f}") + print(f"Recall: {metrics['recall']:.4f}") + print(f"F1 Score: {metrics['f1_score']:.4f}") + print(f"mAP@0.5: {map_50:.4f}") + print(f"Average IoU: {avg_iou:.4f}") + print("-"*60) + print(f"True Positives: {metrics['total_tp']}") + print(f"False Positives: {metrics['total_fp']}") + print(f"False Negatives: {metrics['total_fn']}") + print("="*60) + + # Category-wise IoU + if category_results: + print("\nCategory-wise Average IoU:") + for category, ious in category_results.items(): + avg_cat_iou = np.mean(ious) + print(f" {category}: {avg_cat_iou:.4f}") + + return { + 'precision': metrics['precision'], + 'recall': metrics['recall'], + 'f1_score': metrics['f1_score'], + 'map_50': map_50, + 'avg_iou': avg_iou, + 'detailed_metrics': metrics + } + + +if __name__ == "__main__": + results = main() \ No newline at end of file diff --git a/evaluate_advanced.py b/evaluate_advanced.py new file mode 100644 index 0000000..0728eb4 --- /dev/null +++ b/evaluate_advanced.py @@ -0,0 +1,379 @@ +import logging +import os +from functools import partial +from typing import List, Tuple, Dict, Any +import numpy as np +from collections import defaultdict + +import torch +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import AutoProcessor, Gemma3ForConditionalGeneration + +from config import Configuration +from utils import test_collate_function, parse_paligemma_label + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def parse_model_output(output_text: str, width: int, height: int) -> List[Dict[str, Any]]: + """Parse the model output to extract bounding boxes and categories.""" + predictions = [] + + # Split by semicolon to handle multiple detections + detections = output_text.split(';') + + for detection in detections: + detection = detection.strip() + if not detection: + continue + + try: + category, bbox = parse_paligemma_label(detection, width, height) + predictions.append({ + 'bbox': bbox, # [x1, y1, x2, y2] + 'category': category.strip(), + 'confidence': 1.0 # Model doesn't output confidence scores + }) + except Exception as e: + logger.warning(f"Failed to parse detection: {detection}, Error: {e}") + continue + + return predictions + + +def calculate_iou(box1: List[float], box2: List[float]) -> float: + """Calculate Intersection over Union (IoU) between two bounding boxes.""" + x1_inter = max(box1[0], box2[0]) + y1_inter = max(box1[1], box2[1]) + x2_inter = min(box1[2], box2[2]) + y2_inter = min(box1[3], box2[3]) + + # Calculate intersection area + inter_width = max(0, x2_inter - x1_inter) + inter_height = max(0, y2_inter - y1_inter) + intersection = inter_width * inter_height + + # Calculate union area + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + if union == 0: + return 0.0 + + return intersection / union + + +def convert_coco_to_xyxy(bbox: List[float]) -> List[float]: + """Convert COCO format [x, y, width, height] to [x1, y1, x2, y2]""" + x, y, w, h = bbox + return [x, y, x + w, y + h] + + +def calculate_ap_interpolated(precisions: np.ndarray, recalls: np.ndarray) -> float: + """ + Calculate Average Precision using interpolation method (COCO style). + """ + # Sort by recall + sorted_indices = np.argsort(recalls) + sorted_recalls = recalls[sorted_indices] + sorted_precisions = precisions[sorted_indices] + + # Add points at recall 0 and 1 + recalls_interp = np.concatenate(([0], sorted_recalls, [1])) + precisions_interp = np.concatenate(([0], sorted_precisions, [0])) + + # Make precision monotonically decreasing + for i in range(len(precisions_interp) - 2, -1, -1): + precisions_interp[i] = max(precisions_interp[i], precisions_interp[i + 1]) + + # Calculate AP as area under curve + ap = 0.0 + for i in range(1, len(recalls_interp)): + ap += (recalls_interp[i] - recalls_interp[i - 1]) * precisions_interp[i] + + return ap + + +def evaluate_at_iou_threshold(all_predictions: List[List[Dict]], + all_ground_truths: List[List[Dict]], + iou_threshold: float, + category: str = None) -> Dict[str, float]: + """ + Evaluate detections at a specific IoU threshold. + + Args: + all_predictions: List of prediction lists for each image + all_ground_truths: List of ground truth lists for each image + iou_threshold: IoU threshold for considering a detection as correct + category: Specific category to evaluate (None for all categories) + + Returns: + Dictionary containing evaluation metrics + """ + all_scores = [] # confidence scores + all_matches = [] # whether each detection is TP or FP + total_gt = 0 + + for predictions, ground_truths in zip(all_predictions, all_ground_truths): + # Filter by category if specified + if category: + predictions = [p for p in predictions if p['category'] == category] + ground_truths = [gt for gt in ground_truths if gt['category'] == category] + + total_gt += len(ground_truths) + + # Sort predictions by confidence (descending) + predictions_sorted = sorted(predictions, key=lambda x: x['confidence'], reverse=True) + + matched_gt = set() + + for pred in predictions_sorted: + all_scores.append(pred['confidence']) + + best_iou = 0 + best_gt_idx = -1 + + for gt_idx, gt in enumerate(ground_truths): + if gt_idx in matched_gt: + continue + + # Only match if categories are the same (or we're not filtering by category) + if category is None and pred['category'] != gt['category']: + continue + + iou = calculate_iou(pred['bbox'], gt['bbox']) + if iou > best_iou: + best_iou = iou + best_gt_idx = gt_idx + + if best_iou >= iou_threshold and best_gt_idx != -1: + all_matches.append(1) # True Positive + matched_gt.add(best_gt_idx) + else: + all_matches.append(0) # False Positive + + if not all_scores: + return {'ap': 0.0, 'precision': 0.0, 'recall': 0.0} + + # Convert to numpy arrays + scores = np.array(all_scores) + matches = np.array(all_matches) + + # Sort by confidence scores (descending) + sorted_indices = np.argsort(-scores) + matches_sorted = matches[sorted_indices] + + # Calculate cumulative TP and FP + tp_cumsum = np.cumsum(matches_sorted) + fp_cumsum = np.cumsum(1 - matches_sorted) + + # Calculate precision and recall at each point + precisions = tp_cumsum / (tp_cumsum + fp_cumsum) + recalls = tp_cumsum / total_gt if total_gt > 0 else np.zeros_like(tp_cumsum) + + # Calculate Average Precision + ap = calculate_ap_interpolated(precisions, recalls) + + # Final precision and recall + final_precision = precisions[-1] if len(precisions) > 0 else 0.0 + final_recall = recalls[-1] if len(recalls) > 0 else 0.0 + + return { + 'ap': ap, + 'precision': final_precision, + 'recall': final_recall, + 'total_tp': int(tp_cumsum[-1]) if len(tp_cumsum) > 0 else 0, + 'total_fp': int(fp_cumsum[-1]) if len(fp_cumsum) > 0 else 0, + 'total_gt': total_gt + } + + +def calculate_comprehensive_metrics(all_predictions: List[List[Dict]], + all_ground_truths: List[List[Dict]]) -> Dict[str, Any]: + """ + Calculate comprehensive evaluation metrics including mAP at multiple IoU thresholds. + + Args: + all_predictions: List of prediction lists for each image + all_ground_truths: List of ground truth lists for each image + + Returns: + Dictionary containing comprehensive metrics + """ + iou_thresholds = np.arange(0.5, 1.0, 0.05) # [0.5, 0.55, 0.6, ..., 0.95] + + # Get all unique categories + all_categories = set() + for predictions in all_predictions: + for pred in predictions: + all_categories.add(pred['category']) + for ground_truths in all_ground_truths: + for gt in ground_truths: + all_categories.add(gt['category']) + + results = {} + + # Calculate metrics for each IoU threshold + ap_per_iou = [] + for iou_thresh in iou_thresholds: + metrics = evaluate_at_iou_threshold(all_predictions, all_ground_truths, iou_thresh) + ap_per_iou.append(metrics['ap']) + + if iou_thresh == 0.5: # Store detailed metrics for IoU=0.5 + results['metrics_at_50'] = metrics + if iou_thresh == 0.75: # Store detailed metrics for IoU=0.75 + results['metrics_at_75'] = metrics + + # Calculate mAP (mean over IoU thresholds) + results['mAP'] = np.mean(ap_per_iou) + results['mAP_50'] = ap_per_iou[0] # AP at IoU=0.5 + results['mAP_75'] = ap_per_iou[5] if len(ap_per_iou) > 5 else 0.0 # AP at IoU=0.75 + + # Category-wise metrics at IoU=0.5 + category_metrics = {} + for category in all_categories: + cat_metrics = evaluate_at_iou_threshold(all_predictions, all_ground_truths, 0.5, category) + category_metrics[category] = cat_metrics + + results['category_metrics'] = category_metrics + results['iou_thresholds'] = iou_thresholds.tolist() + results['ap_per_iou'] = ap_per_iou + + return results + + +def get_dataloader(processor, cfg): + """Create test dataloader""" + test_dataset = load_dataset(cfg.dataset_id, split="test") + test_collate_fn = partial( + test_collate_function, processor=processor, dtype=cfg.dtype + ) + test_dataloader = DataLoader( + test_dataset, batch_size=1, collate_fn=test_collate_fn + ) + return test_dataloader, test_dataset + + +def main(): + """Main evaluation function with comprehensive metrics""" + cfg = Configuration() + + logger.info("Loading model and processor...") + processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) + model = Gemma3ForConditionalGeneration.from_pretrained( + cfg.checkpoint_id, + torch_dtype=cfg.dtype, + device_map="cpu", + ) + model.eval() + model.to(cfg.device) + + logger.info("Creating test dataloader...") + test_dataloader, test_dataset = get_dataloader(processor=processor, cfg=cfg) + + all_predictions = [] + all_ground_truths = [] + all_ious = [] + + logger.info(f"Evaluating on {len(test_dataset)} test samples...") + + with torch.no_grad(): + for idx, (sample, sample_images) in enumerate(test_dataloader): + if idx >= len(test_dataset): + break + + sample = sample.to(cfg.device) + print(sample) + # Get predictions from model + generation = model.generate(**sample, max_new_tokens=100) + decoded = processor.batch_decode(generation, skip_special_tokens=True) + + # Process each sample in the batch (batch size is 1) + for batch_idx, (output_text, sample_image) in enumerate(zip(decoded, sample_images)): + image = sample_image[0] + width, height = image.size + + # Get ground truth from original dataset + gt_sample = test_dataset[idx] + ground_truths = [] + + # Convert ground truth bounding boxes + for bbox, category in zip(gt_sample['objects']['bbox'], gt_sample['objects']['category']): + gt_bbox = convert_coco_to_xyxy(bbox) + ground_truths.append({ + 'bbox': gt_bbox, + 'category': 'plate' # Assuming all objects are plates + }) + + # Parse model predictions + predictions = parse_model_output(output_text, width, height) + + all_predictions.append(predictions) + all_ground_truths.append(ground_truths) + + # Calculate IoUs for all predictions + for pred in predictions: + for gt in ground_truths: + if pred['category'] == gt['category']: + iou = calculate_iou(pred['bbox'], gt['bbox']) + all_ious.append(iou) + + if (idx + 1) % 50 == 0: + logger.info(f"Processed {idx + 1}/{len(test_dataset)} samples") + + # Calculate comprehensive metrics + logger.info("Computing comprehensive metrics...") + results = calculate_comprehensive_metrics(all_predictions, all_ground_truths) + + # Calculate average IoU + avg_iou = np.mean(all_ious) if all_ious else 0.0 + + # Print results + print("\n" + "="*70) + print("COMPREHENSIVE OBJECT DETECTION EVALUATION RESULTS") + print("="*70) + print(f"Dataset: {cfg.dataset_id}") + print(f"Model: {cfg.checkpoint_id}") + print(f"Test samples: {len(test_dataset)}") + print("-"*70) + print("COCO-style mAP Metrics:") + print(f" mAP (IoU 0.5:0.95): {results['mAP']:.4f}") + print(f" mAP@0.5: {results['mAP_50']:.4f}") + print(f" mAP@0.75: {results['mAP_75']:.4f}") + print("-"*70) + + if 'metrics_at_50' in results: + metrics_50 = results['metrics_at_50'] + print("Metrics at IoU=0.5:") + print(f" Precision: {metrics_50['precision']:.4f}") + print(f" Recall: {metrics_50['recall']:.4f}") + print(f" True Positives: {metrics_50['total_tp']}") + print(f" False Positives: {metrics_50['total_fp']}") + print(f" Total Ground Truth: {metrics_50['total_gt']}") + + print(f"\nAverage IoU: {avg_iou:.4f}") + print("-"*70) + + # Category-wise results + if results['category_metrics']: + print("Category-wise AP@0.5:") + for category, metrics in results['category_metrics'].items(): + print(f" {category:12}: AP={metrics['ap']:.4f}, " + f"P={metrics['precision']:.4f}, R={metrics['recall']:.4f}") + + print("="*70) + + # Save results to file + results['avg_iou'] = avg_iou + results['total_samples'] = len(test_dataset) + + return results + + +if __name__ == "__main__": + results = main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 35c7bbf..20519bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ huggingface_hub[hf_xet] bitsandbytes wandb peft -albumentations \ No newline at end of file +albumentations +numpy \ No newline at end of file diff --git a/run_evaluation.py b/run_evaluation.py new file mode 100644 index 0000000..65f2b01 --- /dev/null +++ b/run_evaluation.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Evaluation runner script for Gemma 3 Object Detection model. + +Usage: + python run_evaluation.py --mode basic # Basic evaluation (faster) + python run_evaluation.py --mode advanced # Comprehensive COCO-style evaluation + python run_evaluation.py --help # Show help +""" + +import argparse +import sys +import json +import time +from pathlib import Path + + +def run_basic_evaluation(): + """Run basic evaluation with simple metrics.""" + print("Running basic evaluation...") + try: + import evaluate + results = evaluate.main() + return results + except Exception as e: + print(f"Error running basic evaluation: {e}") + return None + + +def run_advanced_evaluation(): + """Run advanced evaluation with comprehensive COCO-style metrics.""" + print("Running advanced evaluation with COCO-style metrics...") + try: + import evaluate_advanced + results = evaluate_advanced.main() + return results + except Exception as e: + print(f"Error running advanced evaluation: {e}") + return None + + +def save_results(results, output_file): + """Save evaluation results to JSON file.""" + if results is None: + print("No results to save.") + return + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Convert numpy arrays to lists for JSON serialization + json_results = {} + for key, value in results.items(): + if hasattr(value, 'tolist'): + json_results[key] = value.tolist() + elif isinstance(value, dict): + json_results[key] = {} + for k, v in value.items(): + if hasattr(v, 'tolist'): + json_results[key][k] = v.tolist() + else: + json_results[key][k] = v + else: + json_results[key] = value + + with open(output_path, 'w') as f: + json.dump(json_results, f, indent=2) + + print(f"Results saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate Gemma 3 Object Detection model", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run_evaluation.py --mode basic + python run_evaluation.py --mode advanced --output results/eval_results.json + python run_evaluation.py --mode both --output results/ + """ + ) + + parser.add_argument( + '--mode', + choices=['basic', 'advanced', 'both'], + default='basic', + help='Evaluation mode: basic (fast), advanced (comprehensive), or both' + ) + + parser.add_argument( + '--output', + type=str, + default='evaluation_results.json', + help='Output file/directory for results (default: evaluation_results.json)' + ) + + parser.add_argument( + '--verbose', + action='store_true', + help='Verbose output' + ) + + args = parser.parse_args() + + print("="*60) + print("GEMMA 3 OBJECT DETECTION EVALUATION") + print("="*60) + print(f"Mode: {args.mode}") + print(f"Output: {args.output}") + print("-"*60) + + start_time = time.time() + + if args.mode == 'basic': + results = run_basic_evaluation() + if results: + save_results(results, args.output) + + elif args.mode == 'advanced': + results = run_advanced_evaluation() + if results: + save_results(results, args.output) + + elif args.mode == 'both': + output_dir = Path(args.output) + if output_dir.is_file(): + output_dir = output_dir.parent + + print("\n" + "="*60) + print("RUNNING BASIC EVALUATION") + print("="*60) + basic_results = run_basic_evaluation() + if basic_results: + save_results(basic_results, output_dir / 'basic_evaluation.json') + + print("\n" + "="*60) + print("RUNNING ADVANCED EVALUATION") + print("="*60) + advanced_results = run_advanced_evaluation() + if advanced_results: + save_results(advanced_results, output_dir / 'advanced_evaluation.json') + + # Create summary comparison + if basic_results and advanced_results: + summary = { + 'basic_metrics': { + 'precision': basic_results.get('precision', 0), + 'recall': basic_results.get('recall', 0), + 'f1_score': basic_results.get('f1_score', 0), + 'avg_iou': basic_results.get('avg_iou', 0) + }, + 'advanced_metrics': { + 'mAP': advanced_results.get('mAP', 0), + 'mAP_50': advanced_results.get('mAP_50', 0), + 'mAP_75': advanced_results.get('mAP_75', 0), + 'avg_iou': advanced_results.get('avg_iou', 0) + }, + 'evaluation_time': time.time() - start_time + } + save_results(summary, output_dir / 'evaluation_summary.json') + + total_time = time.time() - start_time + print(f"\nTotal evaluation time: {total_time:.2f} seconds") + print("Evaluation completed!") + + +if __name__ == "__main__": + main() \ No newline at end of file