diff --git a/configs/default_config.yaml b/configs/default_config.yaml index 5873222..a672abe 100644 --- a/configs/default_config.yaml +++ b/configs/default_config.yaml @@ -25,7 +25,7 @@ ball_tracking: data: # shared between dataset and datamodule frame_size: [1280, 720] # [width, height] num_frame: 3 - frame_stride: 1 # Stride between frames within a sequence + frame_stride: 3 # Stride between frames within a sequence sequence_stride: 30 # Stride between different sequences mag: 1.0 sigma: 2.5 @@ -73,11 +73,12 @@ ball_tracking: epochs: 100 learning_rate: 0.001 tolerance: 4 - save_dir: "exp-stride=1-weighted_msee" + save_dir: "exp-stride=3-weighted_msee" debug: false - data_dir: "${interim_data_path}/ball_tracking_dataset-stride-1" + data_dir: "${interim_data_path}/ball_tracking_dataset" devices: [0] - early_stop_patience: 30 + + early_stop_patience: 25 # Loss configuration main_loss: "weighted_mse" # choices: weighted_bce, focal_wbce, kl_div, bce @@ -161,8 +162,8 @@ create_ground_truth_mot: create_yolo_dataset: video_path: "/data/share/teamtrack/teamtrack-mot/soccer_side/full/combined.mp4" - mot_path: "/data/share/teamtrack/teamtrack-mot/soccer_side/full/gt.txt" - output_dir: "/home/atom/SoccerTrack-v2/data/v1" + mot_path: "/data/share/teamtrack/teamtrack-mot/soccer_side/full/combined.txt" + output_dir: "/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/SoccerTrack-v2/data/interim/yolo_dataset" frame_interval: 1000 train_split: 0.8 val_split: 0.1 diff --git a/src/ball_tracking/tracknetx/data_module.py b/src/ball_tracking/tracknetx/data_module.py index 4d5ae54..66cd5e9 100644 --- a/src/ball_tracking/tracknetx/data_module.py +++ b/src/ball_tracking/tracknetx/data_module.py @@ -9,8 +9,8 @@ from dataclasses import dataclass from omegaconf import DictConfig -from src.ball_tracking.tracknetx.dataset import TrackNetX_Dataset -from src.ball_tracking.tracknetx.data_transforms import RandomCrop, RandomHorizontalFlip, Resize +from dataset import TrackNetX_Dataset +from data_transforms import RandomCrop, RandomHorizontalFlip, Resize def collate_fn(batch): @@ -276,4 +276,4 @@ def test_dataloader(self): # Save each sample to a separate file plt.savefig("dataset_visualization_dm.jpg", bbox_inches="tight", dpi=300) - logger.info(f"Saved to: {Path.cwd() / 'dataset_visualization_dm.jpg'}") + logger.info(f"Saved to: {Path.cwd() / 'dataset_visualization_dm.jpg'}") \ No newline at end of file diff --git a/src/ball_tracking/tracknetx/data_transforms.py b/src/ball_tracking/tracknetx/data_transforms.py index e832e4a..c4379d1 100644 --- a/src/ball_tracking/tracknetx/data_transforms.py +++ b/src/ball_tracking/tracknetx/data_transforms.py @@ -194,4 +194,4 @@ def __call__(self, frames, heatmaps, coordinates): resized_coordinates[:, 0] *= width_ratio resized_coordinates[:, 1] *= height_ratio - return resized_frames, resized_heatmaps, resized_coordinates + return resized_frames, resized_heatmaps, resized_coordinates \ No newline at end of file diff --git a/src/ball_tracking/tracknetx/evaluate.py b/src/ball_tracking/tracknetx/evaluate.py new file mode 100644 index 0000000..88b33bc --- /dev/null +++ b/src/ball_tracking/tracknetx/evaluate.py @@ -0,0 +1,201 @@ +import numpy as np +import torch +import cv2 +from pathlib import Path +from tqdm import tqdm +from sklearn.metrics import precision_score, recall_score, average_precision_score + +from model import TrackNetXModel + +def load_model(checkpoint_path: str, device: str = "cpu"): + """Load the trained model from a checkpoint.""" + model = TrackNetXModel.load_from_checkpoint(checkpoint_path) + model.eval() + model.to(device) + return model + +def preprocess_frames(frames): + """Preprocess frames to match the input format of the model.""" + processed = [] + for frame in frames: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = frame.astype(np.float32) / 255.0 + frame = frame.transpose(2, 0, 1) # (H, W, C) -> (C, H, W) + processed.append(frame) + input_tensor = np.concatenate(processed, axis=0) # (9, H, W) + return torch.tensor(input_tensor, dtype=torch.float32).unsqueeze(0) # (1, 9, H, W) + +def draw_predictions(image, prediction, ground_truth, visibility, save_path): + """ + Draw predictions, ground truth, and visibility on the image and save it. + + Args: + image: Original image. + prediction: Predicted (x, y) coordinates. + ground_truth: Ground truth (x, y) coordinates. + visibility: Whether the ball is visible in the frame. + save_path: Path to save the annotated image. + """ + img_with_overlay = image.copy() + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.7 + thickness = 2 + + # Draw predicted position + pred_x, pred_y = prediction + cv2.circle(img_with_overlay, (int(pred_x), int(pred_y)), 5, (0, 255, 0), -1) # Green circle + cv2.putText( + img_with_overlay, + f"Pred: ({pred_x:.1f}, {pred_y:.1f})", + (10, 30), + font, + font_scale, + (0, 255, 0), + thickness, + ) + + # Draw ground truth position + gt_x, gt_y = ground_truth + cv2.circle(img_with_overlay, (int(gt_x), int(gt_y)), 5, (0, 0, 255), -1) # Red circle + cv2.putText( + img_with_overlay, + f"GT: ({gt_x:.1f}, {gt_y:.1f})", + (10, 60), + font, + font_scale, + (0, 0, 255), + thickness, + ) + + # Draw visibility status + vis_text = "Visible" if visibility else "Not Visible" + vis_color = (0, 255, 0) if visibility else (0, 0, 255) + cv2.putText( + img_with_overlay, + f"Visibility: {vis_text}", + (10, 90), + font, + font_scale, + vis_color, + thickness, + ) + + # Save the image + cv2.imwrite(str(save_path), img_with_overlay) + +def evaluate_model(sequences, coordinates, visibility, model, output_dir, device): + """Evaluate the model and calculate metrics.""" + predictions = [] + ground_truths = [] + visibilities = [] + ap_scores_50 = [] + ap_scores_50_95 = [] + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + iou_thresholds = np.linspace(0.5, 0.95, 10) + + for seq_idx, sequence in tqdm(enumerate(sequences), desc="Evaluating sequences"): + frames = [cv2.imread(frame_path) for frame_path in sequence] + input_tensor = preprocess_frames(frames).to(device) + + # Run prediction + with torch.no_grad(): + output = model(input_tensor) + output_prob = torch.sigmoid(output).squeeze(0).cpu().numpy() + + # Get predicted coordinates (argmax of heatmap) + pred_heatmap = output_prob[1] # Use the middle frame's heatmap + pred_y, pred_x = np.unravel_index(np.argmax(pred_heatmap), pred_heatmap.shape) + predictions.append((pred_x, pred_y)) + + # Add ground truth and visibility + ground_truths.append(tuple(coordinates[seq_idx, 1])) # Middle frame's ground truth + visibilities.append(visibility[seq_idx, 1]) # Middle frame's visibility + + # Prepare binary ground truth heatmap + gt_heatmap = np.zeros_like(pred_heatmap) + if visibility[seq_idx, 1]: # Only mark visible frames + gt_x, gt_y = coordinates[seq_idx, 1] + gt_x, gt_y = int(gt_x), int(gt_y) + if 0 <= gt_x < gt_heatmap.shape[1] and 0 <= gt_y < gt_heatmap.shape[0]: + gt_heatmap[gt_y, gt_x] = 1 # Set ground truth pixel to 1 + + # Flatten ground truth and predicted heatmaps for mAP calculation + y_true = gt_heatmap.flatten() + y_score = pred_heatmap.flatten() + + # Calculate AP for IoU thresholds + ap_per_threshold = [] + for iou_thresh in iou_thresholds: + ap_score = average_precision_score(y_true, y_score >= iou_thresh) + ap_per_threshold.append(ap_score) + ap_scores_50.append(ap_per_threshold[0]) # AP at IoU=0.5 + ap_scores_50_95.append(np.mean(ap_per_threshold)) # Mean AP from IoU=0.5 to 0.95 + + # Draw and save prediction + save_path = output_dir / f"sequence_{seq_idx}_frame_1.jpg" + draw_predictions(frames[1], (pred_x, pred_y), coordinates[seq_idx, 1], visibility[seq_idx, 1], save_path) + + # Calculate overall metrics + precision = precision_score(visibilities, [1 if v else 0 for v in visibilities]) + recall = recall_score(visibilities, [1 if v else 0 for v in visibilities]) + map_50 = np.mean(ap_scores_50) if ap_scores_50 else 0.0 + map_50_95 = np.mean(ap_scores_50_95) if ap_scores_50_95 else 0.0 + + # Euclidean distance and MSE + distances = [] + squared_errors = [] + for (pred_x, pred_y), (gt_x, gt_y), vis in zip(predictions, ground_truths, visibilities): + if vis: # Only evaluate visible frames + dist = np.sqrt((pred_x - gt_x) ** 2 + (pred_y - gt_y) ** 2) + mse = (pred_x - gt_x) ** 2 + (pred_y - gt_y) ** 2 + distances.append(dist) + squared_errors.append(mse) + + mean_distance = np.mean(distances) if distances else float('nan') + mean_squared_error = np.mean(squared_errors) if squared_errors else float('nan') + + return { + "precision": precision, + "recall": recall, + "mAP@0.5": map_50, + "mAP@0.5:0.95": map_50_95, + "mean_distance": mean_distance, + "mean_squared_error": mean_squared_error, + } + +def main(): + # Paths + sequences_path = Path("/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/SoccerTrack-v2/data/interim/ball_tracking_dataset/test/sequences.npy") + coordinates_path = Path("/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/SoccerTrack-v2/data/interim/ball_tracking_dataset/test/coordinates.npy") + visibility_path = Path("/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/SoccerTrack-v2/data/interim/ball_tracking_dataset/test/visibility.npy") + checkpoint_path = "/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/tracknetx/exp-stride=3-weighted_msee/model-epoch=89-val_total_loss=0.00.ckpt" + output_dir = "/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/tracknetx/exp" + + # Device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Load data + sequences = np.load(sequences_path, allow_pickle=True) + coordinates = np.load(coordinates_path, allow_pickle=True) + visibility = np.load(visibility_path, allow_pickle=True) + + # Load model + model = load_model(checkpoint_path, device) + + # Evaluate model + metrics = evaluate_model(sequences, coordinates, visibility, model, output_dir, device) + + # Print metrics + print("Evaluation Metrics:") + print(f"Mean Euclidean Distance: {metrics['mean_distance']:.2f}") + print(f"Mean Squared Error: {metrics['mean_squared_error']:.2f}") + print(f"Precision: {metrics['precision']:.4f}") + print(f"Recall: {metrics['recall']:.4f}") + print(f"mAP@0.5: {metrics['mAP@0.5']:.4f}") + print(f"mAP@0.5:0.95: {metrics['mAP@0.5:0.95']:.4f}") + +if __name__ == "__main__": + main() diff --git a/src/ball_tracking/tracknetx/inference.py b/src/ball_tracking/tracknetx/inference.py index 7e14ea9..5c2fd07 100644 --- a/src/ball_tracking/tracknetx/inference.py +++ b/src/ball_tracking/tracknetx/inference.py @@ -8,7 +8,7 @@ import torch.nn.functional as F # Replace with the actual import of your TrackNetXModel -from src.ball_tracking.tracknetx.model import TrackNetXModel +from model import TrackNetXModel def load_model(checkpoint_path: str, device: str = "cpu"): @@ -244,4 +244,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/ball_tracking/tracknetx/losses.py b/src/ball_tracking/tracknetx/losses.py index d6ab85b..bc3e631 100644 --- a/src/ball_tracking/tracknetx/losses.py +++ b/src/ball_tracking/tracknetx/losses.py @@ -369,4 +369,4 @@ def forward(self, pred_heatmaps, gt_heatmaps, gt_coords): # 4. Combine losses total_loss = coord_loss + self.beta * heatmap_loss - return total_loss + return total_loss \ No newline at end of file diff --git a/src/ball_tracking/tracknetx/train.py b/src/ball_tracking/tracknetx/train.py index 3bdec56..20322ca 100644 --- a/src/ball_tracking/tracknetx/train.py +++ b/src/ball_tracking/tracknetx/train.py @@ -4,6 +4,7 @@ import sys from pathlib import Path from typing import List +import random import torch import pytorch_lightning as pl @@ -13,12 +14,28 @@ from src.ball_tracking.tracknetx.evaluate_callback import EvaluateAndLogCallback # Use internal evaluate_tracknet_model from omegaconf import OmegaConf -from src.ball_tracking.tracknetx.data_module import TrackNetXDataModule -from src.ball_tracking.tracknetx.model import TrackNetXModel -from src.ball_tracking.tracknetx.utils import model_summary, evaluation, plot_result +from data_module import TrackNetXDataModule +from model import TrackNetXModel +from utils import model_summary, evaluation, plot_result + +def seed_everything(seed: int = 42): + """ + Set seed for reproducibility. + + Args: + seed (int): Random seed value. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # For multi-GPU setups + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def main(): """Main training function.""" + seed_everything(42) # Load configuration config = OmegaConf.load("/home/nakamura/desktop/playbox/ball_detection/TrackNetV3/tracknetx/SoccerTrack-v2/configs/default_config.yaml") logger.remove() @@ -110,13 +127,14 @@ def main(): callbacks=callbacks, ) - # Train the model from scratch - logger.info("Starting training from scratch...") - trainer.fit(model, data_module) + + logger.info("Starting training with validation...") + trainer.fit(model, data_module) # , ckpt_path=checkpoint_path + # Test the model logger.info("Starting testing...") trainer.test(model, data_module) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/ball_tracking/tracknetx/utils.py b/src/ball_tracking/tracknetx/utils.py index ff82876..8d9e855 100644 --- a/src/ball_tracking/tracknetx/utils.py +++ b/src/ball_tracking/tracknetx/utils.py @@ -7,7 +7,8 @@ import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm -# from PIL import Image, ImageSequence +from PIL import Image, ImageSequence +import parse # HEIGHT = 333 # WIDTH = 3250 @@ -769,3 +770,4 @@ def show_prediction(x, y, y_pred, y_coor, input_type, save_dir): duration=1000, loop=0, ) + \ No newline at end of file