From dfc7c1014e763fb3903533484066c2fcec0d855c Mon Sep 17 00:00:00 2001 From: samrat-rm Date: Sat, 14 Mar 2026 02:45:45 +0530 Subject: [PATCH 1/3] feat: compute and print mean single-pass entropy uncertainty during model testing --- models/common_metrics.py | 13 +++++++++---- models/model_test.py | 21 +++++++++++++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/models/common_metrics.py b/models/common_metrics.py index 6ef81a7..34ae6bf 100644 --- a/models/common_metrics.py +++ b/models/common_metrics.py @@ -126,7 +126,7 @@ def iou_per_class(preds, labels, num_classes=3): ious.append(iou) return ious -def calculate_metrics(all_preds, all_targets, num_classes, total_pixels, correct_pixels): +def calculate_metrics(all_preds, all_targets, num_classes, total_pixels, correct_pixels, mean_uncertainty=None): """ Calculate validation metrics from predictions and targets. @@ -175,6 +175,8 @@ def calculate_metrics(all_preds, all_targets, num_classes, total_pixels, correct metrics = { "pixel_accuracy": pixel_accuracy, } + if mean_uncertainty is not None: + metrics["mean_uncertainty"] = mean_uncertainty avg_iou = 0 for cls in range(num_classes): metrics[f"iou_{cls}"] = iou[cls] @@ -193,14 +195,17 @@ def calculate_metrics(all_preds, all_targets, num_classes, total_pixels, correct # metrics["confusion_matrix"] = cm.flatten().tolist() # Print results - print(f"\nPixel Accuracy: {pixel_accuracy:.4f}, mIoU: {avg_iou}") + if mean_uncertainty is not None: + print(f"\nPixel Accuracy: {pixel_accuracy:.4f}, mIoU: {avg_iou:.4f}, Mean Uncertainty: {mean_uncertainty:.4f}") + else: + print(f"\nPixel Accuracy: {pixel_accuracy:.4f}, mIoU: {avg_iou:.4f}") print(f"{'Class':<6} {'IoU':>6} {'Precision':>10} {'Recall':>8} {'F1':>6}") for cls in range(num_classes): print(f"{cls:<6} {iou[cls]:>6.3f} {precision[cls]:>10.3f} {recall[cls]:>8.3f} {f1[cls]:>6.3f}") return metrics -def validate_all(model, val_loader, params_dict): +def validate_all(model, val_loader, params_dict, mean_uncertainty=None): """ Validation function for single-head U-Net models. Expects data loader to return (x, y) format. @@ -257,7 +262,7 @@ def validate_all(model, val_loader, params_dict): avg_loss = total_loss / total_batches # Calculate metrics - metrics = calculate_metrics(all_preds, all_targets, params_dict["num_classes"], total_pixels, correct_pixels) + metrics = calculate_metrics(all_preds, all_targets, params_dict["num_classes"], total_pixels, correct_pixels,mean_uncertainty) if "loss" in params_dict: metrics["val_loss"] = avg_loss metrics["neg_val_loss"] = -avg_loss diff --git a/models/model_test.py b/models/model_test.py index 6a06510..fd4457f 100644 --- a/models/model_test.py +++ b/models/model_test.py @@ -17,6 +17,10 @@ import argparse from tqdm import tqdm +def single_pass_uncertainty(logits): + probs = torch.softmax(logits, dim=1) + entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1) + return entropy def save_inference_images(ibatch, save_inference_dir, results, inputs, outputs, preds, targets, batch_size, test_df, save_logits, num_classes): if isinstance(inputs, list): @@ -105,6 +109,8 @@ def evaluate_on_test_set( print(f"Save Inference to: {save_inference_dir}") with torch.no_grad(): + total_uncertainty = 0.0 + total_pixels = 0 results=[] for i, (inputs, targets) in enumerate(tqdm(test_loader, desc="Inference Progress")): @@ -113,11 +119,20 @@ def evaluate_on_test_set( inputs = inputs.to(device) outputs = model(inputs) ''' + if isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + uncertainty = single_pass_uncertainty(logits).cpu() + total_uncertainty += uncertainty.sum().item() + total_pixels += uncertainty.numel() if isinstance(outputs, tuple): preds = torch.argmax(outputs[0], dim=1).cpu() else: preds = torch.argmax(outputs, dim=1).cpu() - + uncertainty_scores = outputs.var(dim=0).mean(dim=1) + print("Uncertainity using py-trco framework : ",uncertainty_scores ) targets = targets.cpu() all_preds.append(preds) all_targets.append(targets) @@ -134,7 +149,9 @@ def evaluate_on_test_set( print(f"Saved inference results to {csv_path}") - metrics = validate_all(model, test_loader, params_dict) + if total_pixels > 0: + mean_uncertainty = total_uncertainty / total_pixels + metrics = validate_all(model, test_loader, params_dict,mean_uncertainty) if wandbrun: wandbrun.log(metrics) From dc5e924e1b3b832276c3fb46f0b2df180151d478 Mon Sep 17 00:00:00 2001 From: samrat-rm Date: Sat, 14 Mar 2026 02:55:49 +0530 Subject: [PATCH 2/3] remove temporary debug print used during development --- models/model_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/model_test.py b/models/model_test.py index fd4457f..c7e5697 100644 --- a/models/model_test.py +++ b/models/model_test.py @@ -132,7 +132,6 @@ def evaluate_on_test_set( else: preds = torch.argmax(outputs, dim=1).cpu() uncertainty_scores = outputs.var(dim=0).mean(dim=1) - print("Uncertainity using py-trco framework : ",uncertainty_scores ) targets = targets.cpu() all_preds.append(preds) all_targets.append(targets) From 3e1d91786b26e537e405eb78c8644e7656198898 Mon Sep 17 00:00:00 2001 From: samrat-rm Date: Sat, 14 Mar 2026 03:41:29 +0530 Subject: [PATCH 3/3] remove unused variable uncertainty_scores --- models/model_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/model_test.py b/models/model_test.py index c7e5697..78ffca2 100644 --- a/models/model_test.py +++ b/models/model_test.py @@ -131,7 +131,6 @@ def evaluate_on_test_set( preds = torch.argmax(outputs[0], dim=1).cpu() else: preds = torch.argmax(outputs, dim=1).cpu() - uncertainty_scores = outputs.var(dim=0).mean(dim=1) targets = targets.cpu() all_preds.append(preds) all_targets.append(targets)