Skip to content

Commit

Permalink
Consistency between total infected and total over mortality groups re…
Browse files Browse the repository at this point in the history
…quires using mortality groups in calculations. Specifically, treatments use mortality groups directly to consistently apply the treatment function. When mortality is not active, we still need to behave the same, so we now consider all infected to be a part of one mortality group when asked for it. This allows treatments to use mortality groups to get infected even when mortality is disabled. The functions in host pool which accept mortality groups for a cell as an input are now ignoring mortality groups when mortality is not active (before they were relying on input mortality vector being empty). Alternative implementation, or just possible improvement in addtion to this new API, would be to allow treatments to pass a function to reduce the different values, but treatments uses different functions for different pools, so that would require at least two functions passed which is possible, but may be hard to read.
  • Loading branch information
wenzeslaus committed Dec 20, 2024
1 parent e65eb93 commit 699c018
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 11 deletions.
41 changes: 30 additions & 11 deletions include/pops/host_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @brief Completely remove any hosts
*
* Removes hosts completely (as opposed to moving them to another pool).
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
Expand All @@ -524,9 +525,6 @@ class HostPool : public HostPoolInterface<RasterIndex>
* @param infected Number of infected hosts to remove.
* @param mortality Number of infected hosts in each mortality cohort.
*
* @note Counts are doubles, so that handling of floating point values is managed
* here in the same way as in the original treatment code.
*
* @note This does not remove resistant just like the original implementation in
* treatments.
*/
Expand Down Expand Up @@ -557,6 +555,11 @@ class HostPool : public HostPoolInterface<RasterIndex>
// Possibly reuse in the I->S removal.
if (infected <= 0)
return;
if (!mortality_tracker_vector_.size()) {
infected_(row, col) -= infected;
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -566,7 +569,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
}

int mortality_total = 0;
for (size_t i = 0; i < mortality.size(); ++i) {
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
if (mortality_tracker_vector_[i](row, col) < mortality[i]) {
throw std::invalid_argument(
"Mortality value [" + std::to_string(i) + "] is too high ("
Expand Down Expand Up @@ -594,13 +597,13 @@ class HostPool : public HostPoolInterface<RasterIndex>
if (infected_(row, col) < mortality_total) {
throw std::invalid_argument(
"Total of removed mortality values is higher than current number "
"of infected hosts for cell ("
+ std::to_string(row) + ", " + std::to_string(col) + ") is too high ("
+ std::to_string(mortality_total) + " > " + std::to_string(infected)
+ ")");
"of infected hosts ("
+ std::to_string(mortality_total) + " > " + std::to_string(infected)
+ ") for cell (" + std::to_string(row) + ", " + std::to_string(col) + ")");
}
infected_(row, col) -= infected;
reset_total_host(row, col);

}

/**
Expand Down Expand Up @@ -701,6 +704,8 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Make hosts resistant in a given cell
*
* If mortality is not active, the *mortality* parameter is ignored.
*
* @param row Row index of the cell
* @param col Column index of the cell
* @param susceptible Number of susceptible hosts to make resistant
Expand Down Expand Up @@ -747,6 +752,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
total_resistant += exposed[i];
}
infected_(row, col) -= infected;
total_resistant += infected;
resistant_(row, col) += total_resistant;
if (!mortality_tracker_vector_.size()) {
reset_total_host(row, col);
return;
}
if (mortality_tracker_vector_.size() != mortality.size()) {
throw std::invalid_argument(
"mortality is not the same size as the internal mortality tracker ("
Expand All @@ -756,7 +767,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
}
int mortality_total = 0;
// no simple zip in C++, falling back to indices
for (size_t i = 0; i < mortality.size(); ++i) {
for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) {
mortality_tracker_vector_[i](row, col) -= mortality[i];
mortality_total += mortality[i];
}
Expand All @@ -772,8 +783,7 @@ class HostPool : public HostPoolInterface<RasterIndex>
+ " for cell (" + std::to_string(row) + ", " + std::to_string(col)
+ "))");
}
total_resistant += infected;
resistant_(row, col) += total_resistant;
reset_total_host(row, col);
}

/**
Expand Down Expand Up @@ -976,6 +986,9 @@ class HostPool : public HostPoolInterface<RasterIndex>
/**
* @brief Get infected hosts in each mortality cohort at a given cell
*
* If mortality is not active, it returns number of all infected individuals
* in the first and only item of the vector.
*
* @param row Row index of the cell
* @param col Column index of the cell
*
Expand All @@ -984,6 +997,12 @@ class HostPool : public HostPoolInterface<RasterIndex>
std::vector<int> mortality_by_group_at(RasterIndex row, RasterIndex col) const
{
std::vector<int> all;

if (!mortality_tracker_vector_.size()) {
all.push_back(infected_at(row, col));
return all;
}

all.reserve(mortality_tracker_vector_.size());
for (const auto& raster : mortality_tracker_vector_)
all.push_back(raster(row, col));
Expand Down
1 change: 1 addition & 0 deletions include/pops/treatments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class SimpleTreatment : public BaseTreatment<HostPool, FloatRaster>
remove_mortality.push_back(remove);
remove_infected += remove;
}
// Will need to use infected directly if not mortality.

std::vector<int> remove_exposed;
for (int count : host_pool.exposed_by_group_at(i, j)) {
Expand Down
77 changes: 77 additions & 0 deletions tests/test_treatments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,82 @@ int test_application_ratio()
return num_errors;
}

/**
* @brief Test treatment without active mortality
* @return Number of errors encountered
*/
int test_application_ratio_without_mortality()
{
int num_errors = 0;
Scheduler scheduler(Date(2020, 1, 1), Date(2020, 12, 31), StepUnit::Month, 1);

TestEnvironment environment;

Raster<double> tr1 = {{1, 0.5}, {0.75, 0}};
Raster<int> susceptible = {{10, 6}, {20, 42}};
Raster<int> resistant = {{0, 0}, {0, 0}};
Raster<int> infected = {{1, 4}, {16, 40}};
Raster<int> zeros(infected.rows(), infected.cols(), 0);
auto total_hosts = infected + susceptible + resistant;
std::vector<Raster<int>> exposed;
std::vector<Raster<int>> mortality_tracker;

std::vector<std::vector<int>> suitable_cells = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};

StandardSingleHostPool host_pool(
ModelType::SusceptibleInfected,
susceptible,
exposed,
0,
infected,
zeros,
resistant,
mortality_tracker,
zeros,
total_hosts,
environment,
false,
0,
false,
0,
infected.rows(),
infected.cols(),
suitable_cells);

// First, test that host pool works for the case without mortality.
for (int row = 0; row < infected.rows(); ++row) {
for (int col = 0; col < infected.cols(); ++col) {
auto mortality_groups = host_pool.mortality_by_group_at(row, col);
if (mortality_groups.size() != 1) {
std::cerr << "Expected a single mortality group from host pool but got " <<
mortality_groups.size() << "\n";
num_errors++;
}
if (mortality_groups[0] != host_pool.infected_at(row, col)) {
std::cerr << "Host pool does not work as expected: " <<
"The single mortality group is diferent from total infected (" <<
mortality_groups[0] << " != " << host_pool.infected_at(row, col) << ")\n";
num_errors++;
}
}
}

Treatments<StandardSingleHostPool, Raster<double>> treatments(scheduler);
treatments.add_treatment(tr1, Date(2020, 1, 1), 0, TreatmentApplication::Ratio);
treatments.manage(0, host_pool);

Raster<int> treated = {{0, 3}, {5, 42}};
Raster<int> inf_treated = {{0, 2}, {4, 40}};
auto th_treated = treated + inf_treated + resistant;
if (!(susceptible == treated && infected == inf_treated
&& total_hosts == th_treated)) {
std::cerr << "Treatment with ratio app does not work without mortality\n";
std::cerr << susceptible << infected << total_hosts;
num_errors++;
}
return num_errors;
}

int test_application_all_inf()
{
int num_errors = 0;
Expand Down Expand Up @@ -722,6 +798,7 @@ int main()
int num_errors = 0;

num_errors += test_application_ratio();
num_errors += test_application_ratio_without_mortality();
num_errors += test_application_all_inf();
num_errors += test_application_ratio_pesticide();
num_errors += test_application_all_inf_pesticide();
Expand Down

0 comments on commit 699c018

Please sign in to comment.