From 27a055e63c4c67859044421ac2f800369902b7cb Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 2 Oct 2024 15:34:24 +0000 Subject: [PATCH 1/4] ocr.dataset.data: provide predefined element_length_fn for pipeline.bucket_boundaries --- calamari_ocr/ocr/dataset/data.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/calamari_ocr/ocr/dataset/data.py b/calamari_ocr/ocr/dataset/data.py index 53fcab62..2f46fdce 100644 --- a/calamari_ocr/ocr/dataset/data.py +++ b/calamari_ocr/ocr/dataset/data.py @@ -1,8 +1,9 @@ import logging import os -from typing import Type, Optional +from typing import Callable, Dict, Type, Optional import tensorflow as tf +from tfaip.util.tftyping import AnyTensor from tfaip.data.data import DataBase from tfaip.data.databaseparams import DataPipelineParams from tfaip.data.pipeline.datapipeline import DataPipeline @@ -84,6 +85,11 @@ def _target_layer_specs(self): "gt_len": tf.TensorSpec([1], dtype=tf.int32), } + def element_length_fn(self) -> Callable[[Dict[str, AnyTensor]], AnyTensor]: + def img_len(x): + return x["img_len"] + return img_len + def create_pipeline( self, pipeline_params: DataPipelineParams, From febb2c629f86696cfc7501e3c1eb26afa5d44cb5 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 2 Oct 2024 15:55:39 +0000 Subject: [PATCH 2/4] let black have its friggin empty line --- calamari_ocr/ocr/dataset/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/calamari_ocr/ocr/dataset/data.py b/calamari_ocr/ocr/dataset/data.py index 2f46fdce..6e16e2f3 100644 --- a/calamari_ocr/ocr/dataset/data.py +++ b/calamari_ocr/ocr/dataset/data.py @@ -88,6 +88,7 @@ def _target_layer_specs(self): def element_length_fn(self) -> Callable[[Dict[str, AnyTensor]], AnyTensor]: def img_len(x): return x["img_len"] + return img_len def create_pipeline( From b2dab6c0c7c5a46ea6d17c535d44bf0e800d79f1 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky <38561704+bertsky@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:02:40 +0200 Subject: [PATCH 3/4] test_prediction.py: add test for bucket_boundaries --- calamari_ocr/test/test_prediction.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/calamari_ocr/test/test_prediction.py b/calamari_ocr/test/test_prediction.py index 16a3366f..cf91df47 100644 --- a/calamari_ocr/test/test_prediction.py +++ b/calamari_ocr/test/test_prediction.py @@ -206,6 +206,12 @@ def test_dataset_prediction_voted(self): predictor.benchmark_results.pretty_print() + def test_batch_bucketing(self): + args = predict_args() + args.predictor.pipeline.bucket_boundaries = [20, 50, 100, 200, 400, 800] + run(args) + + if __name__ == "__main__": unittest.main() From 6d63500a7f2aac9e1890c223eeb0058edfef44aa Mon Sep 17 00:00:00 2001 From: Robert Sachunsky <38561704+bertsky@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:09:14 +0200 Subject: [PATCH 4/4] satisfy the stupid linter --- calamari_ocr/test/test_prediction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/calamari_ocr/test/test_prediction.py b/calamari_ocr/test/test_prediction.py index cf91df47..9aed874b 100644 --- a/calamari_ocr/test/test_prediction.py +++ b/calamari_ocr/test/test_prediction.py @@ -212,6 +212,5 @@ def test_batch_bucketing(self): run(args) - if __name__ == "__main__": unittest.main()