From 19388929c43662783741959878b9136754f6201b Mon Sep 17 00:00:00 2001 From: Vaclav Petras Date: Thu, 14 Nov 2024 09:33:40 -0500 Subject: [PATCH 1/2] Use rounding for treatments Update to PoPS Core with rounding for treatments. --- inst/cpp/pops-core | 2 +- inst/include/actions.hpp | 19 +++---- inst/include/config.hpp | 9 +++- inst/include/deterministic_kernel.hpp | 6 ++- inst/include/host_pool.hpp | 54 +++++++++++-------- inst/include/network.hpp | 4 +- inst/include/pest_host_table.hpp | 2 +- inst/include/pest_pool.hpp | 37 ++++++------- inst/include/quarantine.hpp | 8 +-- inst/include/radial_kernel.hpp | 4 +- inst/include/raster.hpp | 77 ++++++++++++++++++++++++--- inst/include/simulation.hpp | 9 ++-- inst/include/soils.hpp | 2 +- inst/include/treatments.hpp | 33 ++++++------ 14 files changed, 171 insertions(+), 95 deletions(-) diff --git a/inst/cpp/pops-core b/inst/cpp/pops-core index 19eab166..e65eb932 160000 --- a/inst/cpp/pops-core +++ b/inst/cpp/pops-core @@ -1 +1 @@ -Subproject commit 19eab166224bb5f75a285ec2d96920ec8e2d3f4c +Subproject commit e65eb9323f744fb2a158acaf7d8a574449571f37 diff --git a/inst/include/actions.hpp b/inst/include/actions.hpp index e92bbc54..1b05188e 100644 --- a/inst/include/actions.hpp +++ b/inst/include/actions.hpp @@ -90,16 +90,14 @@ class SpreadAction // From all the generated dispersers, some go to the soil in the // same cell and don't participate in the kernel-driven dispersal. auto dispersers_to_soil = - std::round(to_soil_percentage_ * dispersers_from_cell); + std::lround(to_soil_percentage_ * dispersers_from_cell); soil_pool_->dispersers_to(dispersers_to_soil, i, j, generator); dispersers_from_cell -= dispersers_to_soil; } - pests.set_dispersers_at(i, j, dispersers_from_cell); - pests.set_established_dispersers_at(i, j, dispersers_from_cell); + pests.set_dispersers_at(i, j, dispersers_from_cell, 0); } else { - pests.set_dispersers_at(i, j, 0); - pests.set_established_dispersers_at(i, j, 0); + pests.set_dispersers_at(i, j, 0, 0); } } } @@ -123,17 +121,15 @@ class SpreadAction if (pests.dispersers_at(i, j) > 0) { for (int k = 0; k < pests.dispersers_at(i, j); k++) { std::tie(row, col) = dispersal_kernel_(generator, i, j); - // if (row < 0 || row >= rows_ || col < 0 || col >= cols_) { if (host_pool.is_outside(row, col)) { pests.add_outside_disperser_at(row, col); - pests.remove_established_dispersers_at(i, j, 1); continue; } // Put a disperser to the host pool. auto dispersed = host_pool.disperser_to(row, col, generator.establishment()); - if (!dispersed) { - pests.remove_established_dispersers_at(i, j, 1); + if (dispersed) { + pests.add_established_dispersers_at(i, j, 1); } } } @@ -370,7 +366,8 @@ class MoveOverpopulatedPests // for leaving_percentage == 0.5 // 2 infected -> 1 leaving // 3 infected -> 1 leaving - int leaving = original_count * leaving_percentage_; + int leaving = + static_cast(std::floor(original_count * leaving_percentage_)); leaving = hosts.pests_from(i, j, leaving, generator.overpopulation()); if (row < 0 || row >= rows_ || col < 0 || col >= cols_) { pests.add_outside_dispersers_at(row, col, leaving); @@ -510,7 +507,7 @@ class Mortality void action(Hosts& hosts) { for (auto indices : hosts.suitable_cells()) { - if (action_mortality_) { + if (static_cast(action_mortality_)) { hosts.apply_mortality_at( indices[0], indices[1], mortality_rate_, mortality_time_lag_); } diff --git a/inst/include/config.hpp b/inst/include/config.hpp index 79ea58b1..a5382a9e 100644 --- a/inst/include/config.hpp +++ b/inst/include/config.hpp @@ -569,12 +569,19 @@ class Config for (const auto& row : values) { if (row.size() < 3) { throw std::invalid_argument( - "3 values are required for each pest-host table row"); + "3 values are required for each pest-host table row " + "(but row size is " + + std::to_string(row.size()) + ")"); } PestHostTableDataRow resulting_row; resulting_row.susceptibility = row[0]; resulting_row.mortality_rate = row[1]; resulting_row.mortality_time_lag = row[2]; + if (resulting_row.susceptibility < 0 || resulting_row.susceptibility > 1) { + throw std::invalid_argument( + "Susceptibility needs to be >=0 and <=1, not " + + std::to_string(resulting_row.susceptibility)); + } pest_host_table_data_.push_back(std::move(resulting_row)); } } diff --git a/inst/include/deterministic_kernel.hpp b/inst/include/deterministic_kernel.hpp index a3df69eb..edcd02a9 100644 --- a/inst/include/deterministic_kernel.hpp +++ b/inst/include/deterministic_kernel.hpp @@ -172,8 +172,10 @@ class DeterministicDispersalKernel // The invalid state is checked later, in this case using the kernel type. return; } - number_of_columns = ceil(max_distance / east_west_resolution) * 2 + 1; - number_of_rows = ceil(max_distance / north_south_resolution) * 2 + 1; + number_of_columns = + static_cast(ceil(max_distance / east_west_resolution)) * 2 + 1; + number_of_rows = + static_cast(ceil(max_distance / north_south_resolution)) * 2 + 1; Raster prob_size(number_of_rows, number_of_columns, 0); probability = prob_size; probability_copy = prob_size; diff --git a/inst/include/host_pool.hpp b/inst/include/host_pool.hpp index 2c600b31..7be1dfd9 100644 --- a/inst/include/host_pool.hpp +++ b/inst/include/host_pool.hpp @@ -26,6 +26,7 @@ #include "environment_interface.hpp" #include "competency_table.hpp" #include "pest_host_table.hpp" +#include "utils.hpp" namespace pops { @@ -306,7 +307,8 @@ class HostPool : public HostPoolInterface } } else { - dispersers_from_cell = lambda * infected_at(row, col); + dispersers_from_cell = + static_cast(std::floor(lambda * infected_at(row, col))); } return dispersers_from_cell; } @@ -326,7 +328,17 @@ class HostPool : public HostPoolInterface if (pest_host_table_) { suitability *= pest_host_table_->susceptibility(this); } - return environment_.influence_suitability_at(row, col, suitability); + suitability = environment_.influence_suitability_at(row, col, suitability); + if (suitability < 0 || suitability > 1) { + throw std::invalid_argument( + "Suitability should be >=0 and <=1, not " + std::to_string(suitability) + + " (susceptible: " + std::to_string(susceptible_(row, col)) + + ", total population: " + + std::to_string(environment_.total_population_at(row, col)) + + ", susceptibility: " + + std::to_string(pest_host_table_->susceptibility(this)) + ")"); + } + return suitability; } /** @@ -477,16 +489,9 @@ class HostPool : public HostPoolInterface // Since suitable cells originally comes from the total hosts, check first total // hosts and proceed only if there was no host. if (total_hosts_(row_to, col_to) == 0) { - for (auto indices : suitable_cells_) { - int i = indices[0]; - int j = indices[1]; - // TODO: This looks like a bug. Flag is needed for found and push back - // should happen only after the loop. - if ((i == row_to) && (j == col_to)) { - std::vector added_index = {row_to, col_to}; - suitable_cells_.push_back(added_index); - break; - } + std::vector new_index = {row_to, col_to}; + if (!container_contains(suitable_cells_, new_index)) { + suitable_cells_.push_back(new_index); } } @@ -528,10 +533,10 @@ class HostPool : public HostPoolInterface void completely_remove_hosts_at( RasterIndex row, RasterIndex col, - double susceptible, - std::vector exposed, - double infected, - const std::vector& mortality) + int susceptible, + std::vector exposed, + int infected, + const std::vector& mortality) { if (susceptible > 0) susceptible_(row, col) = susceptible_(row, col) - susceptible; @@ -560,7 +565,7 @@ class HostPool : public HostPoolInterface + std::to_string(row) + ", " + std::to_string(col) + ")"); } - double mortality_total = 0; + int mortality_total = 0; for (size_t i = 0; i < mortality.size(); ++i) { if (mortality_tracker_vector_[i](row, col) < mortality[i]) { throw std::invalid_argument( @@ -578,15 +583,15 @@ class HostPool : public HostPoolInterface // and once we don't need to keep the exact same double to int results for // tests. First condition always fails the tests. The second one may potentially // fail. - if (false && infected != mortality_total) { + if (infected != mortality_total) { throw std::invalid_argument( "Total of removed mortality values differs from removed infected " "count (" + std::to_string(mortality_total) + " != " + std::to_string(infected) - + " for cell (" + std::to_string(row) + ", " + std::to_string(col) + + ") for cell (" + std::to_string(row) + ", " + std::to_string(col) + ")"); } - if (false && infected_(row, col) < mortality_total) { + if (infected_(row, col) < mortality_total) { throw std::invalid_argument( "Total of removed mortality values is higher than current number " "of infected hosts for cell (" @@ -795,6 +800,9 @@ class HostPool : public HostPoolInterface * individuals is multiplied by the mortality rate to calculate the number of hosts * that die that time step. * + * If mortality rate is zero (<=0), no mortality is applied and mortality tracker + * vector stays as is, i.e., no hosts die. + * * To be used together with step_forward_mortality(). * * @param row Row index of the cell @@ -805,6 +813,8 @@ class HostPool : public HostPoolInterface void apply_mortality_at( RasterIndex row, RasterIndex col, double mortality_rate, int mortality_time_lag) { + if (mortality_rate <= 0) + return; int max_index = mortality_tracker_vector_.size() - mortality_time_lag - 1; for (int index = 0; index <= max_index; index++) { int mortality_in_index = 0; @@ -815,8 +825,8 @@ class HostPool : public HostPoolInterface mortality_in_index = mortality_tracker_vector_[index](row, col); } else { - mortality_in_index = - mortality_rate * mortality_tracker_vector_[index](row, col); + mortality_in_index = static_cast(std::floor( + mortality_rate * mortality_tracker_vector_[index](row, col))); } mortality_tracker_vector_[index](row, col) -= mortality_in_index; died_(row, col) += mortality_in_index; diff --git a/inst/include/network.hpp b/inst/include/network.hpp index 56caf671..77da301d 100644 --- a/inst/include/network.hpp +++ b/inst/include/network.hpp @@ -58,7 +58,7 @@ class EdgeGeometry : public std::vector /** Get cost of the whole segment (edge). */ double cost() const { - if (total_cost_) + if (static_cast(total_cost_)) return total_cost_; // This is short for ((size - 2) + (2 * 1/2)) * cost per cell. return (this->size() - 1) * cost_per_cell_; @@ -72,7 +72,7 @@ class EdgeGeometry : public std::vector /** Get cost per cell for the segment (edge). */ double cost_per_cell() const { - if (total_cost_) + if (static_cast(total_cost_)) return total_cost_ / (this->size() - 1); return cost_per_cell_; } diff --git a/inst/include/pest_host_table.hpp b/inst/include/pest_host_table.hpp index 0d70ba3e..a0017e43 100644 --- a/inst/include/pest_host_table.hpp +++ b/inst/include/pest_host_table.hpp @@ -103,7 +103,7 @@ class PestHostTable * @param host Pointer to the host to get the information for * @return Mortality time lag value */ - double mortality_time_lag(const HostPool* host) const + int mortality_time_lag(const HostPool* host) const { auto host_index = environment_.host_index(host); return mortality_time_lags_.at(host_index); diff --git a/inst/include/pest_pool.hpp b/inst/include/pest_pool.hpp index 7eb3e6bb..e548cf3f 100644 --- a/inst/include/pest_pool.hpp +++ b/inst/include/pest_pool.hpp @@ -53,13 +53,17 @@ class PestPool {} /** * @brief Set number of dispersers + * * @param row Row number * @param col Column number - * @param value The new value + * @param dispersers Number of dispersers + * @param established_dispersers Number of established dispersers */ - void set_dispersers_at(RasterIndex row, RasterIndex col, int value) + void set_dispersers_at( + RasterIndex row, RasterIndex col, int dispersers, int established_dispersers) { - dispersers_(row, col) = value; + dispersers_(row, col) = dispersers; + established_dispersers_(row, col) = established_dispersers; } /** * @brief Return number of dispersers @@ -82,32 +86,21 @@ class PestPool { return dispersers_; } + /** - * @brief Set number of established dispersers + * @brief Add established dispersers * - * Established are dispersers which left cell (row, col) and established themselves - * elsewhere, i.e., origin of the established dispersers is tracked. + * Established dispersers are dispersers which left cell (row, col) and + * established themselves elsewhere, i.e., origin of the established dispersers + * is tracked. * * @param row Row number * @param col Column number - * @param value The new value - */ - void set_established_dispersers_at(RasterIndex row, RasterIndex col, int value) - { - established_dispersers_(row, col) = value; - } - // TODO: The following function should not be necessary because pests can't - // un-establish. It exists just because it mirrors how the raster was handled in the - // original Simulation code. - /** - * @brief Remove established dispersers - * @param row Row number - * @param col Column number - * @param count How many dispers to remove + * @param count How many dispersers to add */ - void remove_established_dispersers_at(RasterIndex row, RasterIndex col, int count) + void add_established_dispersers_at(RasterIndex row, RasterIndex col, int count) { - established_dispersers_(row, col) -= count; + established_dispersers_(row, col) += count; } /** * @brief Add a disperser which left the study area diff --git a/inst/include/quarantine.hpp b/inst/include/quarantine.hpp index 71da68b1..6720ab5c 100644 --- a/inst/include/quarantine.hpp +++ b/inst/include/quarantine.hpp @@ -147,20 +147,20 @@ class QuarantineEscapeAction DistDir closest; if (directions_.at(Direction::N) && (i - n) * north_south_resolution_ < mindist) { - mindist = (i - n) * north_south_resolution_; + mindist = static_cast(std::floor((i - n) * north_south_resolution_)); closest = std::make_tuple(mindist, Direction::N); } if (directions_.at(Direction::S) && (s - i) * north_south_resolution_ < mindist) { - mindist = (s - i) * north_south_resolution_; + mindist = static_cast(std::floor((s - i) * north_south_resolution_)); closest = std::make_tuple(mindist, Direction::S); } if (directions_.at(Direction::E) && (e - j) * west_east_resolution_ < mindist) { - mindist = (e - j) * west_east_resolution_; + mindist = static_cast(std::floor((e - j) * west_east_resolution_)); closest = std::make_tuple(mindist, Direction::E); } if (directions_.at(Direction::W) && (j - w) * west_east_resolution_ < mindist) { - mindist = (j - w) * west_east_resolution_; + mindist = static_cast(std::floor((j - w) * west_east_resolution_)); closest = std::make_tuple(mindist, Direction::W); } return closest; diff --git a/inst/include/radial_kernel.hpp b/inst/include/radial_kernel.hpp index b6e4722b..ce265cb9 100644 --- a/inst/include/radial_kernel.hpp +++ b/inst/include/radial_kernel.hpp @@ -198,8 +198,8 @@ class RadialDispersalKernel } theta = von_mises(generator); - row -= round(distance * cos(theta) / north_south_resolution); - col += round(distance * sin(theta) / east_west_resolution); + row -= lround(distance * cos(theta) / north_south_resolution); + col += lround(distance * sin(theta) / east_west_resolution); return std::make_tuple(row, col); } diff --git a/inst/include/raster.hpp b/inst/include/raster.hpp index 03ff60d8..54ccdc2c 100644 --- a/inst/include/raster.hpp +++ b/inst/include/raster.hpp @@ -275,7 +275,11 @@ class Raster } template - Raster& operator+=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator+=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a += value; }); @@ -283,7 +287,11 @@ class Raster } template - Raster& operator-=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator-=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a -= value; }); @@ -291,7 +299,11 @@ class Raster } template - Raster& operator*=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator*=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a *= value; }); @@ -299,13 +311,65 @@ class Raster } template - Raster& operator/=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator/=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a /= value; }); return *this; } + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator+=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a += static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator-=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a -= static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator*=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a *= static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator/=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a /= static_cast(std::floor(value)); + }); + return *this; + } + template typename std::enable_if< std::is_floating_point::value @@ -496,12 +560,13 @@ class Raster return out; } - friend inline Raster pow(Raster image, double value) + friend inline Raster pow(const Raster& image, double value) { image.for_each([value](Number& a) { a = std::pow(a, value); }); return image; } - friend inline Raster sqrt(Raster image) + + friend inline Raster sqrt(const Raster& image) { image.for_each([](Number& a) { a = std::sqrt(a); }); return image; diff --git a/inst/include/simulation.hpp b/inst/include/simulation.hpp index 582185bb..e4d5db8d 100644 --- a/inst/include/simulation.hpp +++ b/inst/include/simulation.hpp @@ -35,10 +35,11 @@ namespace pops { /*! A class to control the spread simulation. * - * \deprecated - * The class is deprecated in favor of individual action classes and a higher-level - * Model. The class corresponding to the original Simulation class before too much code - * accumulated in Simulation is SpreadAction. The class is now used only in tests. + * @note + * The class is deprecated for external use in favor of individual action classes and a + * higher-level Model. The class corresponding to the original Simulation class before + * too much code accumulated in Simulation is SpreadAction. The class is now used only + * in tests. * * The Simulation class handles the mechanics of the model, but the * timing of the events or steps should be handled outside of this diff --git a/inst/include/soils.hpp b/inst/include/soils.hpp index 8e23435d..3c4802d3 100644 --- a/inst/include/soils.hpp +++ b/inst/include/soils.hpp @@ -90,7 +90,7 @@ class SoilPool } } else { - dispersers = lambda * count; + dispersers = static_cast(std::floor(lambda * count)); } auto draw = draw_n_from_cohorts(*rasters_, dispersers, row, col, generator); size_t index = 0; diff --git a/inst/include/treatments.hpp b/inst/include/treatments.hpp index 39381a65..7484c1e8 100644 --- a/inst/include/treatments.hpp +++ b/inst/include/treatments.hpp @@ -125,18 +125,18 @@ class BaseTreatment : public AbstractTreatment } // returning double allows identical results with the previous version - double get_treated(int i, int j, int count) + int get_treated(int i, int j, int count) { return get_treated(i, j, count, this->application_); } - double get_treated(int i, int j, int count, TreatmentApplication application) + int get_treated(int i, int j, int count, TreatmentApplication application) { if (application == TreatmentApplication::Ratio) { - return count * this->map_(i, j); + return std::lround(count * this->map_(i, j)); } else if (application == TreatmentApplication::AllInfectedInCell) { - return this->map_(i, j) ? count : 0; + return static_cast(this->map_(i, j)) ? count : 0; } throw std::runtime_error( "BaseTreatment::get_treated: unknown TreatmentApplication"); @@ -173,16 +173,18 @@ class SimpleTreatment : public BaseTreatment for (auto indices : host_pool.suitable_cells()) { int i = indices[0]; int j = indices[1]; - double remove_susceptible = this->get_treated( + int remove_susceptible = this->get_treated( i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio); - double remove_infected = - this->get_treated(i, j, host_pool.infected_at(i, j)); - std::vector remove_mortality; + // Treated infected are computed as a sum of treated in mortality groups. + int remove_infected = 0; + std::vector remove_mortality; for (int count : host_pool.mortality_by_group_at(i, j)) { - remove_mortality.push_back(this->get_treated(i, j, count)); + int remove = this->get_treated(i, j, count); + remove_mortality.push_back(remove); + remove_infected += remove; } - std::vector remove_exposed; + std::vector remove_exposed; for (int count : host_pool.exposed_by_group_at(i, j)) { remove_exposed.push_back(this->get_treated(i, j, count)); } @@ -238,26 +240,25 @@ class PesticideTreatment : public BaseTreatment for (auto indices : host_pool.suitable_cells()) { int i = indices[0]; int j = indices[1]; - // Given how the original code was written (everything was first converted - // to ints and subtractions happened only afterwards), this needs ints, - // not doubles to pass the r.pops.spread test (unlike the other code which - // did substractions before converting to ints). int susceptible_resistant = this->get_treated( i, j, host_pool.susceptible_at(i, j), TreatmentApplication::Ratio); std::vector resistant_exposed_list; for (const auto& number : host_pool.exposed_by_group_at(i, j)) { resistant_exposed_list.push_back(this->get_treated(i, j, number)); } + int infected = 0; std::vector resistant_mortality_list; for (const auto& number : host_pool.mortality_by_group_at(i, j)) { - resistant_mortality_list.push_back(this->get_treated(i, j, number)); + int remove = this->get_treated(i, j, number); + resistant_mortality_list.push_back(remove); + infected += remove; } host_pool.make_resistant_at( i, j, susceptible_resistant, resistant_exposed_list, - this->get_treated(i, j, host_pool.infected_at(i, j)), + infected, resistant_mortality_list); } } From 5f25416ac7fcf9819a303aece0bbff961d208d1e Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 6 Jan 2025 07:48:58 -0500 Subject: [PATCH 2/2] update to latest treatments rounding branch --- inst/include/host_pool.hpp | 37 ++++++++++++++++++++++++++++--------- inst/include/treatments.hpp | 1 + 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/inst/include/host_pool.hpp b/inst/include/host_pool.hpp index 7be1dfd9..0a00af9b 100644 --- a/inst/include/host_pool.hpp +++ b/inst/include/host_pool.hpp @@ -516,6 +516,7 @@ class HostPool : public HostPoolInterface * @brief Completely remove any hosts * * Removes hosts completely (as opposed to moving them to another pool). + * If mortality is not active, the *mortality* parameter is ignored. * * @param row Row index of the cell * @param col Column index of the cell @@ -524,9 +525,6 @@ class HostPool : public HostPoolInterface * @param infected Number of infected hosts to remove. * @param mortality Number of infected hosts in each mortality cohort. * - * @note Counts are doubles, so that handling of floating point values is managed - * here in the same way as in the original treatment code. - * * @note This does not remove resistant just like the original implementation in * treatments. */ @@ -557,6 +555,11 @@ class HostPool : public HostPoolInterface // Possibly reuse in the I->S removal. if (infected <= 0) return; + if (!mortality_tracker_vector_.size()) { + infected_(row, col) -= infected; + reset_total_host(row, col); + return; + } if (mortality_tracker_vector_.size() != mortality.size()) { throw std::invalid_argument( "mortality is not the same size as the internal mortality tracker (" @@ -566,7 +569,7 @@ class HostPool : public HostPoolInterface } int mortality_total = 0; - for (size_t i = 0; i < mortality.size(); ++i) { + for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) { if (mortality_tracker_vector_[i](row, col) < mortality[i]) { throw std::invalid_argument( "Mortality value [" + std::to_string(i) + "] is too high (" @@ -594,9 +597,9 @@ class HostPool : public HostPoolInterface if (infected_(row, col) < mortality_total) { throw std::invalid_argument( "Total of removed mortality values is higher than current number " - "of infected hosts for cell (" - + std::to_string(row) + ", " + std::to_string(col) + ") is too high (" + "of infected hosts (" + std::to_string(mortality_total) + " > " + std::to_string(infected) + + ") for cell (" + std::to_string(row) + ", " + std::to_string(col) + ")"); } infected_(row, col) -= infected; @@ -701,6 +704,8 @@ class HostPool : public HostPoolInterface /** * @brief Make hosts resistant in a given cell * + * If mortality is not active, the *mortality* parameter is ignored. + * * @param row Row index of the cell * @param col Column index of the cell * @param susceptible Number of susceptible hosts to make resistant @@ -747,6 +752,12 @@ class HostPool : public HostPoolInterface total_resistant += exposed[i]; } infected_(row, col) -= infected; + total_resistant += infected; + resistant_(row, col) += total_resistant; + if (!mortality_tracker_vector_.size()) { + reset_total_host(row, col); + return; + } if (mortality_tracker_vector_.size() != mortality.size()) { throw std::invalid_argument( "mortality is not the same size as the internal mortality tracker (" @@ -756,7 +767,7 @@ class HostPool : public HostPoolInterface } int mortality_total = 0; // no simple zip in C++, falling back to indices - for (size_t i = 0; i < mortality.size(); ++i) { + for (size_t i = 0; i < mortality_tracker_vector_.size(); ++i) { mortality_tracker_vector_[i](row, col) -= mortality[i]; mortality_total += mortality[i]; } @@ -772,8 +783,7 @@ class HostPool : public HostPoolInterface + " for cell (" + std::to_string(row) + ", " + std::to_string(col) + "))"); } - total_resistant += infected; - resistant_(row, col) += total_resistant; + reset_total_host(row, col); } /** @@ -976,6 +986,9 @@ class HostPool : public HostPoolInterface /** * @brief Get infected hosts in each mortality cohort at a given cell * + * If mortality is not active, it returns number of all infected individuals + * in the first and only item of the vector. + * * @param row Row index of the cell * @param col Column index of the cell * @@ -984,6 +997,12 @@ class HostPool : public HostPoolInterface std::vector mortality_by_group_at(RasterIndex row, RasterIndex col) const { std::vector all; + + if (!mortality_tracker_vector_.size()) { + all.push_back(infected_at(row, col)); + return all; + } + all.reserve(mortality_tracker_vector_.size()); for (const auto& raster : mortality_tracker_vector_) all.push_back(raster(row, col)); diff --git a/inst/include/treatments.hpp b/inst/include/treatments.hpp index 7484c1e8..9494be97 100644 --- a/inst/include/treatments.hpp +++ b/inst/include/treatments.hpp @@ -183,6 +183,7 @@ class SimpleTreatment : public BaseTreatment remove_mortality.push_back(remove); remove_infected += remove; } + // Will need to use infected directly if not mortality. std::vector remove_exposed; for (int count : host_pool.exposed_by_group_at(i, j)) {