Skip to content

Commit

Permalink
refactor: ColorMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Mar 27, 2024
1 parent e675a5d commit a153b3f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 49 deletions.
47 changes: 24 additions & 23 deletions source/MaaFramework/Vision/ColorMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,18 @@

MAA_VISION_NS_BEGIN

std::pair<ColorMatcher::ResultsVec, size_t> ColorMatcher::analyze() const
ColorMatcher::ColorMatcher(ColorMatcherParam param, cv::Mat image, std::string name)
: VisionBase(std::move(image), std::move(name))
, param_(std::move(param))
{
ResultsVec all_results;
}

void ColorMatcher::analyze()
{
if (analyzed_) {
return;
}
analyzed_ = true;

for (const auto& range : param_.range) {
auto start_time = std::chrono::steady_clock::now();
Expand All @@ -21,6 +30,7 @@ std::pair<ColorMatcher::ResultsVec, size_t> ColorMatcher::analyze() const
auto cost = duration_since(start_time);
LogTrace << name_ << "Raw:" << VAR(results) << VAR(range.first) << VAR(range.second)
<< VAR(connected) << VAR(cost);
raw_results_.insert(raw_results_.end(), results.begin(), results.end());

int count = param_.count;
filter(results, count);
Expand All @@ -29,20 +39,18 @@ std::pair<ColorMatcher::ResultsVec, size_t> ColorMatcher::analyze() const
LogTrace << name_ << "Filter:" << VAR(results) << VAR(range.first) << VAR(range.second)
<< VAR(count) << VAR(connected) << VAR(cost);

all_results.insert(
all_results.end(),
filtered_results_.insert(
filtered_results_.end(),
std::make_move_iterator(results.begin()),
std::make_move_iterator(results.end()));
}

sort(all_results);
size_t index = preferred_index(all_results);

return { all_results, index };
sort(filtered_results_);
handle_index(filtered_results_.size(), param_.result_index);
}

ColorMatcher::ResultsVec
ColorMatcher::foreach_rois(const ColorMatcherParam::Range& range, bool connected) const
ColorMatcher::foreach_rois(const ColorMatcherParam::Range& range, bool connected)
{
if (param_.roi.empty()) {
return { color_match(cv::Rect(0, 0, image_.cols, image_.rows), range, connected) };
Expand All @@ -63,7 +71,7 @@ ColorMatcher::ResultsVec
ColorMatcher::ResultsVec ColorMatcher::color_match(
const cv::Rect& roi,
const ColorMatcherParam::Range& range,
bool connected) const
bool connected)
{
cv::Mat image = image_with_roi(roi);
cv::Mat color;
Expand All @@ -74,7 +82,10 @@ ColorMatcher::ResultsVec ColorMatcher::color_match(
ResultsVec results =
connected ? count_non_zero_with_connected(bin, roi.tl()) : count_non_zero(bin, roi.tl());

draw_result(roi, color, bin, results);
if (debug_draw_) {
auto draw = draw_result(roi, color, bin, results);
handle_draw(draw);
}
return results;
}

Expand Down Expand Up @@ -114,7 +125,7 @@ ColorMatcher::ResultsVec
return NMS_for_count(std::move(results), 0.7);
}

void ColorMatcher::draw_result(
cv::Mat ColorMatcher::draw_result(
const cv::Rect& roi,
const cv::Mat& color,
const cv::Mat& bin,
Expand Down Expand Up @@ -174,7 +185,7 @@ void ColorMatcher::draw_result(

// cv::line(image_draw, cv::Point(raw_width + color.cols, 0), res.box.tl(), color_draw, 1);

handle_draw(image_draw);
return image_draw;
}

void ColorMatcher::filter(ResultsVec& results, int count) const
Expand Down Expand Up @@ -206,14 +217,4 @@ void ColorMatcher::sort(ResultsVec& results) const
}
}

size_t ColorMatcher::preferred_index(const ResultsVec& results) const
{
auto index_opt = pythonic_index(results.size(), param_.result_index);
if (!index_opt) {
return SIZE_MAX;
}

return *index_opt;
}

MAA_VISION_NS_END
26 changes: 16 additions & 10 deletions source/MaaFramework/Vision/ColorMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,35 @@ class ColorMatcher : public VisionBase
using ResultsVec = std::vector<Result>;

public:
void set_param(ColorMatcherParam param) { param_ = std::move(param); }
ColorMatcher(ColorMatcherParam param, cv::Mat image, std::string name);

std::pair<ResultsVec, size_t> analyze() const;
void analyze();

const ResultsVec& raw_results() const { return raw_results_; }

const ResultsVec& filtered_results() const { return filtered_results_; }

private:
ResultsVec foreach_rois(const ColorMatcherParam::Range& range, bool connected) const;
ResultsVec color_match(
const cv::Rect& roi,
const ColorMatcherParam::Range& range,
bool connected) const;
ResultsVec foreach_rois(const ColorMatcherParam::Range& range, bool connected);
ResultsVec
color_match(const cv::Rect& roi, const ColorMatcherParam::Range& range, bool connected);
ResultsVec count_non_zero(const cv::Mat& bin, const cv::Point& tl) const;
ResultsVec count_non_zero_with_connected(const cv::Mat& bin, const cv::Point& tl) const;
void draw_result(
cv::Mat draw_result(
const cv::Rect& roi,
const cv::Mat& color,
const cv::Mat& bin,
const ResultsVec& results) const;

void filter(ResultsVec& results, int count) const;
void sort(ResultsVec& results) const;
size_t preferred_index(const ResultsVec& results) const;

ColorMatcherParam param_;
private:
const ColorMatcherParam param_;

bool analyzed_ = false;
ResultsVec raw_results_;
ResultsVec filtered_results_;
};

MAA_VISION_NS_END
30 changes: 20 additions & 10 deletions source/MaaFramework/Vision/VisionBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@

MAA_VISION_NS_BEGIN

void VisionBase::set_image(const cv::Mat& image)
VisionBase::VisionBase(cv::Mat image, std::string name)
: image_(std::move(image))
, name_(std::move(name))
{
image_ = image;
init_debug_draw();
}

void VisionBase::set_name(std::string name)
{
name_ = std::move(name);
}

cv::Mat VisionBase::image_with_roi(const cv::Rect& roi) const
{
cv::Rect roi_corrected = correct_roi(roi, image_);
Expand Down Expand Up @@ -56,19 +52,33 @@ cv::Mat VisionBase::draw_roi(const cv::Rect& roi, const cv::Mat& base) const
return image_draw;
}

void VisionBase::handle_draw(const cv::Mat& draw) const
void VisionBase::handle_draw(const cv::Mat& draw)
{
draws_.emplace_back(draw);

if (save_draw_) {
save_image(draw);
draw_paths_.emplace_back(save_image(draw));
}
}

void VisionBase::save_image(const cv::Mat& image) const
void VisionBase::handle_index(size_t total, int index)
{
auto index_opt = pythonic_index(total, index);
if (!index_opt) {
preferred_index_ = SIZE_MAX;
return;
}

preferred_index_ = *index_opt;
}

std::filesystem::path VisionBase::save_image(const cv::Mat& image) const
{
std::string filename = std::format("{}_{}.png", name_, format_now_for_filename());
auto filepath = GlobalOptionMgr::get_instance().log_dir() / "vision" / path(filename);
MAA_NS::imwrite(filepath, image);
LogDebug << "save image to" << filepath;
return filepath;
}

void VisionBase::init_debug_draw()
Expand Down
23 changes: 17 additions & 6 deletions source/MaaFramework/Vision/VisionBase.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <filesystem>

#include "Conf/Conf.h"
#include "Utils/JsonExt.hpp"
#include "Utils/NoWarningCVMat.hpp"
Expand All @@ -9,28 +11,37 @@ MAA_VISION_NS_BEGIN
class VisionBase
{
public:
void set_image(const cv::Mat& image);
void set_name(std::string name);
VisionBase(cv::Mat image, std::string name);

const std::vector<cv::Mat>& draws() const { return draws_; }

const std::vector<std::filesystem::path>& draw_paths() const { return draw_paths_; }

size_t preferred_index() const { return preferred_index_; }

protected:
cv::Mat image_with_roi(const cv::Rect& roi) const;

protected:
cv::Mat draw_roi(const cv::Rect& roi, const cv::Mat& base = cv::Mat()) const;
void handle_draw(const cv::Mat& draw) const;
void handle_draw(const cv::Mat& draw);
void handle_index(size_t total, int index);

protected:
cv::Mat image_ {};
std::string name_;
const cv::Mat image_;
const std::string name_;

bool debug_draw_ = false;

private:
void init_debug_draw();
void save_image(const cv::Mat& image) const;
std::filesystem::path save_image(const cv::Mat& image) const;

private:
bool save_draw_ = false;
std::vector<cv::Mat> draws_;
std::vector<std::filesystem::path> draw_paths_;
size_t preferred_index_ = SIZE_MAX;
};

MAA_VISION_NS_END

0 comments on commit a153b3f

Please sign in to comment.