-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
118 lines (90 loc) · 3.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import glob
import pathlib
from typing import Union
import click
import numpy as np
from tqdm.contrib.concurrent import process_map
from src.augmentation.augment import augment
from src.classification.classifier import Classifier
from src.segmentation.character import CharacterSegmenter
from src.segmentation.line import LineSegmenter
from src.utils.font import text_to_font
from src.utils.images import get_name
from src.utils.logger import logger
from src.utils.paths import LINE_SEGMENT_PATH
from src.utils.zip import unzip_all
@click.group()
def cli() -> None:
pass
@cli.command()
def prepare() -> None:
logger.info("Unzipping data")
unzip_all("data/raw/image-data.zip", "data/unpacked")
unzip_all("data/raw/characters.zip", "data/unpacked")
logger.info("Done")
@cli.command()
@click.option("--debug/--no-debug", default=True)
@click.option("--file", default=None)
def linesegment(debug: bool, file: Union[str, None]) -> None:
line_segmenter = LineSegmenter(debug=debug)
if file:
logger.info(f"Starting line segmentation on {file}")
line_segmenter.segment_from_path(file)
else:
logger.info("Starting line segmentation on all binarized images")
binary_files = glob.glob("data/unpacked/image-data/*binarized.jpg")
# concurrent processing
process_map(line_segmenter.segment_from_path, binary_files)
@cli.command()
@click.option("--file", default=None)
@click.option("--debug/--no-debug", default=True)
def charactersegment(file: Union[str, None], debug: bool = False) -> None:
logger.info("Starting character segmentation")
character_segmenter = CharacterSegmenter(debug=debug)
if file:
character_segmenter.segment_from_path(file)
else:
files = glob.glob(f"{LINE_SEGMENT_PATH}/**/*.png")
if not files:
logger.warning("No images with segmented lines, did you run linesegment?")
for file in files:
character_segmenter.segment_from_path(file)
@cli.command(name="augment")
def run_augment() -> None:
augment()
@cli.command()
@click.option("--train/--no-train", default=True)
@click.argument("name")
def train(train: bool, name: str) -> None:
Classifier(train=train, model_filename=name, debug=True)
@cli.command()
@click.argument("directory", default="data/unpacked/image-data/")
@click.option("--o", "out_dir", default="results/")
@click.option("--suffix", default="")
def run(directory: str, out_dir: str, suffix: str) -> None:
classifier = Classifier(train=False, model_filename="augmented_cnn", debug=True)
line_segmenter = LineSegmenter()
character_segmenter = CharacterSegmenter(min_distance=35)
binary_files = glob.glob(f"{directory}/*{suffix}.jpg")
output = ""
for file in binary_files:
out_name = get_name(file)
logger.info(f"Processing {file}")
lines = line_segmenter.segment_from_path(file)
for line in lines:
characters = character_segmenter.segment(line)
if characters:
stacked_characters = np.stack(characters, axis=0)
proba = classifier.predict_batch(stacked_characters.astype(np.int64))
line_predictions = classifier.decode_proba_batch(proba)
for character, _ in line_predictions:
output += f"{character} "
output += "\n"
out_path = pathlib.Path(out_dir)
out_path.mkdir(parents=True, exist_ok=True)
with open(pathlib.Path(out_path, out_name + ".txt"), "w") as out:
mapped_output = text_to_font(output)
out.write(mapped_output)
output = ""
if __name__ == "__main__":
cli()