Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rounding for treatments #225

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
docker build -t doozyx/clang-format-lint-action "github.com/DoozyX/clang-format-lint-action"
- name: Run clang-format lint
run: |
docker run --rm --workdir /src -v $(pwd):/src doozyx/clang-format-lint-action --clang-format-executable /clang-format/clang-format10 -r --exclude .git include/*/*.hpp tests/*.cpp
docker run --rm --workdir /src -v $(pwd):/src doozyx/clang-format-lint-action --clang-format-executable /clang-format/clang-format18 -r --exclude .git include/*/*.hpp tests/*.cpp
45 changes: 32 additions & 13 deletions include/pops/host_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @brief Completely remove any hosts
*
* Removes hosts completely (as opposed to moving them to another pool).
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
Expand All @@ -524,9 +525,6 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @param infected Number of infected hosts to remove.
* @param mortality Number of infected hosts in each mortality cohort.
*
* @note Counts are doubles, so that handling of floating point values is managed
* here in the same way as in the original treatment code.
*
* @note This does not remove resistant just like the original implementation in
* treatments.
*/
Expand Down Expand Up @@ -557,6 +555,11 @@ class HostPool : public HostPoolInterface<RasterIndex>
// Possibly reuse in the I->S removal.
if (infected <= 0)
return;
if (!mortality_tracker_vector_.size()) {
infected_(row, col) -= infected;
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -565,8 +568,8 @@ class HostPool : public HostPoolInterface<RasterIndex>
+ std::to_string(row) + ", " + std::to_string(col) + ")");
}

double mortality_total = 0;
for (size_t i = 0; i < mortality.size(); ++i) {
int mortality_total = 0;
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
if (mortality_tracker_vector_[i](row, col) < mortality[i]) {
throw std::invalid_argument(
"Mortality value [" + std::to_string(i) + "] is too high ("
Expand All @@ -583,20 +586,20 @@ class HostPool : public HostPoolInterface<RasterIndex>
// and once we don't need to keep the exact same double to int results for
// tests. First condition always fails the tests. The second one may potentially
// fail.
if (false && infected != mortality_total) {
if (infected != mortality_total) {
throw std::invalid_argument(
"Total of removed mortality values differs from removed infected "
"count ("
+ std::to_string(mortality_total) + " != " + std::to_string(infected)
+ " for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ") for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ")");
}
if (false && infected_(row, col) < mortality_total) {
if (infected_(row, col) < mortality_total) {
throw std::invalid_argument(
"Total of removed mortality values is higher than current number "
"of infected hosts for cell ("
+ std::to_string(row) + ", " + std::to_string(col) + ") is too high ("
"of infected hosts ("
+ std::to_string(mortality_total) + " > " + std::to_string(infected)
+ ") for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ ")");
}
infected_(row, col) -= infected;
Expand Down Expand Up @@ -701,6 +704,8 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Make hosts resistant in a given cell
*
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
* @param susceptible Number of susceptible hosts to make resistant
Expand Down Expand Up @@ -747,6 +752,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
total_resistant += exposed[i];
}
infected_(row, col) -= infected;
total_resistant += infected;
resistant_(row, col) += total_resistant;
if (!mortality_tracker_vector_.size()) {
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -756,7 +767,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
}
int mortality_total = 0;
// no simple zip in C++, falling back to indices
for (size_t i = 0; i < mortality.size(); ++i) {
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
mortality_tracker_vector_[i](row, col) -= mortality[i];
mortality_total += mortality[i];
}
Expand All @@ -772,8 +783,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
+ " for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ "))");
}
total_resistant += infected;
resistant_(row, col) += total_resistant;
reset_total_host(row, col);
}

/**
Expand Down Expand Up @@ -976,6 +986,9 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Get infected hosts in each mortality cohort at a given cell
*
* If mortality is not active, it returns number of all infected individuals
* in the first and only item of the vector.
*
* @param row Row index of the cell
* @param col Column index of the cell
*
Expand All @@ -984,6 +997,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
std::vector<int> mortality_by_group_at(RasterIndex row, RasterIndex col) const
{
std::vector<int> all;

if (!mortality_tracker_vector_.size()) {
all.push_back(infected_at(row, col));
return all;
}

all.reserve(mortality_tracker_vector_.size());
for (const auto& raster : mortality_tracker_vector_)
all.push_back(raster(row, col));
Expand Down
44 changes: 19 additions & 25 deletions include/pops/treatments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ class BaseTreatment : public AbstractTreatment<HostPool, FloatRaster>
}

// returning double allows identical results with the previous version
double get_treated(int i, int j, int count)
int get_treated(int i, int j, int count)
{
return get_treated(i, j, count, this->application_);
}

double get_treated(int i, int j, int count, TreatmentApplication application)
int get_treated(int i, int j, int count, TreatmentApplication application)
{
if (application == TreatmentApplication::Ratio) {
return count * this->map_(i, j);
return std::lround(count * this->map_(i, j));
}
else if (application == TreatmentApplication::AllInfectedInCell) {
return static_cast<bool>(this->map_(i, j)) ? count : 0;
Expand Down Expand Up @@ -173,20 +173,21 @@ class SimpleTreatment : public BaseTreatment<HostPool, FloatRaster>
for (auto indices : host_pool.suitable_cells()) {
int i = indices[0];
int j = indices[1];
int remove_susceptible = static_cast<int>(std::ceil(this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio)));
int remove_infected = static_cast<int>(
std::ceil(this->get_treated(i, j, host_pool.infected_at(i, j))));
int remove_susceptible = this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio);
// Treated infected are computed as a sum of treated in mortality groups.
int remove_infected = 0;
std::vector<int> remove_mortality;
for (int count : host_pool.mortality_by_group_at(i, j)) {
remove_mortality.push_back(
static_cast<int>(std::ceil(this->get_treated(i, j, count))));
int remove = this->get_treated(i, j, count);
remove_mortality.push_back(remove);
remove_infected += remove;
}
// Will need to use infected directly if not mortality.

std::vector<int> remove_exposed;
for (int count : host_pool.exposed_by_group_at(i, j)) {
remove_exposed.push_back(
static_cast<int>(std::ceil(this->get_treated(i, j, count))));
remove_exposed.push_back(this->get_treated(i, j, count));
}
host_pool.completely_remove_hosts_at(
i,
Expand Down Expand Up @@ -240,26 +241,19 @@ class PesticideTreatment : public BaseTreatment<HostPool, FloatRaster>
for (auto indices : host_pool.suitable_cells()) {
int i = indices[0];
int j = indices[1];
// Given how the original code was written (everything was first converted
// to ints and subtractions happened only afterwards), this needs ints,
// not doubles to pass the r.pops.spread test (unlike the other code which
// did substractions before converting to ints), so the conversion to ints
// happened only later. Now get_treated returns double and floor or ceil is
// applied to the result to get the same results as before.
int susceptible_resistant = static_cast<int>(std::floor(this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio)));
int susceptible_resistant = this->get_treated(
i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio);
std::vector<int> resistant_exposed_list;
for (const auto& number : host_pool.exposed_by_group_at(i, j)) {
resistant_exposed_list.push_back(
static_cast<int>(std::floor(this->get_treated(i, j, number))));
resistant_exposed_list.push_back(this->get_treated(i, j, number));
}
int infected = 0;
std::vector<int> resistant_mortality_list;
for (const auto& number : host_pool.mortality_by_group_at(i, j)) {
resistant_mortality_list.push_back(
static_cast<int>(std::floor(this->get_treated(i, j, number))));
int remove = this->get_treated(i, j, number);
resistant_mortality_list.push_back(remove);
infected += remove;
}
int infected = static_cast<int>(
std::floor(this->get_treated(i, j, host_pool.infected_at(i, j))));
host_pool.make_resistant_at(
i,
j,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_generator_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,8 @@ int test_seed_config_vector_incomplete_list_recognized()
{
int ret = 0;
Config config;
bool thrown = throws_exception<std::invalid_argument>([&config] {
config.read_seeds({1, 2, 3});
});
bool thrown = throws_exception<std::invalid_argument>(
[&config] { config.read_seeds({1, 2, 3}); });
if (!thrown) {
std::cerr << "test_seed_config_vector_incomplete_list_recognized: "
"An incomplete list of seeds wrongly accepted\n";
Expand Down
10 changes: 6 additions & 4 deletions tests/test_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,12 @@ int test_model_sei_deterministic_with_treatments()
for (int row = 0; row < expected_infected.rows(); ++row)
for (int col = 0; col < expected_infected.rows(); ++col)
if (pesticide_treatment(row, col) > 0)
expected_infected(row, col) = static_cast<int>(
std::floor(2 * pesticide_treatment(row, col) * infected(row, col)));
// Valus is based on the result which is considered correct.
Raster<int> expected_dispersers = {{0, 0, 0}, {0, 5, 0}, {0, 0, 2}};
expected_infected(row, col) = std::lround(
2 * pesticide_treatment(row, col) * expected_infected(row, col));
expected_infected(0, 0) += 5; // based on what is considered a correct result
expected_infected(1, 1) -= 5; // based on what is considered a correct result
// Values are based on the result which is considered correct.
Raster<int> expected_dispersers = {{5, 0, 0}, {0, 10, 0}, {0, 0, 2}};

for (unsigned int step = 0; step < config.scheduler().get_num_steps(); ++step) {
model.run_step(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_network_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ int test_model_with_network()
if (original_infected(coords.first, coords.second)
!= infected(coords.first, coords.second)) {
std::cerr << "Infected at: " << coords.first << ", " << coords.second
<< " is different but should be the same"
<< " (is " << original_infected(coords.first, coords.second)
<< " is different but should be the same" << " (is "
<< original_infected(coords.first, coords.second)
<< " but should be " << infected(coords.first, coords.second)
<< ").\n";
ret += 1;
Expand Down
Loading