Skip to content

Commit 664689f

Browse files
committed
data setter redesign
1 parent 1959d3b commit 664689f

File tree

9 files changed

+1598
-1845
lines changed

9 files changed

+1598
-1845
lines changed

cpp/memilio/geography/regions.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,59 @@ get_holidays(StateId state);
103103
Range<std::pair<std::vector<std::pair<Date, Date>>::const_iterator, std::vector<std::pair<Date, Date>>::const_iterator>>
104104
get_holidays(StateId state, Date start_date, Date end_date);
105105

106+
namespace de
107+
{
108+
struct EpidataFilenames
109+
{
110+
private:
111+
112+
EpidataFilenames(std::string& pydata) :
113+
population_data_path(mio::path_join(pydata, "county_current_population.json"))
114+
{
115+
}
116+
117+
static EpidataFilenames county(std::string& pydata)
118+
{
119+
EpidataFilenames s(pydata);
120+
121+
s.case_data_path = mio::path_join(pydata, "cases_all_county_age_ma7.json");
122+
s.divi_data_path = mio::path_join(pydata, "county_divi_ma7.json");
123+
s.vaccination_data_path = mio::path_join(pydata, "vacc_county_ageinf_ma7.json");
124+
125+
return s;
126+
}
127+
128+
static EpidataFilenames states(std::string& pydata)
129+
{
130+
EpidataFilenames s(pydata);
131+
132+
s.case_data_path = mio::path_join(pydata, "cases_all_state_age_ma7.json");
133+
s.divi_data_path = mio::path_join(pydata, "state_divi_ma7.json");
134+
s.vaccination_data_path = mio::path_join(pydata, "vacc_state_ageinf_ma7.json");
135+
136+
return s;
137+
}
138+
139+
static EpidataFilenames country(std::string& pydata)
140+
{
141+
EpidataFilenames s(pydata);
142+
143+
s.case_data_path = mio::path_join(pydata, "cases_all_age_ma7.json");
144+
s.divi_data_path = mio::path_join(pydata, "germany_divi_ma7.json");
145+
s.vaccination_data_path = mio::path_join(pydata, "vacc_ageinf_ma7.json");
146+
147+
return s;
148+
}
149+
150+
std::string population_data_path;
151+
std::string case_data_path;
152+
std::string divi_data_path;
153+
std::string vaccination_data_path;
154+
};
155+
} // namespace de
156+
106157
} // namespace regions
158+
107159
} // namespace mio
108160

109161
#endif //MIO_EPI_REGIONS_H

cpp/memilio/io/parameters_io.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ IOResult<std::vector<ScalarType>> read_divi_data(const std::string& path, const
9292
* @return An IOResult indicating success or failure.
9393
*/
9494
template <typename FP, class Model>
95-
IOResult<void> set_divi_data(std::vector<Model>& model, const std::vector<double>& num_icu, const std::vector<int>& vregion,
95+
IOResult<void> set_divi_data(mio::VectorRange<Model>& model, const std::vector<double>& num_icu, const std::vector<int>& vregion,
9696
Date date, FP scaling_factor_icu)
9797
{
9898
std::vector<FP> sum_mu_I_U(vregion.size(), 0);
@@ -127,7 +127,7 @@ IOResult<void> set_divi_data(std::vector<Model>& model, const std::vector<double
127127
* @param[in] scaling_factor_icu factor by which to scale the icu cases of divi data
128128
*/
129129
template <class Model>
130-
IOResult<void> set_divi_data(std::vector<Model>& model, const std::string& path, const std::vector<int>& vregion,
130+
IOResult<void> set_divi_data(mio::VectorRange<Model>& model, const std::string& path, const std::vector<int>& vregion,
131131
Date date, double scaling_factor_icu)
132132
{
133133
// DIVI dataset will no longer be updated from CW29 2024 on.
@@ -138,7 +138,7 @@ IOResult<void> set_divi_data(std::vector<Model>& model, const std::string& path,
138138
return success();
139139
}
140140
BOOST_OUTCOME_TRY(auto&& num_icu, read_divi_data(path, vregion, date));
141-
BOOST_OUTCOME_TRY(set_divi_data(model, num_icu, rki_data, vregion, date, scaling_factor_icu));
141+
BOOST_OUTCOME_TRY(set_divi_data(model, num_icu, vregion, date, scaling_factor_icu));
142142
return success();
143143
}
144144

cpp/memilio/mobility/graph.h

Lines changed: 102 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ class Graph
152152
using NodeProperty = NodePropertyT;
153153
using EdgeProperty = EdgePropertyT;
154154

155+
Graph(std::vector<NodePropertyT> nodes, std::vector<EdgePropertyT> edges)
156+
: m_nodes(nodes)
157+
, m_edges(edges)
158+
{
159+
}
160+
161+
template <class... Args>
162+
Graph(std::vector<int>& node_ids, Args&&... args)
163+
{
164+
for (int id : node_ids) {
165+
add_node(id, std::forward<Args>(args)...);
166+
}
167+
}
168+
169+
155170
/**
156171
* @brief add a node to the graph. property of the node is constructed from arguments.
157172
*/
@@ -240,6 +255,75 @@ class Graph
240255
std::vector<Edge<EdgePropertyT>> m_edges;
241256
}; // namespace mio
242257

258+
template <class FP, class Model, class ContactPattern>
259+
void set_german_holidays(Model& node, const int node_id,
260+
const mio::Date& start_date, const mio::Date& end_date)
261+
{
262+
auto state_id = regions::get_state_id(node_id);
263+
auto holiday_periods = regions::get_holidays(state_id, start_date, end_date);
264+
265+
auto& contacts = node.parameters.template get<ContactPattern>();
266+
contacts.get_school_holidays() =
267+
std::vector<std::pair<mio::SimulationTime<FP>, mio::SimulationTime<FP>>>(holiday_periods.size());
268+
std::transform(
269+
holiday_periods.begin(), holiday_periods.end(), contacts.get_school_holidays().begin(), [=](auto& period) {
270+
return std::make_pair(mio::SimulationTime<FP>(mio::get_offset_in_days(period.first, start_date)),
271+
mio::SimulationTime<FP>(mio::get_offset_in_days(period.second, start_date)));
272+
});
273+
}
274+
275+
/**
276+
* @brief Sets the graph nodes for counties or districts.
277+
* Reads the node ids which could refer to districts or counties and the epidemiological
278+
* data from json files and creates one node for each id. Every node contains a model.
279+
* @param[in] params Model Parameters that are used for every node.
280+
* @param[in] start_date Start date for which the data should be read.
281+
* @param[in] end_data End date for which the data should be read.
282+
* @param[in] data_dir Directory that contains the data files.
283+
* @param[in, out] params_graph Graph whose nodes are set by the function.
284+
* @param[in] read_func Function that reads input data for german counties and sets Model compartments.
285+
* @param[in] node_func Function that returns the county ids.
286+
* @param[in] scaling_factor_inf Factor of confirmed cases to account for undetected cases in each county.
287+
* @param[in] scaling_factor_icu Factor of ICU cases to account for underreporting.
288+
* @param[in] tnt_capacity_factor Factor for test and trace capacity.
289+
*/
290+
template <typename FP, class ContactPattern, class Model, class MobilityParams, class Parameters,
291+
class ReadFunction>
292+
IOResult<void> set_nodes(const Parameters& params, Date start_date, Date end_date, const fs::path& data_dir,
293+
Graph<Model, MobilityParams>& params_graph, ReadFunction&& read_func,
294+
const std::vector<int>& node_ids, const std::vector<FP>& scaling_factor_inf, FP scaling_factor_icu,
295+
bool add_uncertainty_to_population = true)
296+
297+
{
298+
std::vector<Model> nodes(node_ids.size(), Model(int(size_t(params.get_num_groups()))));
299+
for (auto& node : nodes) {
300+
node.parameters = params;
301+
}
302+
303+
BOOST_OUTCOME_TRY(read_func(nodes, start_date, node_ids, scaling_factor_inf, scaling_factor_icu, data_dir.string()));
304+
305+
for (size_t node_idx = 0; node_idx < nodes.size(); ++node_idx) {
306+
307+
set_german_holidays<FP, Model, ContactPattern>(nodes[node_idx], node_ids[node_idx], start_date, end_date);
308+
if (add_uncertainty_to_population)
309+
{
310+
//uncertainty in populations
311+
for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) {
312+
for (auto j = Index<typename Model::Compartments>(0); j < Model::Compartments::Count; ++j) {
313+
auto& compartment_value = nodes[node_idx].populations[{i, j}];
314+
compartment_value =
315+
UncertainValue<FP>(compartment_value.value());
316+
compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * compartment_value.value(),
317+
1.1 * compartment_value.value()));
318+
}
319+
}
320+
}
321+
322+
params_graph.add_node(node_ids[node_idx], nodes[node_idx]);
323+
}
324+
return success();
325+
}
326+
243327
/**
244328
* @brief Sets the graph nodes for counties or districts.
245329
* Reads the node ids which could refer to districts or counties and the epidemiological
@@ -248,72 +332,49 @@ class Graph
248332
* @param[in] start_date Start date for which the data should be read.
249333
* @param[in] end_data End date for which the data should be read.
250334
* @param[in] data_dir Directory that contains the data files.
251-
* @param[in] population_data_path Path to json file containing the population data.
252-
* @param[in] is_node_for_county Specifies whether the node ids should be county ids (true) or district ids (false).
253335
* @param[in, out] params_graph Graph whose nodes are set by the function.
254336
* @param[in] read_func Function that reads input data for german counties and sets Model compartments.
255337
* @param[in] node_func Function that returns the county ids.
256338
* @param[in] scaling_factor_inf Factor of confirmed cases to account for undetected cases in each county.
257339
* @param[in] scaling_factor_icu Factor of ICU cases to account for underreporting.
258340
* @param[in] tnt_capacity_factor Factor for test and trace capacity.
259-
* @param[in] num_days Number of days to be simulated; required to load data for vaccinations during the simulation.
260-
* @param[in] export_time_series If true, reads data for each day of simulation and writes it in the same directory as the input files.
261-
* @param[in] rki_age_groups Specifies whether rki-age_groups should be used.
262341
*/
263342
template <typename FP, class TestAndTrace, class ContactPattern, class Model, class MobilityParams, class Parameters,
264-
class ReadFunction, class NodeIdFunction>
343+
class ReadFunction>
265344
IOResult<void> set_nodes(const Parameters& params, Date start_date, Date end_date, const fs::path& data_dir,
266-
const std::string& population_data_path, bool is_node_for_county,
267345
Graph<Model, MobilityParams>& params_graph, ReadFunction&& read_func,
268-
NodeIdFunction&& node_func, const std::vector<FP>& scaling_factor_inf, FP scaling_factor_icu,
269-
FP tnt_capacity_factor, int num_days = 0, bool export_time_series = false,
270-
bool rki_age_groups = true)
346+
const std::vector<int>& node_ids, const std::vector<FP>& scaling_factor_inf, FP scaling_factor_icu,
347+
FP tnt_capacity_factor, bool add_uncertainty_to_population = true)
271348

272349
{
273-
BOOST_OUTCOME_TRY(auto&& node_ids, node_func(population_data_path, is_node_for_county, rki_age_groups));
274350
std::vector<Model> nodes(node_ids.size(), Model(int(size_t(params.get_num_groups()))));
275351
for (auto& node : nodes) {
276352
node.parameters = params;
277353
}
278354

279-
BOOST_OUTCOME_TRY(read_func(nodes, start_date, node_ids, scaling_factor_inf, scaling_factor_icu, data_dir.string(),
280-
num_days, export_time_series));
355+
BOOST_OUTCOME_TRY(read_func(nodes, start_date, node_ids, scaling_factor_inf, scaling_factor_icu, data_dir.string()));
281356

282357
for (size_t node_idx = 0; node_idx < nodes.size(); ++node_idx) {
283358

284359
auto tnt_capacity = nodes[node_idx].populations.get_total() * tnt_capacity_factor;
285360

286-
//local parameters
361+
// local parameters
287362
auto& tnt_value = nodes[node_idx].parameters.template get<TestAndTrace>();
288-
tnt_value = UncertainValue<FP>(0.5 * (1.2 * tnt_capacity + 0.8 * tnt_capacity));
363+
tnt_value = UncertainValue<FP>(tnt_capacity);
289364
tnt_value.set_distribution(mio::ParameterDistributionUniform(0.8 * tnt_capacity, 1.2 * tnt_capacity));
290365

291-
auto id = 0;
292-
if (is_node_for_county) {
293-
id = int(regions::CountyId(node_ids[node_idx]));
294-
}
295-
else {
296-
id = int(regions::DistrictId(node_ids[node_idx]));
297-
}
298-
//holiday periods
299-
auto holiday_periods = regions::get_holidays(regions::get_state_id(id), start_date, end_date);
300-
auto& contacts = nodes[node_idx].parameters.template get<ContactPattern>();
301-
contacts.get_school_holidays() =
302-
std::vector<std::pair<mio::SimulationTime<FP>, mio::SimulationTime<FP>>>(holiday_periods.size());
303-
std::transform(
304-
holiday_periods.begin(), holiday_periods.end(), contacts.get_school_holidays().begin(), [=](auto& period) {
305-
return std::make_pair(mio::SimulationTime<FP>(mio::get_offset_in_days(period.first, start_date)),
306-
mio::SimulationTime<FP>(mio::get_offset_in_days(period.second, start_date)));
307-
});
308-
309-
//uncertainty in populations
310-
for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) {
311-
for (auto j = Index<typename Model::Compartments>(0); j < Model::Compartments::Count; ++j) {
312-
auto& compartment_value = nodes[node_idx].populations[{i, j}];
313-
compartment_value =
314-
UncertainValue<FP>(0.5 * (1.1 * compartment_value.value() + 0.9 * compartment_value.value()));
315-
compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * compartment_value.value(),
316-
1.1 * compartment_value.value()));
366+
set_german_holidays<FP, Model, ContactPattern>(nodes[node_idx], node_ids[node_idx], start_date, end_date);
367+
if (add_uncertainty_to_population)
368+
{
369+
//uncertainty in populations
370+
for (auto i = mio::AgeGroup(0); i < params.get_num_groups(); i++) {
371+
for (auto j = Index<typename Model::Compartments>(0); j < Model::Compartments::Count; ++j) {
372+
auto& compartment_value = nodes[node_idx].populations[{i, j}];
373+
compartment_value =
374+
UncertainValue<FP>(compartment_value.value());
375+
compartment_value.set_distribution(mio::ParameterDistributionUniform(0.9 * compartment_value.value(),
376+
1.1 * compartment_value.value()));
377+
}
317378
}
318379
}
319380

cpp/memilio/utils/stl_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ constexpr std::array<T, size_t(T::Count)> enum_members()
315315
return enum_members;
316316
}
317317

318+
template<class T>
319+
using VectorRange = mio::Range<std::pair<typename std::vector<T>::iterator, typename std::vector<T>::iterator>>;
320+
321+
template<class T>
322+
using ConstVectorRange = mio::Range<std::pair<typename std::vector<T>::const_iterator, typename std::vector<T>::const_iterator>>;
323+
318324
} // namespace mio
319325

320326
#endif //STL_UTIL_H

0 commit comments

Comments
 (0)