Skip to content

Commit

Permalink
Fix masked contrast enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
sitic committed Mar 8, 2024
1 parent 59495e3 commit 0dbe3f8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 32 deletions.
45 changes: 13 additions & 32 deletions optimap/_cpp/contrast_enhancement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,30 @@
#include "definitions.h"
#include "vectors.h"


template <typename Derived, typename Derived2>
auto minmax_kernel(const Derived &block, const Derived2 &kernel) {
static_assert(std::is_same_v<typename Derived2::value_type, bool>);
assert(kernel.shape() == block.shape());

using Scalar = typename Derived::value_type;
Scalar max = std::numeric_limits<Scalar>::lowest();
Scalar min = std::numeric_limits<Scalar>::max();

for (size_t row = 0; row < kernel.shape(0); row++) {
for (size_t col = 0; col < kernel.shape(1); col++) {
if (kernel(row, col)) {
max = std::max(max, block(row, col));
min = std::min(min, block(row, col));
}
}
}
return std::make_pair(min, max);
}

template <typename Derived, typename Derived2, typename Derived3>
auto minmax_kernel_masked(const Derived &block, const Derived2 &kernel, const Derived3 &mask) {
auto minmax_kernel(const Derived &block, const Derived2 &kernel, const Derived3 &mask) {
assert(kernel.shape() == block.shape());
static_assert(std::is_same_v<typename Derived2::value_type, bool>);
static_assert(std::is_same_v<typename Derived2::value_type, typename Derived3::value_type>);
static_assert(std::is_same_v<typename Derived3::value_type, bool>);

using Scalar = typename Derived::value_type;
Scalar max = std::numeric_limits<Scalar>::lowest();
Scalar min = std::numeric_limits<Scalar>::max();

for (size_t row = 0; row < kernel.shape(0); row++) {
for (size_t col = 0; col < kernel.shape(1); col++) {
if (kernel(row, col) && mask(row, col)) {
bool in_mask = mask.size() <= 1 || mask(row, col);
if (kernel(row, col) && in_mask) {
max = std::max(max, block(row, col));
min = std::min(min, block(row, col));
}
}
}

if (max == std::numeric_limits<Scalar>::lowest() || min == std::numeric_limits<Scalar>::max()) {
max = 0;
min = 0;
}
return std::make_pair(min, max);
}

Expand Down Expand Up @@ -158,15 +143,11 @@ void _contrast_enhancement_padded(T &out,

const auto block = xt::view(img, xt::range(startx, endx), xt::range(starty, endy));
const auto kernel_block = xt::view(kernel, xt::range(mstartx, mendx), xt::range(mstarty, mendy));
const auto mask_block = (mask.size() > 1)? xt::view(mask, xt::range(startx, endx), xt::range(starty, endy)) : Array2b();

float min, max;
if (mask.size() > 1) {
const auto mask_block = xt::view(mask, xt::range(startx, endx), xt::range(starty, endy));
std::tie(min, max) = minmax_kernel_masked(block, kernel_block, mask_block);
} else {
std::tie(min, max) = minmax_kernel(block, kernel_block);
}
if (max != min) {
const bool in_mask = mask.size() <= 1 || mask(row, col);
auto [min, max] = minmax_kernel(block, kernel_block, mask_block);
if (max != min && in_mask) {
out(row, col) = (img(row, col) - min) / (max - min);
} else {
out(row, col) = 0;
Expand Down
12 changes: 12 additions & 0 deletions tests/test_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def test_contrast_enhancement():
out2 = om.motion.contrast_enhancement(vid, 3, mask)
assert np.all(out2 == 0)

vid = np.random.random((10, 32, 32)).astype(np.float32)
out = om.motion.contrast_enhancement(vid, 3)
assert out.shape == vid.shape
assert out.min() == 0
assert out.max() == 1

mask = np.random.random((10, 32, 32)) > 0.5
out = om.motion.contrast_enhancement(vid, 3, mask)
assert out.shape == vid.shape
assert out.min() == 0
assert out.max() == 1

def test_flowestimator():
estimator = om.motion.FlowEstimator()
vid = np.ones((10, 128, 128), dtype=np.float32)
Expand Down

0 comments on commit 0dbe3f8

Please sign in to comment.