Skip to content

Commit

Permalink
Add validation functionality to Trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed May 2, 2024
1 parent 81131ab commit 73ec04c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 13 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ For a small tutorial using some concrete examples see the [`tutorial`][tutorial_
## Evaluation
The Laypa repository also contains a few tools used to evaluate the results generated by the model.
The first tool is a visual comparison between the predictions of the model and the ground truth. This is done as an overlay of the classes over the original image. The overlay class names and colors are taken from the dataset catalog. The tool to do this is [`eval.py`][eval_link]. The visualization has almost the same arguments as the training command ([`main.py`][main_link]).
The first tool is a visual comparison between the predictions of the model and the ground truth. This is done as an overlay of the classes over the original image. The overlay class names and colors are taken from the dataset catalog. The tool to do this is [`visualization.py`][eval_link]. The visualization has almost the same arguments as the training command ([`main.py`][main_link]).
Required arguments:
```sh
python eval.py \
python visualization.py \
-c/--config CONFIG \
-i/--input INPUT [INPUT ...] \
```
Expand All @@ -380,7 +380,7 @@ python eval.py \
Optional arguments:
```sh
python eval.py \
python visualization.py \
-c/--config CONFIG \
-i/--input INPUT [INPUT ...] \
[-o/--output OUTPUT] \
Expand All @@ -394,13 +394,13 @@ python eval.py \
The optional arguments are shown using square brackets. The `-o/output` parameter specifies the output directory for the visualization masks. The `--tmp_dir` parameter specifies a folder in which to store temporary files. While the `--keep_tmp_dir` parameter prevents the temporary files from being deleted after a run (mostly for debugging). The final parameter `--opts` allows you to change values specified in the config files. For example, `--opts SOLVER.IMS_PER_BATCH 8` sets the batch size to 8. The `--sorted` parameter sorts the images based on the order in the operating system. The `--save` parameter specifies what type of file the visualization should be saved as. The options are "pred" for the prediction, "gt" for the ground truth, "both" for both the prediction and the ground truth and "all" for all of the previous. If just `--save` is given the default is "all".
</details>
Example of running [`eval.py`][eval_link]:
Example of running [`visualization.py`][eval_link]:
```sh
python eval.py -c config.yml -i input_dir
python visualization.py -c config.yml -i input_dir
```
The [`eval.py`][eval_link] will then open a window with both the prediction and the ground truth side by side (if the ground truth exists). Allowing for easier comparison. The visualization masks are created in the same way the preprocessing converts pageXML to masks.
The [`visualization.py`][eval_link] will then open a window with both the prediction and the ground truth side by side (if the ground truth exists). Allowing for easier comparison. The visualization masks are created in the same way the preprocessing converts pageXML to masks.
The second tool is a program to compare the similarity of two sets of pageXML. This can mean either comparing ground truth to predicted pageXML, or determining the similarity of two annotations by different people. This tool is the [`xml_comparison.py`][xml_comparison_link] file. The comparison allows you to specify how regions and baseline should be drawn in when creating the pixel masks. The pixel masks are then compared based on their Intersection over Union (IoU) and Accuracy (Acc) scores. For the sake of the Accuracy metric one of the two sets needs to be specified as the ground truth set. So one set is the ground truth directory (`--gt`) argument and the other is the input directory (`--input`) argument.
Expand Down Expand Up @@ -493,7 +493,7 @@ If you discover a bug or missing feature that you would like to help with please
[tutorial_link]: tutorial/
[main_link]: main.py
[run_link]: run.py
[eval_link]: eval.py
[eval_link]: visualization.py
[xml_comparison_link]: xml_comparison.py
[xml_viewer_link]: xml_viewer.py
[start_flask_link]: /api/start_flask.sh
Expand Down
7 changes: 5 additions & 2 deletions core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class Trainer(DefaultTrainer):
Trainer class
"""

def __init__(self, cfg: CfgNode):
def __init__(self, cfg: CfgNode, validation: bool = False):
TrainerBase.__init__(self)

# logger = logging.getLogger("detectron2")
Expand All @@ -219,7 +219,7 @@ def __init__(self, cfg: CfgNode):
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
data_loader = self.build_train_loader(cfg) if not validation else None

model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.MODEL.AMP_TRAIN.ENABLED else SimpleTrainer)(model, data_loader, optimizer)
Expand Down Expand Up @@ -319,3 +319,6 @@ def build_optimizer(cls, cfg, model):
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
return build_lr_scheduler(cfg, optimizer)

def validate(self):
results = self.test(self.cfg, self.model)
60 changes: 60 additions & 0 deletions validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import logging

from detectron2.evaluation import SemSegEvaluator

from core.preprocess import preprocess_datasets
from core.setup import setup_cfg, setup_logging, setup_saving, setup_seed
from core.trainer import Trainer
from utils.logging_utils import get_logger_name
from utils.tempdir import OptionalTemporaryDirectory


def get_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Validation of model compared to ground truth")

detectron2_args = parser.add_argument_group("detectron2")

detectron2_args.add_argument("-c", "--config", help="config file", required=True)
detectron2_args.add_argument("--opts", nargs="+", help="optional args to change", action="extend", default=[])

io_args = parser.add_argument_group("IO")
# io_args.add_argument("-t", "--train", help="Train input folder/file",
# nargs="+", action="extend", type=str, default=None)
io_args.add_argument("-i", "--input", help="Input folder/file", nargs="+", action="extend", type=str, default=None)
io_args.add_argument("-o", "--output", help="Output folder", type=str)

tmp_args = parser.add_argument_group("tmp files")
tmp_args.add_argument("--tmp_dir", help="Temp files folder", type=str, default=None)
tmp_args.add_argument("--keep_tmp_dir", action="store_true", help="Don't remove tmp dir after execution")

parser.add_argument("--sorted", action="store_true", help="Sorted iteration")
parser.add_argument("--save", nargs="?", const="all", default=None, help="Save images instead of displaying")

args = parser.parse_args()

return args


def main(args):

cfg = setup_cfg(args)
setup_logging(cfg)
setup_seed(cfg)
setup_saving(cfg)

logger = logging.getLogger(get_logger_name())

# Temp dir for preprocessing in case no temporary dir was specified
with OptionalTemporaryDirectory(name=args.tmp_dir, cleanup=not (args.keep_tmp_dir)) as tmp_dir:
preprocess_datasets(cfg, None, args.input, tmp_dir)

trainer = Trainer(cfg, validation=True)
trainer.resume_or_load(resume=False)

trainer.validate()


if __name__ == "__main__":
args = get_arguments()
main(args)
5 changes: 1 addition & 4 deletions eval.py → visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from natsort import os_sorted
from tqdm import tqdm

from core.preprocess import preprocess_datasets
from core.setup import setup_cfg
from datasets.dataset import metadata_from_classes
from datasets.mapper import AugInput
from page_xml.xml_converter import XMLConverter
from page_xml.xml_regions import XMLRegions
from run import Predictor
from utils.image_utils import load_image_array_from_path, save_image_array_to_path
from utils.input_utils import get_file_paths, supported_image_formats
Expand All @@ -31,7 +28,7 @@


def get_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Eval of prediction of model using visualizer")
parser = argparse.ArgumentParser(description="Visualization of prediction/GT of model")

detectron2_args = parser.add_argument_group("detectron2")

Expand Down

0 comments on commit 73ec04c

Please sign in to comment.