Skip to content

Commit

Permalink
Merge pull request #5 from stefanklut/git-hash
Browse files Browse the repository at this point in the history
Git hash
  • Loading branch information
stefanklut authored Nov 22, 2023
2 parents 482da62 + b0b339c commit 8c1e175
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 26 deletions.
27 changes: 19 additions & 8 deletions api/flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def setup_model(self, model_name: str, args: DummyArgs):
regions=cfg.PREPROCESS.REGION.REGIONS,
merge_regions=cfg.PREPROCESS.REGION.MERGE_REGIONS,
region_type=cfg.PREPROCESS.REGION.REGION_TYPE,
cfg=cfg,
whitelist={},
)

self.predictor = Predictor(cfg=cfg)
Expand All @@ -129,7 +131,13 @@ def setup_model(self, model_name: str, args: DummyArgs):
exception_predict_counter = Counter("exception_predict", "Exception thrown in predict() function")


def predict_image(image: np.ndarray | torch.Tensor, image_path: Path, identifier: str, model_name: str) -> dict[str, Any]:
def predict_image(
image: np.ndarray | torch.Tensor,
image_path: Path,
identifier: str,
model_name: str,
whitelist: list[str],
) -> dict[str, Any]:
"""
Run the prediction for the given image
Expand All @@ -156,15 +164,11 @@ def predict_image(image: np.ndarray | torch.Tensor, image_path: Path, identifier
raise TypeError("The current Predictor is not initialized")

predict_gen_page_wrapper.gen_page.set_output_dir(output_path.parent)
predict_gen_page_wrapper.gen_page.set_whitelist(whitelist)
if not output_path.parent.is_dir():
output_path.parent.mkdir()

if isinstance(image, np.ndarray):
outputs = predict_gen_page_wrapper.predictor.cpu_call(image)
elif isinstance(image, torch.Tensor):
outputs = predict_gen_page_wrapper.predictor.gpu_call(image)
else:
raise TypeError(f"Unknown image type: {type(image)}")
outputs = predict_gen_page_wrapper.predictor(image)

output_image = outputs[0]["sem_seg"]
# output_image = torch.argmax(outputs[0]["sem_seg"], dim=-3).cpu().numpy()
Expand All @@ -186,6 +190,7 @@ class ResponseInfo(TypedDict, total=False):
status_code: int
identifier: str
filename: str
whitelist: list[str]
added_queue_position: int
remaining_queue_size: int
added_time: str
Expand Down Expand Up @@ -255,6 +260,12 @@ def predict() -> tuple[Response, int]:
except KeyError as error:
abort_with_info(400, "Missing model in form", response_info)

try:
whitelist = request.form.getlist("whitelist")
response_info["whitelist"] = whitelist
except KeyError as error:
abort_with_info(400, "Missing whitelist in form", response_info)

try:
post_file = request.files["image"]
except KeyError as error:
Expand All @@ -279,7 +290,7 @@ def predict() -> tuple[Response, int]:
if image is None:
abort_with_info(500, "Corrupted image", response_info)

future = executor.submit(predict_image, image, image_name, identifier, model_name)
future = executor.submit(predict_image, image, image_name, identifier, model_name, whitelist)
future.add_done_callback(check_exception_callback)

response_info["status_code"] = 202
Expand Down
18 changes: 12 additions & 6 deletions core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def setup_logging(cfg: Optional[CfgNode] = None, save_log: bool = True) -> loggi
return logger


def get_git_hash() -> str:
version_path = Path("version_info")

if version_path.is_file():
with version_path.open(mode="r") as file:
git_hash = file.read()
else:
git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=Path(__file__).resolve().parent).strip().decode()
return git_hash


# TODO Replace with LazyConfig
def setup_cfg(args, cfg: Optional[CfgNode] = None) -> CfgNode:
"""
Expand Down Expand Up @@ -89,12 +100,7 @@ def setup_cfg(args, cfg: Optional[CfgNode] = None) -> CfgNode:

version_path = Path("version_info")

if version_path.is_file():
with version_path.open(mode="r") as file:
git_hash = file.read()
else:
git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=Path(__file__).resolve().parent).strip().decode()
cfg.LAYPA_GIT_HASH = git_hash
cfg.LAYPA_GIT_HASH = get_git_hash()

cfg.CONFIG_PATH = str(Path(args.config).resolve())

Expand Down
16 changes: 15 additions & 1 deletion page_xml/output_pageXML.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import uuid
from multiprocessing.pool import Pool
from pathlib import Path
from typing import Optional
from typing import Iterable, Optional

import cv2
import numpy as np
import torch
from detectron2.config import CfgNode
from tqdm import tqdm

from core.setup import get_git_hash

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
from page_xml.xml_regions import XMLRegions
from page_xml.xmlPAGE import PageData
Expand Down Expand Up @@ -51,6 +54,8 @@ def __init__(
regions: Optional[list[str]] = None,
merge_regions: Optional[list[str]] = None,
region_type: Optional[list[str]] = None,
cfg: Optional[CfgNode] = None,
whitelist: Optional[Iterable[str]] = None,
) -> None:
"""
Class for the generation of the pageXML from class predictions on images
Expand All @@ -75,6 +80,10 @@ def __init__(

self.regions = self.get_regions()

self.cfg = cfg

self.whitelist = set() if whitelist is None else set(whitelist)

def set_output_dir(self, output_dir: str | Path):
if isinstance(output_dir, str):
output_dir = Path(output_dir)
Expand All @@ -90,6 +99,9 @@ def set_output_dir(self, output_dir: str | Path):
page_dir.mkdir(parents=True)
self.page_dir = page_dir

def set_whitelist(self, whitelist: Iterable[str]):
self.whitelist = set(whitelist)

def link_image(self, image_path: Path):
"""
Symlink image to get the correct output structure
Expand Down Expand Up @@ -144,6 +156,8 @@ def generate_single_page(

page = PageData(xml_output_path)
page.new_page(image_path.name, str(old_height), str(old_width))
if self.cfg is not None:
page.add_processing_step(get_git_hash(), self.cfg.LAYPA_UUID, self.cfg, self.whitelist)

if self.mode == "region":
sem_seg = torch.argmax(sem_seg, dim=-3).cpu().numpy()
Expand Down
87 changes: 76 additions & 11 deletions page_xml/xmlPAGE.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,35 @@
import sys
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import TypedDict
from types import NoneType
from typing import Iterable, TypedDict

import numpy as np
from detectron2.config import CfgNode

sys.path.append(str(Path(__file__).resolve().parent.joinpath("..")))
from utils.logging_utils import get_logger_name
from utils.tempdir import AtomicFileName

_VALID_TYPES = {tuple, list, str, int, float, bool, NoneType}


def convert_to_dict(cfg_node, key_list=[]):
"""Convert a config node to dictionary"""
if not isinstance(cfg_node, CfgNode):
if type(cfg_node) not in _VALID_TYPES:
print(
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_to_dict(v, key_list + [k])
return cfg_dict


class PageData:
"""Class to process PAGE xml files"""
Expand Down Expand Up @@ -211,27 +232,71 @@ def new_page(self, name, rows, cols):
"""create a new PAGE xml"""
self.xml = ET.Element("PcGts")
self.xml.attrib = self.XMLNS
metadata = ET.SubElement(self.xml, "Metadata")
ET.SubElement(metadata, "Creator").text = self.creator
ET.SubElement(metadata, "Created").text = datetime.datetime.today().strftime("%Y-%m-%dT%X")
ET.SubElement(metadata, "LastChange").text = datetime.datetime.today().strftime("%Y-%m-%dT%X")
self.metadata = ET.SubElement(self.xml, "Metadata")
ET.SubElement(self.metadata, "Creator").text = self.creator
ET.SubElement(self.metadata, "Created").text = datetime.datetime.today().strftime("%Y-%m-%dT%X")
ET.SubElement(self.metadata, "LastChange").text = datetime.datetime.today().strftime("%Y-%m-%dT%X")
self.page = ET.SubElement(self.xml, "Page")
self.page.attrib = {
"imageFilename": name,
"imageWidth": cols,
"imageHeight": rows,
}

def add_element(self, r_class, r_id, r_type, r_coords, parent=None):
def add_processing_step(self, git_hash: str, uuid: str, cfg: CfgNode, whitelist: Iterable[str]):
if git_hash is None:
raise TypeError(f"git_hash is None")
if uuid is None:
raise TypeError(f"uuid is None")
if cfg is None:
raise TypeError(f"cfg is None")
if whitelist is None:
raise TypeError(f"whitelist is None")
if self.metadata is None:
raise TypeError(f"self.metadata is None")

processing_step = ET.SubElement(self.metadata, "MetadataItem")
processing_step.attrib = {
"type": "processingStep",
"name": "layout-analysis",
"value": "laypa",
}
labels = ET.SubElement(processing_step, "Labels")
git_hash_element = ET.SubElement(labels, "Label")
git_hash_element.attrib = {
"type": "githash",
"value": git_hash,
}

uuid_element = ET.SubElement(labels, "Label")
uuid_element.attrib = {
"type": "uuid",
"value": uuid,
}

for key in whitelist:
sub_node = cfg
for sub_key in key.split("."):
try:
sub_node = sub_node[sub_key]
except KeyError as error:
self.logger.error(f"No key {key} in config, missing sub key {sub_key}")
raise error
whilelisted_element = ET.SubElement(labels, "Label")
whilelisted_element.attrib = {
"type": key,
"value": str(convert_to_dict(sub_node)),
}

def add_element(self, region_class, region_id, region_type, region_coords, parent=None):
"""add element to parent node"""
parent = self.page if parent == None else parent
t_reg = ET.SubElement(parent, r_class)
t_reg = ET.SubElement(parent, region_class)
t_reg.attrib = {
# "id": "_".join([r_class, str(r_id)]),
"id": str(r_id),
"custom": "".join(["structure {type:", r_type, ";}"]),
"id": str(region_id),
"custom": f"structure {{type:{region_type}, ;}}",
}
ET.SubElement(t_reg, "Coords").attrib = {"points": r_coords}
ET.SubElement(t_reg, "Coords").attrib = {"points": region_coords}
return t_reg

def remove_element(self, element, parent=None):
Expand Down
4 changes: 4 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_arguments() -> argparse.Namespace:
io_args.add_argument("-i", "--input", nargs="+", help="Input folder", type=str, action="extend", required=True)
io_args.add_argument("-o", "--output", help="Output folder", type=str, required=True)

parser.add_argument("-w", "--whitelist", nargs="+", help="Input folder", type=str, action="extend")

args = parser.parse_args()

return args
Expand Down Expand Up @@ -290,6 +292,8 @@ def main(args: argparse.Namespace) -> None:
regions=cfg.PREPROCESS.REGION.REGIONS,
merge_regions=cfg.PREPROCESS.REGION.MERGE_REGIONS,
region_type=cfg.PREPROCESS.REGION.REGION_TYPE,
cfg=cfg,
whitelist=args.whitelist,
)

predictor = SavePredictor(
Expand Down

0 comments on commit 8c1e175

Please sign in to comment.