Skip to content

[Performance] evaluate_on_test_set() iterates over the test set twice #3

@NikitasKrh

Description

@NikitasKrh

Hi Orion Lab team,

I was reading through the evaluation flow (model_training.py -> evaluate_on_test_set() -> validate_all()) to understand how the benchmark metrics are computed, and I noticed that evaluate_on_test_set ends up running the model on the full test set twice. The first pass is the loop inside the function itself (model_test.py):

all_preds = []
all_targets = []
...
with torch.no_grad():
    results=[]
    for i, (inputs, targets) in enumerate(tqdm(test_loader, desc="Inference Progress")):

        outputs = get_preds_multi_encoders(model, inputs, device)
        '''
        inputs = inputs.to(device)
        outputs = model(inputs)
        '''
        if isinstance(outputs, tuple):
            preds = torch.argmax(outputs[0], dim=1).cpu()
        else:
            preds = torch.argmax(outputs, dim=1).cpu()

        targets = targets.cpu()
        all_preds.append(preds)
        all_targets.append(targets)

        if save_inference:
            save_inference_images(i, save_inference_dir, results, inputs, outputs, preds, targets, 
                                  batch_size, test_df, save_logits, num_classes)

and the second pass right after (line 136):

metrics = validate_all(model, test_loader, params_dict)

validate_all() loops over the same test_loader from scratch, runs the model on every batch again, and computes all the final metrics (iou, pixel accuracy, precision, recall, F1). all_preds and all_targets lists from the first loop are never actually used after that. The metrics that get returned and later written to CSV, all come from the second pass. So when save_inference is false, the first loop runs every forward pass for nothing. When save_inference is true, the first loop does save the inference images, but the metrics still require a full second pass. Since all three evaluation entry points (model_test.py, model_training.py, fine_tune_models.py) go through this function, every model evaluation must pay this cost.

My suggestions:
a) One simple solution is to put the first loop inside if save_inference: so it only runs when inference images are actually needed. The validate_all() call stays as-is. This avoids the redundant pass in the default case without changing much (it will barely change the existing behavior I think).
b) A better one I came up with could be to merge the metrics computation into the first loop (take correct_pixels, total_pixels, etc. while iterating) and call calculate_metrics() and not through validate_all(), so the model only touches each batch once regardless of the save_inference flag.

I'd be happy to learn if I'm missing something about the design here. If not, I could follow up with a PR using whichever approach you prefer. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions