Skip to content

Commit c0dc835

Browse files
Merge pull request #38 from cdcent/epoch-transitions
Epoch transitions
2 parents cf7e04b + ac55d03 commit c0dc835

File tree

5 files changed

+401
-34
lines changed

5 files changed

+401
-34
lines changed

config/config_base.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def __init__(self, **kwargs) -> None:
3232
self.HOSP_PATH = "data/hospital_220213_220108.csv"
3333
# model initialization date DO NOT CHANGE
3434
self.INIT_DATE = datetime.date(2022, 2, 11)
35+
# if running epochs, this value will be number of days after the INIT_DATE the current epoch begins
36+
# 0 if you are initializing.
37+
self.DAYS_AFTER_INIT_DATE = 0
3538
self.MINIMUM_AGE = 0
3639
# age limits for each age bin in the model, begining with minimum age
3740
# values are exclusive in upper bound. so [0,18) means 0-17, 18+
@@ -128,6 +131,15 @@ def __init__(self, **kwargs) -> None:
128131
self.MCMC_PROGRESS_BAR = True
129132
self.MODEL_RAND_SEED = 8675309
130133

134+
# this are all the strains currently supported, historical and future
135+
self.all_strains_supported = [
136+
"wildtype",
137+
"alpha",
138+
"delta",
139+
"omicron",
140+
"BA2/BA5",
141+
]
142+
131143
# now update all parameters from kwargs, overriding the defaults if they are explicitly set
132144
self.__dict__.update(kwargs)
133145
self.GIT_HASH = (
@@ -154,17 +166,8 @@ def __init__(self, **kwargs) -> None:
154166
["W" + str(idx) for idx in range(self.NUM_WANING_COMPARTMENTS)],
155167
start=0,
156168
)
157-
158-
# this are all the strains currently supported, historical and future
159-
all_strains = [
160-
"wildtype",
161-
"alpha",
162-
"delta",
163-
"omicron",
164-
"BA2/BA5",
165-
]
166169
# it often does not make sense to differentiate between wildtype and alpha, so combine strains here
167-
self.STRAIN_NAMES = all_strains[5 - self.NUM_STRAINS :]
170+
self.STRAIN_NAMES = self.all_strains_supported[-self.NUM_STRAINS :]
168171
self.STRAIN_NAMES[0] = "pre-" + self.STRAIN_NAMES[1]
169172
# in each compartment that is strain stratified we use strain indexes to improve readability.
170173
# omicron will always be index=2 if num_strains >= 3. In a two strain model we must combine alpha and delta together.

config/config_epoch_2.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import jax.numpy as jnp
2+
3+
from config.config_base import ConfigBase
4+
5+
6+
class ConfigEpoch(ConfigBase):
7+
"""
8+
This is an example Config file for a particular scenario,
9+
in which we want to test a 2 strain model with inital R0 of 1.5 for eachs train, and no vaccination.
10+
Through inheritance this class will inherit all non-listed parameters from ConfigBase, and can even add its own!
11+
"""
12+
13+
def __init__(self) -> None:
14+
self.SCENARIO_NAME = "Epoch 2, omicron, BA2, XBB"
15+
# set scenario parameters here
16+
self.STRAIN_SPECIFIC_R0 = jnp.array([1.8, 3.0, 3.0]) # R0s
17+
self.STRAIN_INTERACTIONS = jnp.array(
18+
[
19+
[1.0, 0.7, 0.49], # omicron
20+
[0.7, 1.0, 0.7], # BA2
21+
[0.49, 0.7, 1.0], # XBB
22+
]
23+
)
24+
self.VAX_EFF_MATRIX = jnp.array(
25+
[
26+
[0, 0.34, 0.68], # omicron
27+
[0, 0.24, 0.48], # BA2
28+
[0, 0.14, 0.28], # XBB
29+
]
30+
)
31+
self.all_strains_supported = [
32+
"wildtype",
33+
"alpha",
34+
"delta",
35+
"omicron",
36+
"BA2/BA5",
37+
"XBB1.5",
38+
]
39+
# specifies the number of days after the model INIT date this epoch occurs
40+
self.DAYS_AFTER_INIT_DATE = 250
41+
# DO NOT CHANGE THE FOLLOWING TWO LINES
42+
super().__init__(**self.__dict__)
43+
# Do not add any scenario parameters below, may create inconsistent state
44+
45+
def assert_valid_values(self):
46+
"""
47+
a function designed to be called after all parameters are initalized, does a series of reasonable checks
48+
to ensure values are within expected ranges and no parameters directly contradict eachother.
49+
50+
Raises
51+
----------
52+
Assert Error:
53+
if user supplies invalid parameters, short description will be provided as to why the parameter is wrong.
54+
"""
55+
super().assert_valid_values()
56+
assert True, "any new parameters should be tested here"

mechanistic_compartments.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, **kwargs):
6565

6666
# GENERATE CROSS IMMUNITY MATRIX with protection from STRAIN_INTERACTIONS most recent infected strain.
6767
if self.CROSSIMMUNITY_MATRIX is None:
68-
self.build_cross_immunity_matrix()
68+
self.load_cross_immunity_matrix()
6969
# if not given, load population fractions based on observed census data into self
7070
if not self.INITIAL_POPULATION_FRACTIONS:
7171
self.load_initial_population_fractions()
@@ -303,7 +303,7 @@ def vaccination_rate(self, t):
303303
"""
304304
return jnp.exp(
305305
utils.VAX_FUNCTION(
306-
t,
306+
t + self.DAYS_AFTER_INIT_DATE,
307307
self.VAX_MODEL_KNOT_LOCATIONS,
308308
self.VAX_MODEL_BASE_EQUATIONS,
309309
self.VAX_MODEL_KNOTS,
@@ -528,6 +528,7 @@ def run(
528528
max_steps=int(1e6),
529529
)
530530
self.solution = solution
531+
self.solution_final_state = tuple(y[-1] for y in solution.ys)
531532
save_path = (
532533
save_path if save else None
533534
) # dont set a save path if we dont want to save
@@ -610,6 +611,7 @@ def plot_diffrax_solution(
610611
plot_labels: list[str] = None,
611612
save_path: str = None,
612613
log_scale: bool = None,
614+
start_date: datetime.date = None,
613615
fig: plt.figure = None,
614616
ax: plt.axis = None,
615617
):
@@ -635,6 +637,8 @@ def plot_diffrax_solution(
635637
log_scale : bool, optional
636638
whether or not to exclusively show the log or unlogged version of the plot, by default include both
637639
in a stacked subplot.
640+
start_date : date, optional
641+
the start date of the x axis of the plot. Defaults to model.INIT_DATE + model.DAYS_AFTER_INIT_DATE
638642
fig: matplotlib.pyplot.figure
639643
if this plot is part of a larger subplots, pass the figure object here, otherwise one is created
640644
ax: matplotlib.pyplot.axis
@@ -645,6 +649,11 @@ def plot_diffrax_solution(
645649
fig, ax : matplotlib.Figure/axis object
646650
objects containing the matplotlib figure and axis for further modifications if needed.
647651
"""
652+
# default start date is based on the model INIT date and in the case of epochs, days after initialization
653+
if start_date is None:
654+
start_date = self.INIT_DATE + datetime.timedelta(
655+
days=self.DAYS_AFTER_INIT_DATE
656+
)
648657
plot_commands = [x.strip() for x in plot_commands]
649658
if fig is None or ax is None:
650659
fig, ax = plt.subplots(
@@ -680,11 +689,8 @@ def plot_diffrax_solution(
680689
# if we explicitly set plot_labels, override the default ones.
681690
label = plot_labels[idx] if plot_labels is not None else label
682691
days = list(range(len(timeline)))
683-
# incidence is aggregated weekly, so our array increases 7 days at a time
684-
# if command == "incidence":
685-
# days = [day * 7 for day in days]
686692
x_axis = [
687-
self.INIT_DATE + datetime.timedelta(days=day) for day in days
693+
start_date + datetime.timedelta(days=day) for day in days
688694
]
689695
if command == "incidence":
690696
# plot both logged and unlogged version by default
@@ -1162,12 +1168,13 @@ def load_init_infection_infected_and_exposed_dist_via_abm(self):
11621168
if self.INITIAL_INFECTIONS is None:
11631169
self.INITIAL_INFECTIONS = self.POP_SIZE * proportion_infected
11641170

1165-
def build_cross_immunity_matrix(self):
1171+
def load_cross_immunity_matrix(self):
11661172
"""
11671173
Loads the Crossimmunity matrix given the strain interactions matrix.
11681174
Strain interactions matrix is a matrix of shape (num_strains, num_strains) representing the relative immune escape risk
11691175
of those who are being challenged by a strain in dim 0 but have recovered from a strain in dim 1.
11701176
Neither the strain interactions matrix nor the crossimmunity matrix take into account waning.
1177+
11711178
Updates
11721179
----------
11731180
self.CROSSIMMUNITY_MATRIX:
@@ -1368,6 +1375,78 @@ def default(self, obj):
13681375
else: # if given empty file, just return JSON string
13691376
return json.dumps(self.config_file, indent=4, cls=CustomEncoder)
13701377

1378+
def collapse_strains(
1379+
self,
1380+
from_strain: str,
1381+
to_strain: str,
1382+
new_config: config,
1383+
):
1384+
"""
1385+
Modifies `self` such that all infections, infection histories, and enums that refer to the strain in `from_strain`
1386+
now point to `to_strain`. Number of strains are preserved, shifting all strain indexes left by 1
1387+
to make space for this new most-recent strain. New config is loaded to update strain specific values and indexes.
1388+
1389+
Example
1390+
----------
1391+
self.STRAIN_IDX["delta"] -> 0
1392+
self.STRAIN_IDX["omicron"] -> 1
1393+
self.STRAIN_IDX["BA2/BA5"] -> 2
1394+
self.collapse_strains("omicron", "delta") #collapses omicron and delta strains
1395+
self.STRAIN_IDX["delta"] -> 0
1396+
self.STRAIN_IDX["omicron"] -> *0*
1397+
self.STRAIN_IDX["BA2/BA5"] -> *1*
1398+
self.STRAIN_IDX[_] -> *2*
1399+
1400+
Parameters
1401+
----------
1402+
from_strain: str
1403+
the strain name of the strain being collapsed, whos references will be rerouted.
1404+
to_strain: str
1405+
the strain name of the strain being joined with from_strain, typically the oldest strain.
1406+
1407+
Modifies
1408+
----------
1409+
self.INITIAL_STATE
1410+
all compartments within initial state will be modified with new initial states
1411+
whos infection histories line up with the collapse strains and the new state.
1412+
1413+
all parameters within `new_config` will be used to override parameters within `self`
1414+
"""
1415+
from_strain_idx = self.STRAIN_IDX[from_strain]
1416+
to_strain_idx = self.STRAIN_IDX[to_strain]
1417+
(
1418+
immune_state_converter,
1419+
strain_converter,
1420+
) = utils.combined_strains_mapping(
1421+
from_strain_idx,
1422+
to_strain_idx,
1423+
self.NUM_STRAINS,
1424+
)
1425+
return_state = []
1426+
for idx, compartment in enumerate(self.INITIAL_STATE):
1427+
# we dont have a strain axis if are in the S compartment, otherwise we do
1428+
strain_axis = idx != self.IDX.S
1429+
strain_combined_compartment = utils.combine_strains(
1430+
compartment,
1431+
immune_state_converter,
1432+
strain_converter,
1433+
self.NUM_STRAINS,
1434+
strain_axis=strain_axis,
1435+
)
1436+
return_state.append(strain_combined_compartment)
1437+
1438+
# people who are actively infected with `from_strain` need to be combined together as well
1439+
self.INITIAL_STATE = tuple(return_state)
1440+
self.config_file = new_config
1441+
# use the new config to update things like STRAIN_IDX enum and strain_interactions matrix.
1442+
self.__dict__.update(**new_config.__dict__)
1443+
# end with some minor update tasks because our init_date likely changed
1444+
# along with our strain_interactions matrix
1445+
self.load_cross_immunity_matrix()
1446+
self.load_vaccination_model()
1447+
self.load_external_i_distributions()
1448+
self.load_contact_matrix()
1449+
13711450

13721451
def build_basic_mechanistic_model(config: config):
13731452
"""

model_odes/seip_model.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,6 @@ def seip_ode(state, t, parameters):
126126
# slice across age, strain, and wane. vaccination updates the vax column and also moves all to w0.
127127
# ex: diagonal movement from 1 shot in 4th waning compartment to 2 shots 0 waning compartment s[:, 0, 1, 3] -> s[:, 0, 2, 0]
128128
vax_counts = s * p.VACCINATION_RATES(t)[:, jnp.newaxis, :, jnp.newaxis]
129-
130-
# for vaccine_count in range(p.MAX_VAX_COUNT + 1):
131-
# # num of people who had vaccine_count shots and then are getting 1 more
132-
# s_vax_count = vax_counts[:, :, vaccine_count, :]
133-
# # people who just got vaccinated/recovered wont get another shot for at least 1 waning compartment time.
134-
# s_vax_count = s_vax_count.at[:, :, 0].set(0)
135-
# # sum all the people getting vaccines, across waning bins since they will be put in w0
136-
# vax_gained = jnp.sum(s_vax_count, axis=(-1))
137-
# # if people already at the max counted vaccinations, dont move them, only update waning
138-
# if vaccine_count == p.MAX_VAX_COUNT:
139-
# ds = ds.at[:, :, vaccine_count, 0].add(vax_gained)
140-
# else: # increment num_vaccines by 1, waning reset
141-
# ds = ds.at[:, :, vaccine_count + 1, 0].add(vax_gained)
142-
# # we moved everyone into their correct compartment, now remove them from their starting position
143-
# ds = ds.at[:, :, vaccine_count, :].add(-s_vax_count)
144-
145129
vax_counts = vax_counts.at[:, :, :, 0].set(0)
146130
vax_gained = jnp.sum(vax_counts, axis=(-1))
147131
ds = ds.at[:, :, p.MAX_VAX_COUNT, 0].add(vax_gained[:, :, p.MAX_VAX_COUNT])

0 commit comments

Comments
 (0)