Skip to content

Commit 75ca98f

Browse files
committed
Implement support for *Templated* detection.
Co-authored By: Tim Walter <tim.michelbach@hotmail.com>
1 parent 77f6fad commit 75ca98f

25 files changed

+4240
-744
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Features
4040
The following colour checker detection algorithms are implemented:
4141

4242
- Segmentation
43+
- Templated
4344
- Machine learning inference via `Ultralytics YOLOv8 <https://github.com/ultralytics/ultralytics>`__
4445

4546
- The model is published on `HuggingFace <https://huggingface.co/colour-science/colour-checker-detection-models>`__,

colour_checker_detection/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,17 @@
3131
SETTINGS_SEGMENTATION_COLORCHECKER_CLASSIC,
3232
SETTINGS_SEGMENTATION_COLORCHECKER_NANO,
3333
SETTINGS_SEGMENTATION_COLORCHECKER_SG,
34+
SETTINGS_TEMPLATED_COLORCHECKER_CLASSIC,
3435
detect_colour_checkers_inference,
3536
detect_colour_checkers_segmentation,
37+
detect_colour_checkers_templated,
38+
extractor_inference,
39+
extractor_segmentation,
40+
extractor_templated,
3641
inferencer_default,
42+
plot_detection_results,
3743
segmenter_default,
44+
segmenter_templated,
3845
)
3946

4047
__author__ = "Colour Developers"
@@ -50,10 +57,17 @@
5057
"SETTINGS_SEGMENTATION_COLORCHECKER_CLASSIC",
5158
"SETTINGS_SEGMENTATION_COLORCHECKER_NANO",
5259
"SETTINGS_SEGMENTATION_COLORCHECKER_SG",
60+
"SETTINGS_TEMPLATED_COLORCHECKER_CLASSIC",
5361
"detect_colour_checkers_inference",
5462
"detect_colour_checkers_segmentation",
63+
"detect_colour_checkers_templated",
64+
"extractor_inference",
65+
"extractor_segmentation",
66+
"extractor_templated",
5567
"inferencer_default",
68+
"plot_detection_results",
5669
"segmenter_default",
70+
"segmenter_templated",
5771
]
5872

5973
ROOT_RESOURCES: str = os.path.join(os.path.dirname(__file__), "resources")

colour_checker_detection/detection/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
SETTINGS_DETECTION_COLORCHECKER_CLASSIC,
1717
SETTINGS_DETECTION_COLORCHECKER_SG,
1818
DataDetectionColourChecker,
19+
DataSegmentationColourCheckers,
1920
approximate_contour,
2021
as_float32_array,
2122
as_int32_array,
23+
cluster_swatches,
2224
contour_centroid,
2325
detect_contours,
26+
filter_clusters,
2427
is_square,
2528
quadrilateralise_contours,
2629
reformat_image,
@@ -38,6 +41,7 @@
3841
SETTINGS_INFERENCE_COLORCHECKER_CLASSIC,
3942
SETTINGS_INFERENCE_COLORCHECKER_CLASSIC_MINI,
4043
detect_colour_checkers_inference,
44+
extractor_inference,
4145
inferencer_default,
4246
)
4347

@@ -48,21 +52,47 @@
4852
SETTINGS_SEGMENTATION_COLORCHECKER_NANO,
4953
SETTINGS_SEGMENTATION_COLORCHECKER_SG,
5054
detect_colour_checkers_segmentation,
55+
extractor_segmentation,
5156
segmenter_default,
5257
)
5358

59+
# isort: split
60+
61+
from .templated import (
62+
SETTINGS_TEMPLATED_COLORCHECKER_CLASSIC,
63+
WarpingData,
64+
detect_colour_checkers_templated,
65+
extractor_templated,
66+
segmenter_templated,
67+
)
68+
from .templates import (
69+
PATH_TEMPLATE_COLORCHECKER_CLASSIC,
70+
PATH_TEMPLATE_COLORCHECKER_CREATIVE_ENHANCEMENT,
71+
ROOT_TEMPLATES,
72+
Template,
73+
generate_template,
74+
load_template,
75+
)
76+
77+
# isort: split
78+
79+
from .plotting import plot_detection_results
80+
5481
__all__ = [
5582
"DTYPE_FLOAT_DEFAULT",
5683
"DTYPE_INT_DEFAULT",
5784
"SETTINGS_CONTOUR_DETECTION_DEFAULT",
5885
"SETTINGS_DETECTION_COLORCHECKER_CLASSIC",
5986
"SETTINGS_DETECTION_COLORCHECKER_SG",
6087
"DataDetectionColourChecker",
88+
"DataSegmentationColourCheckers",
6189
"approximate_contour",
6290
"as_float32_array",
6391
"as_int32_array",
92+
"cluster_swatches",
6493
"contour_centroid",
6594
"detect_contours",
95+
"filter_clusters",
6696
"is_square",
6797
"quadrilateralise_contours",
6898
"reformat_image",
@@ -77,12 +107,32 @@
77107
"SETTINGS_INFERENCE_COLORCHECKER_CLASSIC",
78108
"SETTINGS_INFERENCE_COLORCHECKER_CLASSIC_MINI",
79109
"detect_colour_checkers_inference",
110+
"extractor_inference",
80111
"inferencer_default",
81112
]
82113
__all__ += [
83114
"SETTINGS_SEGMENTATION_COLORCHECKER_CLASSIC",
84115
"SETTINGS_SEGMENTATION_COLORCHECKER_NANO",
85116
"SETTINGS_SEGMENTATION_COLORCHECKER_SG",
86117
"detect_colour_checkers_segmentation",
118+
"extractor_segmentation",
87119
"segmenter_default",
88120
]
121+
__all__ += [
122+
"detect_colour_checkers_templated",
123+
"extractor_templated",
124+
"segmenter_templated",
125+
"SETTINGS_TEMPLATED_COLORCHECKER_CLASSIC",
126+
"WarpingData",
127+
]
128+
__all__ += [
129+
"Template",
130+
"generate_template",
131+
"load_template",
132+
"ROOT_TEMPLATES",
133+
"PATH_TEMPLATE_COLORCHECKER_CLASSIC",
134+
"PATH_TEMPLATE_COLORCHECKER_CREATIVE_ENHANCEMENT",
135+
]
136+
__all__ += [
137+
"plot_detection_results",
138+
]

colour_checker_detection/detection/common.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import typing
2323
from dataclasses import dataclass
24+
from itertools import combinations
2425

2526
import cv2
2627
import numpy as np
@@ -35,7 +36,6 @@
3536
DTypeFloat,
3637
DTypeInt,
3738
Literal,
38-
NDArrayReal,
3939
Type,
4040
)
4141

@@ -75,13 +75,17 @@
7575
"reformat_image",
7676
"transform_image",
7777
"detect_contours",
78+
"is_quadrilateral",
7879
"is_square",
7980
"contour_centroid",
8081
"scale_contour",
82+
"cluster_swatches",
83+
"filter_clusters",
8184
"approximate_contour",
8285
"quadrilateralise_contours",
8386
"remove_stacked_contours",
8487
"DataDetectionColourChecker",
88+
"DataSegmentationColourCheckers",
8589
"sample_colour_checker",
8690
]
8791

@@ -664,6 +668,44 @@ def detect_contours(
664668
return contours
665669

666670

671+
def is_quadrilateral(points: NDArrayFloat) -> bool:
672+
"""
673+
Check if points form a quadrilateral (no three points are collinear).
674+
675+
Parameters
676+
----------
677+
points
678+
Points to check (should be 4 points).
679+
680+
Returns
681+
-------
682+
:class:`bool`
683+
True if points form a quadrilateral (no three collinear), False otherwise.
684+
685+
Notes
686+
-----
687+
This function checks that no three points are collinear, which ensures
688+
the 4 points form a proper quadrilateral suitable for perspective transformation.
689+
690+
Examples
691+
--------
692+
>>> points = np.array([[0, 0], [10, 0], [10, 10], [0, 10]], dtype=float)
693+
>>> is_quadrilateral(points)
694+
True
695+
>>> points = np.array([[0, 0], [5, 0], [10, 0], [0, 10]], dtype=float)
696+
>>> is_quadrilateral(points) # Three points collinear
697+
False
698+
"""
699+
700+
for pts in combinations(points, 3):
701+
matrix = np.column_stack((pts, np.ones(len(pts))))
702+
703+
if np.linalg.matrix_rank(matrix) < 3:
704+
return False
705+
706+
return True
707+
708+
667709
def is_square(contour: ArrayLike, tolerance: float = 0.015) -> bool:
668710
"""
669711
Return if specified contour is a square.
@@ -775,6 +817,140 @@ def scale_contour(contour: ArrayLike, factor: ArrayLike) -> NDArrayFloat:
775817
return (contour - centroid) * factor + centroid
776818

777819

820+
def cluster_swatches(
821+
image: NDArrayFloat, swatches: NDArrayInt, swatch_contour_scale: float
822+
) -> NDArrayInt:
823+
"""
824+
Cluster swatches by expanding them and fitting rectangles to overlapping areas.
825+
826+
Parameters
827+
----------
828+
image
829+
Image containing the swatches. Only used for its shape.
830+
swatches
831+
The swatches to cluster.
832+
swatch_contour_scale
833+
The scale by which to expand the swatches.
834+
835+
Returns
836+
-------
837+
:class:`NDArrayInt`
838+
The clusters of swatches.
839+
840+
Examples
841+
--------
842+
>>> import numpy as np
843+
>>> image = np.zeros((600, 900, 3))
844+
>>> swatches = np.array(
845+
... [
846+
... [[100, 100], [200, 100], [200, 200], [100, 200]],
847+
... [[300, 100], [400, 100], [400, 200], [300, 200]],
848+
... ],
849+
... dtype=np.int32,
850+
... )
851+
>>> cluster_swatches(image, swatches, 1.5)
852+
array([[[275, 75],
853+
[425, 75],
854+
[425, 225],
855+
[275, 225]],
856+
<BLANKLINE>
857+
[[ 75, 75],
858+
[225, 75],
859+
[225, 225],
860+
[ 75, 225]]], dtype=int32)
861+
"""
862+
863+
scaled_swatches = [
864+
scale_contour(swatch, swatch_contour_scale) for swatch in swatches
865+
]
866+
image_c = np.zeros(image.shape[:2], dtype=np.uint8)
867+
868+
cv2.drawContours(
869+
image_c, [as_int32_array(s) for s in scaled_swatches], -1, (255,), -1
870+
)
871+
872+
contours, _ = cv2.findContours(image_c, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
873+
874+
return as_int32_array(
875+
[cv2.boxPoints(cv2.minAreaRect(contour)) for contour in contours]
876+
)
877+
878+
879+
def filter_clusters(
880+
clusters: NDArrayInt,
881+
swatches: NDArrayInt,
882+
swatches_count_minimum: int,
883+
swatches_count_maximum: int,
884+
) -> NDArrayInt:
885+
"""
886+
Filter clusters by the number of swatches they contain.
887+
888+
Parameters
889+
----------
890+
clusters
891+
The clusters to filter.
892+
swatches
893+
The swatches to count within each cluster.
894+
swatches_count_minimum
895+
Minimum number of swatches required in a cluster.
896+
swatches_count_maximum
897+
Maximum number of swatches allowed in a cluster.
898+
899+
Returns
900+
-------
901+
:class:`NDArrayInt`
902+
The filtered clusters that contain the expected number of swatches.
903+
904+
Examples
905+
--------
906+
>>> import numpy as np
907+
>>> clusters = np.array(
908+
... [
909+
... [[0, 0], [200, 0], [200, 200], [0, 200]],
910+
... [[300, 300], [400, 300], [400, 400], [300, 400]],
911+
... ],
912+
... dtype=np.int32,
913+
... )
914+
>>> swatches = np.array(
915+
... [
916+
... [[50, 50], [100, 50], [100, 100], [50, 100]],
917+
... [[350, 350], [380, 350], [380, 380], [350, 380]],
918+
... ],
919+
... dtype=np.int32,
920+
... )
921+
>>> filter_clusters(clusters, swatches, 1, 2)
922+
array([[[ 0, 0],
923+
[200, 0],
924+
[200, 200],
925+
[ 0, 200]],
926+
<BLANKLINE>
927+
[[300, 300],
928+
[400, 300],
929+
[400, 400],
930+
[300, 400]]], dtype=int32)
931+
"""
932+
933+
if len(clusters) == 0 or len(swatches) == 0:
934+
return as_int32_array([]).reshape(0, 4, 2)
935+
936+
filtered_clusters = []
937+
for cluster in clusters:
938+
count = 0
939+
for swatch in swatches:
940+
centroid = contour_centroid(swatch)
941+
if cv2.pointPolygonTest(cluster, centroid, False) >= 0:
942+
count += 1
943+
944+
if swatches_count_minimum <= count <= swatches_count_maximum:
945+
filtered_clusters.append(cluster)
946+
947+
return (
948+
as_int32_array(filtered_clusters)
949+
if len(filtered_clusters) > 0
950+
else as_int32_array([]).reshape(0, 4, 2)
951+
)
952+
953+
778954
def approximate_contour(
779955
contour: ArrayLike, points: int = 4, iterations: int = 100
780956
) -> NDArrayInt:
@@ -991,6 +1167,31 @@ class DataDetectionColourChecker(MixinDataclassIterable):
9911167
quadrilateral: NDArrayFloat
9921168

9931169

1170+
@dataclass
1171+
class DataSegmentationColourCheckers(MixinDataclassIterable):
1172+
"""
1173+
Colour checkers detection data used for plotting, debugging and further
1174+
analysis.
1175+
1176+
Parameters
1177+
----------
1178+
rectangles
1179+
Colour checker bounding boxes, i.e., the clusters that have the
1180+
relevant count of swatches.
1181+
clusters
1182+
Detected swatches clusters.
1183+
swatches
1184+
Detected swatches.
1185+
segmented_image
1186+
Segmented image.
1187+
"""
1188+
1189+
rectangles: NDArrayInt
1190+
clusters: NDArrayInt
1191+
swatches: NDArrayInt
1192+
segmented_image: NDArrayFloat
1193+
1194+
9941195
def sample_colour_checker(
9951196
image: ArrayLike,
9961197
quadrilateral: ArrayLike,

0 commit comments

Comments
 (0)