From 699c018a2e5a4d84a901b8480d83569c22cee2f4 Mon Sep 17 00:00:00 2001 From: Vaclav Petras Date: Fri, 20 Dec 2024 15:22:28 -0500 Subject: [PATCH] Consistency between total infected and total over mortality groups requires using mortality groups in calculations. Specifically, treatments use mortality groups directly to consistently apply the treatment function. When mortality is not active, we still need to behave the same, so we now consider all infected to be a part of one mortality group when asked for it. This allows treatments to use mortality groups to get infected even when mortality is disabled. The functions in host pool which accept mortality groups for a cell as an input are now ignoring mortality groups when mortality is not active (before they were relying on input mortality vector being empty). Alternative implementation, or just possible improvement in addtion to this new API, would be to allow treatments to pass a function to reduce the different values, but treatments uses different functions for different pools, so that would require at least two functions passed which is possible, but may be hard to read. --- include/pops/host_pool.hpp | 41 ++++++++++++++------ include/pops/treatments.hpp | 1 + tests/test_treatments.cpp | 77 +++++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/include/pops/host_pool.hpp b/include/pops/host_pool.hpp index 7be1dfd9..ac932700 100644 --- a/include/pops/host_pool.hpp +++ b/include/pops/host_pool.hpp @@ -516,6 +516,7 @@ class HostPool : public HostPoolInterface * @brief Completely remove any hosts * * Removes hosts completely (as opposed to moving them to another pool). + * If mortality is not active, the *mortality* parameter is ignored. * * @param row Row index of the cell * @param col Column index of the cell @@ -524,9 +525,6 @@ class HostPool : public HostPoolInterface * @param infected Number of infected hosts to remove. * @param mortality Number of infected hosts in each mortality cohort. * - * @note Counts are doubles, so that handling of floating point values is managed - * here in the same way as in the original treatment code. - * * @note This does not remove resistant just like the original implementation in * treatments. */ @@ -557,6 +555,11 @@ class HostPool : public HostPoolInterface // Possibly reuse in the I->S removal. if (infected <= 0) return; + if (!mortality_tracker_vector_.size()) { + infected_(row, col) -= infected; + reset_total_host(row, col); + return; + } if (mortality_tracker_vector_.size() != mortality.size()) { throw std::invalid_argument( "mortality is not the same size as the internal mortality tracker (" @@ -566,7 +569,7 @@ class HostPool : public HostPoolInterface } int mortality_total = 0; - for (size_t i = 0; i < mortality.size(); ++i) { + for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) { if (mortality_tracker_vector_[i](row, col) < mortality[i]) { throw std::invalid_argument( "Mortality value [" + std::to_string(i) + "] is too high (" @@ -594,13 +597,13 @@ class HostPool : public HostPoolInterface if (infected_(row, col) < mortality_total) { throw std::invalid_argument( "Total of removed mortality values is higher than current number " - "of infected hosts for cell (" - + std::to_string(row) + ", " + std::to_string(col) + ") is too high (" - + std::to_string(mortality_total) + " > " + std::to_string(infected) - + ")"); + "of infected hosts (" + + std::to_string(mortality_total) + " > " + std::to_string(infected) + + ") for cell (" + std::to_string(row) + ", " + std::to_string(col) + ")"); } infected_(row, col) -= infected; reset_total_host(row, col); + } /** @@ -701,6 +704,8 @@ class HostPool : public HostPoolInterface /** * @brief Make hosts resistant in a given cell * + * If mortality is not active, the *mortality* parameter is ignored. + * * @param row Row index of the cell * @param col Column index of the cell * @param susceptible Number of susceptible hosts to make resistant @@ -747,6 +752,12 @@ class HostPool : public HostPoolInterface total_resistant += exposed[i]; } infected_(row, col) -= infected; + total_resistant += infected; + resistant_(row, col) += total_resistant; + if (!mortality_tracker_vector_.size()) { + reset_total_host(row, col); + return; + } if (mortality_tracker_vector_.size() != mortality.size()) { throw std::invalid_argument( "mortality is not the same size as the internal mortality tracker (" @@ -756,7 +767,7 @@ class HostPool : public HostPoolInterface } int mortality_total = 0; // no simple zip in C++, falling back to indices - for (size_t i = 0; i < mortality.size(); ++i) { + for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) { mortality_tracker_vector_[i](row, col) -= mortality[i]; mortality_total += mortality[i]; } @@ -772,8 +783,7 @@ class HostPool : public HostPoolInterface + " for cell (" + std::to_string(row) + ", " + std::to_string(col) + "))"); } - total_resistant += infected; - resistant_(row, col) += total_resistant; + reset_total_host(row, col); } /** @@ -976,6 +986,9 @@ class HostPool : public HostPoolInterface /** * @brief Get infected hosts in each mortality cohort at a given cell * + * If mortality is not active, it returns number of all infected individuals + * in the first and only item of the vector. + * * @param row Row index of the cell * @param col Column index of the cell * @@ -984,6 +997,12 @@ class HostPool : public HostPoolInterface std::vector mortality_by_group_at(RasterIndex row, RasterIndex col) const { std::vector all; + + if (!mortality_tracker_vector_.size()) { + all.push_back(infected_at(row, col)); + return all; + } + all.reserve(mortality_tracker_vector_.size()); for (const auto& raster : mortality_tracker_vector_) all.push_back(raster(row, col)); diff --git a/include/pops/treatments.hpp b/include/pops/treatments.hpp index 7484c1e8..9494be97 100644 --- a/include/pops/treatments.hpp +++ b/include/pops/treatments.hpp @@ -183,6 +183,7 @@ class SimpleTreatment : public BaseTreatment remove_mortality.push_back(remove); remove_infected += remove; } + // Will need to use infected directly if not mortality. std::vector remove_exposed; for (int count : host_pool.exposed_by_group_at(i, j)) { diff --git a/tests/test_treatments.cpp b/tests/test_treatments.cpp index 3f85f23c..176c9800 100644 --- a/tests/test_treatments.cpp +++ b/tests/test_treatments.cpp @@ -99,6 +99,82 @@ int test_application_ratio() return num_errors; } +/** + * @brief Test treatment without active mortality + * @return Number of errors encountered + */ +int test_application_ratio_without_mortality() +{ + int num_errors = 0; + Scheduler scheduler(Date(2020, 1, 1), Date(2020, 12, 31), StepUnit::Month, 1); + + TestEnvironment environment; + + Raster tr1 = {{1, 0.5}, {0.75, 0}}; + Raster susceptible = {{10, 6}, {20, 42}}; + Raster resistant = {{0, 0}, {0, 0}}; + Raster infected = {{1, 4}, {16, 40}}; + Raster zeros(infected.rows(), infected.cols(), 0); + auto total_hosts = infected + susceptible + resistant; + std::vector> exposed; + std::vector> mortality_tracker; + + std::vector> suitable_cells = {{0, 0}, {0, 1}, {1, 0}, {1, 1}}; + + StandardSingleHostPool host_pool( + ModelType::SusceptibleInfected, + susceptible, + exposed, + 0, + infected, + zeros, + resistant, + mortality_tracker, + zeros, + total_hosts, + environment, + false, + 0, + false, + 0, + infected.rows(), + infected.cols(), + suitable_cells); + + // First, test that host pool works for the case without mortality. + for (int row = 0; row < infected.rows(); ++row) { + for (int col = 0; col < infected.cols(); ++col) { + auto mortality_groups = host_pool.mortality_by_group_at(row, col); + if (mortality_groups.size() != 1) { + std::cerr << "Expected a single mortality group from host pool but got " << + mortality_groups.size() << "\n"; + num_errors++; + } + if (mortality_groups[0] != host_pool.infected_at(row, col)) { + std::cerr << "Host pool does not work as expected: " << + "The single mortality group is diferent from total infected (" << + mortality_groups[0] << " != " << host_pool.infected_at(row, col) << ")\n"; + num_errors++; + } + } + } + + Treatments> treatments(scheduler); + treatments.add_treatment(tr1, Date(2020, 1, 1), 0, TreatmentApplication::Ratio); + treatments.manage(0, host_pool); + + Raster treated = {{0, 3}, {5, 42}}; + Raster inf_treated = {{0, 2}, {4, 40}}; + auto th_treated = treated + inf_treated + resistant; + if (!(susceptible == treated && infected == inf_treated + && total_hosts == th_treated)) { + std::cerr << "Treatment with ratio app does not work without mortality\n"; + std::cerr << susceptible << infected << total_hosts; + num_errors++; + } + return num_errors; +} + int test_application_all_inf() { int num_errors = 0; @@ -722,6 +798,7 @@ int main() int num_errors = 0; num_errors += test_application_ratio(); + num_errors += test_application_ratio_without_mortality(); num_errors += test_application_all_inf(); num_errors += test_application_ratio_pesticide(); num_errors += test_application_all_inf_pesticide();