Skip to content

Commit

Permalink
rework almost done, still has unbound warning
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Dec 6, 2023
1 parent 4c401ce commit 193cfc2
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 149 deletions.
10 changes: 6 additions & 4 deletions core/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,17 @@ def preprocess_datasets(
train_output_dir = output_dir.joinpath("train")
process.set_input_paths(train)
process.set_output_dir(train_output_dir)
train_image_paths = get_file_paths(train, supported_image_formats)
process.run()

if save_image_locations:
if process.input_paths is None:
raise TypeError("Cannot run when the input path is None")
# Saving the images used to a txt file
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
train_image_output_path = Path(cfg.OUTPUT_DIR).joinpath("training_images.txt")

with train_image_output_path.open(mode="w") as f:
for path in train_image_paths:
for path in process.input_paths:
f.write(f"{path}\n")

val_output_dir = None
Expand All @@ -93,16 +94,17 @@ def preprocess_datasets(
val_output_dir = output_dir.joinpath("val")
process.set_input_paths(val)
process.set_output_dir(val_output_dir)
val_image_paths = get_file_paths(val, supported_image_formats)
process.run()

if save_image_locations:
if process.input_paths is None:
raise TypeError("Cannot run when the input path is None")
# Saving the images used to a txt file
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
val_image_output_path = Path(cfg.OUTPUT_DIR).joinpath("validation_images.txt")

with val_image_output_path.open(mode="w") as f:
for path in val_image_paths:
for path in process.input_paths:
f.write(f"{path}\n")

dataset.register_datasets(
Expand Down
266 changes: 133 additions & 133 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

from core.preprocess import preprocess_datasets
from core.setup import setup_cfg
from datasets.dataset import metadata_from_classes
from page_xml.xml_converter import XMLConverter
from run import Predictor
from utils.image_utils import load_image_array_from_path, save_image_array_to_path
from utils.input_utils import get_file_paths, supported_image_formats
from utils.logging_utils import get_logger_name
from utils.path_utils import image_path_to_xml_path
from utils.tempdir import OptionalTemporaryDirectory

logger = logging.getLogger(get_logger_name())
Expand Down Expand Up @@ -90,40 +94,44 @@ def main(args) -> None:
cfg = setup_cfg(args)

with OptionalTemporaryDirectory(name=args.tmp_dir, cleanup=not args.keep_tmp_dir) as tmp_dir:
preprocess_datasets(cfg, None, args.input, tmp_dir, save_image_locations=False)
predictor = Predictor(cfg=cfg)
# preprocess_datasets(cfg, None, args.input, tmp_dir, save_image_locations=False)

xml_converter = XMLConverter(
mode=cfg.MODEL.MODE,
line_width=cfg.PREPROCESS.BASELINE.LINE_WIDTH,
regions=cfg.PREPROCESS.REGION.REGIONS,
merge_regions=cfg.PREPROCESS.REGION.MERGE_REGIONS,
region_type=cfg.PREPROCESS.REGION.REGION_TYPE,
)
metadata = metadata_from_classes(xml_converter.get_regions())

# train_loader = DatasetCatalog.get("train")
val_loader = DatasetCatalog.get("val")
metadata = MetadataCatalog.get("val")
# print(metadata)
image_paths = get_file_paths(args.input, supported_image_formats, cfg.PREPROCESS.DISABLE_CHECK)

predictor = Predictor(cfg=cfg)

@lru_cache(maxsize=10)
def load_image(filename):
image = load_image_array_from_path(filename, mode="color")
def load_image(path):
image = load_image_array_from_path(path, mode="color")
if image is None:
raise TypeError(f"Image {filename} is None, loading failed")
raise TypeError(f"Image {path} is None, loading failed")
return image

@lru_cache(maxsize=10)
def load_sem_seg(filename):
sem_seg_gt = load_image_array_from_path(filename, mode="grayscale")
if sem_seg_gt is None:
raise TypeError(f"Image {filename} is None, loading failed")
return sem_seg_gt

@lru_cache(maxsize=10)
def create_gt_visualization(image_filename, sem_seg_filename):
image = load_image(image_filename)
sem_seg_gt = load_sem_seg(sem_seg_filename)
def create_gt_visualization(image_path):
image = load_image(image_path)
image = predictor.aug.get_transform(image).apply_image(image)
if image is None:
raise ValueError("image can not be None")
sem_seg_gt = xml_path = image_path_to_xml_path(image_path)
sem_seg_gt = xml_converter.to_sem_seg(xml_path, image_shape=(image.shape[0], image.shape[1]))
vis_im_gt = Visualizer(image.copy(), metadata=metadata, scale=1)
vis_im_gt = vis_im_gt.draw_sem_seg(sem_seg_gt, alpha=0.4)
return vis_im_gt.get_image()

@lru_cache(maxsize=10)
def create_pred_visualization(image_filename):
image = load_image(image_filename)
logger.info(f"Predict: {image_filename}")
def create_pred_visualization(image_path):
image = load_image(image_path)
logger.info(f"Predict: {image_path}")
outputs = predictor(image)
sem_seg = outputs[0]["sem_seg"]
sem_seg = torch.nn.functional.interpolate(
Expand All @@ -137,139 +145,131 @@ def create_pred_visualization(image_filename):
vis_im = vis_im.draw_sem_seg(sem_seg, alpha=0.4)
return vis_im.get_image()

fig, axes = plt.subplots(1, 2)
fig.tight_layout()
fig.canvas.mpl_connect("key_press_event", keypress)
fig.canvas.mpl_connect("close_event", on_close)
axes[0].axis("off")
axes[1].axis("off")

fig_manager = None
if not args.save:
fig_manager = plt.get_current_fig_manager()
fig_manager.window.showMaximized()

if args.save:
pbar = tqdm(total=len(val_loader), desc="Saving")

# for i, inputs in enumerate(np.random.choice(val_loader, 3)):
if args.sorted:
loader = os_sorted(val_loader, key=lambda x: x["file_name"])
loader = os_sorted(image_paths)
else:
loader = val_loader
random.shuffle(val_loader)
loader = image_paths
random.shuffle(image_paths)

bad_results = np.zeros(len(loader), dtype=bool)
delete_results = np.zeros(len(loader), dtype=bool)

i = 0
while 0 <= i < len(loader):
inputs = loader[i]

vis_gt = create_gt_visualization(inputs["file_name"], inputs["sem_seg_file_name"])
vis_pred = create_pred_visualization(inputs["original_file_name"])

# pano_gt = torch.IntTensor(rgb2id(cv2.imread(inputs["pan_seg_file_name"], cv2.IMREAD_COLOR)))
# print(inputs["segments_info"])

# vis_im = vis_im.draw_panoptic_seg(outputs["panoptic_seg"][0], outputs["panoptic_seg"][1])
# vis_im_gt = vis_im_gt.draw_panoptic_seg(pano_gt, [item | {"isthing": True} for item in inputs["segments_info"]])
if not args.save:
fig_manager.window.setWindowTitle(inputs["file_name"])

# HACK Just remove the previous axes, I can't find how to resize the image otherwise
axes[0].clear()
axes[1].clear()
axes[0].axis("off")
axes[1].axis("off")

axes[0].imshow(vis_pred)
axes[1].imshow(vis_gt)

if args.save is not None:
# TODO Move saving to separate function + Multiprocessing
pbar.update(1)
if args.save:
for image_path in tqdm(image_paths, desc="Saving Images"):
if args.save not in ["all", "both", "pred", "gt"]:
raise ValueError(f"{args.save} is not a valid save mode")
if args.save != "pred":
vis_gt = create_gt_visualization(image_path)
if args.save != "gt":
vis_pred = create_pred_visualization(image_path)

output_dir = Path(args.output)
if not output_dir.is_dir():
logger.info(f"Could not find output dir ({output_dir}), creating one at specified location")
output_dir.mkdir(parents=True)

if args.save in ["all", "both"]:
save_path = output_dir.joinpath(Path(inputs["file_name"]).stem + "_both.jpg")
# Save to 4K res
fig.set_size_inches(16, 9)
fig.savefig(str(save_path), dpi=240)
save_path = output_dir.joinpath(image_path.stem + "_both.jpg")

vis_gt = cv2.resize(vis_gt, (vis_pred.shape[1], vis_pred.shape[0]), interpolation=cv2.INTER_CUBIC)
save_image_array_to_path(save_path, np.hstack((vis_pred, vis_gt)))
if args.save in ["all", "pred"]:
save_path = output_dir.joinpath(Path(inputs["file_name"]).stem + "_pred.jpg")
save_path = output_dir.joinpath(image_path.stem + "_pred.jpg")
save_image_array_to_path(save_path, vis_pred)
if args.save in ["all", "gt"]:
save_path = output_dir.joinpath(Path(inputs["file_name"]).stem + "_gt.jpg")
save_path = output_dir.joinpath(image_path.stem + "_gt.jpg")
save_image_array_to_path(save_path, vis_gt)

i += 1
continue

if delete_results[i]:
fig.suptitle("Delete")
elif bad_results[i]:
fig.suptitle("Bad")
else:
fig.suptitle("")
# f.title(inputs["file_name"])
global _keypress_result
_keypress_result = None
fig.canvas.draw()
while _keypress_result is None:
plt.waitforbuttonpress()
if _keypress_result == "delete":
# print(i+1, f"{inputs['original_file_name']}: DELETE")
delete_results[i] = not delete_results[i]
bad_results[i] = False
elif _keypress_result == "bad":
# print(i+1, f"{inputs['original_file_name']}: BAD")
bad_results[i] = not bad_results[i]
delete_results[i] = False
elif _keypress_result == "forward":
# print(i+1, f"{inputs['original_file_name']}")
i += 1
elif _keypress_result == "back":
# print(i+1, f"{inputs['original_file_name']}: DELETE")
i -= 1
else:
fig, axes = plt.subplots(1, 2)
fig.tight_layout()
fig.canvas.mpl_connect("key_press_event", keypress)
fig.canvas.mpl_connect("close_event", on_close)
axes[0].axis("off")
axes[1].axis("off")
fig_manager = plt.get_current_fig_manager()
fig_manager.window.showMaximized()

if args.save:
pbar.close()

if args.output and (delete_results.any() or bad_results.any()):
output_dir = Path(args.output)
if not output_dir.is_dir():
logger.info(f"Could not find output dir ({output_dir}), creating one at specified location")
output_dir.mkdir(parents=True)
if delete_results.any():
output_delete = output_dir.joinpath("delete.txt")
with output_delete.open(mode="w") as f:
for i in delete_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")
if bad_results.any():
output_bad = output_dir.joinpath("bad.txt")
with output_bad.open(mode="w") as f:
for i in bad_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")

remaining_results = np.logical_not(np.logical_or(bad_results, delete_results))
if remaining_results.any():
output_remaining = output_dir.joinpath("correct.txt")
with output_remaining.open(mode="w") as f:
for i in remaining_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")
i = 0
while 0 <= i < len(loader):
image_path = loader[i]

vis_gt = create_gt_visualization(image_path)
vis_pred = create_pred_visualization(image_path)

# pano_gt = torch.IntTensor(rgb2id(cv2.imread(inputs["pan_seg_file_name"], cv2.IMREAD_COLOR)))
# print(inputs["segments_info"])

# vis_im = vis_im.draw_panoptic_seg(outputs["panoptic_seg"][0], outputs["panoptic_seg"][1])
# vis_im_gt = vis_im_gt.draw_panoptic_seg(pano_gt, [item | {"isthing": True} for item in inputs["segments_info"]])

fig_manager.window.setWindowTitle(str(image_path))

# HACK Just remove the previous axes, I can't find how to resize the image otherwise
axes[0].clear()
axes[1].clear()
axes[0].axis("off")
axes[1].axis("off")

axes[0].imshow(vis_pred)
axes[1].imshow(vis_gt)

if delete_results[i]:
fig.suptitle("Delete")
elif bad_results[i]:
fig.suptitle("Bad")
else:
fig.suptitle("")
# f.title(inputs["file_name"])
global _keypress_result
_keypress_result = None
fig.canvas.draw()
while _keypress_result is None:
plt.waitforbuttonpress()
if _keypress_result == "delete":
# print(i+1, f"{inputs['original_file_name']}: DELETE")
delete_results[i] = not delete_results[i]
bad_results[i] = False
elif _keypress_result == "bad":
# print(i+1, f"{inputs['original_file_name']}: BAD")
bad_results[i] = not bad_results[i]
delete_results[i] = False
elif _keypress_result == "forward":
# print(i+1, f"{inputs['original_file_name']}")
i += 1
elif _keypress_result == "back":
# print(i+1, f"{inputs['original_file_name']}: DELETE")
i -= 1

if args.output and (delete_results.any() or bad_results.any()):
output_dir = Path(args.output)
if not output_dir.is_dir():
logger.info(f"Could not find output dir ({output_dir}), creating one at specified location")
output_dir.mkdir(parents=True)
if delete_results.any():
output_delete = output_dir.joinpath("delete.txt")
with output_delete.open(mode="w") as f:
for i in delete_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")
if bad_results.any():
output_bad = output_dir.joinpath("bad.txt")
with output_bad.open(mode="w") as f:
for i in bad_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")

remaining_results = np.logical_not(np.logical_or(bad_results, delete_results))
if remaining_results.any():
output_remaining = output_dir.joinpath("correct.txt")
with output_remaining.open(mode="w") as f:
for i in remaining_results.nonzero()[0]:
path = Path(loader[i]["original_file_name"])
line = path.relative_to(output_dir) if path.is_relative_to(output_dir) else path.resolve()
f.write(f"{line}\n")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 193cfc2

Please sign in to comment.