Skip to content

Commit 0dbe3f8

Browse files
committed
Fix masked contrast enhancement
1 parent 59495e3 commit 0dbe3f8

File tree

2 files changed

+25
-32
lines changed

2 files changed

+25
-32
lines changed

optimap/_cpp/contrast_enhancement.h

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,45 +13,30 @@
1313
#include "definitions.h"
1414
#include "vectors.h"
1515

16-
17-
template <typename Derived, typename Derived2>
18-
auto minmax_kernel(const Derived &block, const Derived2 &kernel) {
19-
static_assert(std::is_same_v<typename Derived2::value_type, bool>);
20-
assert(kernel.shape() == block.shape());
21-
22-
using Scalar = typename Derived::value_type;
23-
Scalar max = std::numeric_limits<Scalar>::lowest();
24-
Scalar min = std::numeric_limits<Scalar>::max();
25-
26-
for (size_t row = 0; row < kernel.shape(0); row++) {
27-
for (size_t col = 0; col < kernel.shape(1); col++) {
28-
if (kernel(row, col)) {
29-
max = std::max(max, block(row, col));
30-
min = std::min(min, block(row, col));
31-
}
32-
}
33-
}
34-
return std::make_pair(min, max);
35-
}
36-
3716
template <typename Derived, typename Derived2, typename Derived3>
38-
auto minmax_kernel_masked(const Derived &block, const Derived2 &kernel, const Derived3 &mask) {
17+
auto minmax_kernel(const Derived &block, const Derived2 &kernel, const Derived3 &mask) {
3918
assert(kernel.shape() == block.shape());
4019
static_assert(std::is_same_v<typename Derived2::value_type, bool>);
41-
static_assert(std::is_same_v<typename Derived2::value_type, typename Derived3::value_type>);
20+
static_assert(std::is_same_v<typename Derived3::value_type, bool>);
4221

4322
using Scalar = typename Derived::value_type;
4423
Scalar max = std::numeric_limits<Scalar>::lowest();
4524
Scalar min = std::numeric_limits<Scalar>::max();
4625

4726
for (size_t row = 0; row < kernel.shape(0); row++) {
4827
for (size_t col = 0; col < kernel.shape(1); col++) {
49-
if (kernel(row, col) && mask(row, col)) {
28+
bool in_mask = mask.size() <= 1 || mask(row, col);
29+
if (kernel(row, col) && in_mask) {
5030
max = std::max(max, block(row, col));
5131
min = std::min(min, block(row, col));
5232
}
5333
}
5434
}
35+
36+
if (max == std::numeric_limits<Scalar>::lowest() || min == std::numeric_limits<Scalar>::max()) {
37+
max = 0;
38+
min = 0;
39+
}
5540
return std::make_pair(min, max);
5641
}
5742

@@ -158,15 +143,11 @@ void _contrast_enhancement_padded(T &out,
158143

159144
const auto block = xt::view(img, xt::range(startx, endx), xt::range(starty, endy));
160145
const auto kernel_block = xt::view(kernel, xt::range(mstartx, mendx), xt::range(mstarty, mendy));
146+
const auto mask_block = (mask.size() > 1)? xt::view(mask, xt::range(startx, endx), xt::range(starty, endy)) : Array2b();
161147

162-
float min, max;
163-
if (mask.size() > 1) {
164-
const auto mask_block = xt::view(mask, xt::range(startx, endx), xt::range(starty, endy));
165-
std::tie(min, max) = minmax_kernel_masked(block, kernel_block, mask_block);
166-
} else {
167-
std::tie(min, max) = minmax_kernel(block, kernel_block);
168-
}
169-
if (max != min) {
148+
const bool in_mask = mask.size() <= 1 || mask(row, col);
149+
auto [min, max] = minmax_kernel(block, kernel_block, mask_block);
150+
if (max != min && in_mask) {
170151
out(row, col) = (img(row, col) - min) / (max - min);
171152
} else {
172153
out(row, col) = 0;

tests/test_motion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ def test_contrast_enhancement():
3232
out2 = om.motion.contrast_enhancement(vid, 3, mask)
3333
assert np.all(out2 == 0)
3434

35+
vid = np.random.random((10, 32, 32)).astype(np.float32)
36+
out = om.motion.contrast_enhancement(vid, 3)
37+
assert out.shape == vid.shape
38+
assert out.min() == 0
39+
assert out.max() == 1
40+
41+
mask = np.random.random((10, 32, 32)) > 0.5
42+
out = om.motion.contrast_enhancement(vid, 3, mask)
43+
assert out.shape == vid.shape
44+
assert out.min() == 0
45+
assert out.max() == 1
46+
3547
def test_flowestimator():
3648
estimator = om.motion.FlowEstimator()
3749
vid = np.ones((10, 128, 128), dtype=np.float32)

0 commit comments

Comments
 (0)