@@ -65,7 +65,7 @@ def __init__(self, **kwargs):
65
65
66
66
# GENERATE CROSS IMMUNITY MATRIX with protection from STRAIN_INTERACTIONS most recent infected strain.
67
67
if self .CROSSIMMUNITY_MATRIX is None :
68
- self .build_cross_immunity_matrix ()
68
+ self .load_cross_immunity_matrix ()
69
69
# if not given, load population fractions based on observed census data into self
70
70
if not self .INITIAL_POPULATION_FRACTIONS :
71
71
self .load_initial_population_fractions ()
@@ -303,7 +303,7 @@ def vaccination_rate(self, t):
303
303
"""
304
304
return jnp .exp (
305
305
utils .VAX_FUNCTION (
306
- t ,
306
+ t + self . DAYS_AFTER_INIT_DATE ,
307
307
self .VAX_MODEL_KNOT_LOCATIONS ,
308
308
self .VAX_MODEL_BASE_EQUATIONS ,
309
309
self .VAX_MODEL_KNOTS ,
@@ -528,6 +528,7 @@ def run(
528
528
max_steps = int (1e6 ),
529
529
)
530
530
self .solution = solution
531
+ self .solution_final_state = tuple (y [- 1 ] for y in solution .ys )
531
532
save_path = (
532
533
save_path if save else None
533
534
) # dont set a save path if we dont want to save
@@ -610,6 +611,7 @@ def plot_diffrax_solution(
610
611
plot_labels : list [str ] = None ,
611
612
save_path : str = None ,
612
613
log_scale : bool = None ,
614
+ start_date : datetime .date = None ,
613
615
fig : plt .figure = None ,
614
616
ax : plt .axis = None ,
615
617
):
@@ -635,6 +637,8 @@ def plot_diffrax_solution(
635
637
log_scale : bool, optional
636
638
whether or not to exclusively show the log or unlogged version of the plot, by default include both
637
639
in a stacked subplot.
640
+ start_date : date, optional
641
+ the start date of the x axis of the plot. Defaults to model.INIT_DATE + model.DAYS_AFTER_INIT_DATE
638
642
fig: matplotlib.pyplot.figure
639
643
if this plot is part of a larger subplots, pass the figure object here, otherwise one is created
640
644
ax: matplotlib.pyplot.axis
@@ -645,6 +649,11 @@ def plot_diffrax_solution(
645
649
fig, ax : matplotlib.Figure/axis object
646
650
objects containing the matplotlib figure and axis for further modifications if needed.
647
651
"""
652
+ # default start date is based on the model INIT date and in the case of epochs, days after initialization
653
+ if start_date is None :
654
+ start_date = self .INIT_DATE + datetime .timedelta (
655
+ days = self .DAYS_AFTER_INIT_DATE
656
+ )
648
657
plot_commands = [x .strip () for x in plot_commands ]
649
658
if fig is None or ax is None :
650
659
fig , ax = plt .subplots (
@@ -680,11 +689,8 @@ def plot_diffrax_solution(
680
689
# if we explicitly set plot_labels, override the default ones.
681
690
label = plot_labels [idx ] if plot_labels is not None else label
682
691
days = list (range (len (timeline )))
683
- # incidence is aggregated weekly, so our array increases 7 days at a time
684
- # if command == "incidence":
685
- # days = [day * 7 for day in days]
686
692
x_axis = [
687
- self . INIT_DATE + datetime .timedelta (days = day ) for day in days
693
+ start_date + datetime .timedelta (days = day ) for day in days
688
694
]
689
695
if command == "incidence" :
690
696
# plot both logged and unlogged version by default
@@ -1162,12 +1168,13 @@ def load_init_infection_infected_and_exposed_dist_via_abm(self):
1162
1168
if self .INITIAL_INFECTIONS is None :
1163
1169
self .INITIAL_INFECTIONS = self .POP_SIZE * proportion_infected
1164
1170
1165
- def build_cross_immunity_matrix (self ):
1171
+ def load_cross_immunity_matrix (self ):
1166
1172
"""
1167
1173
Loads the Crossimmunity matrix given the strain interactions matrix.
1168
1174
Strain interactions matrix is a matrix of shape (num_strains, num_strains) representing the relative immune escape risk
1169
1175
of those who are being challenged by a strain in dim 0 but have recovered from a strain in dim 1.
1170
1176
Neither the strain interactions matrix nor the crossimmunity matrix take into account waning.
1177
+
1171
1178
Updates
1172
1179
----------
1173
1180
self.CROSSIMMUNITY_MATRIX:
@@ -1368,6 +1375,78 @@ def default(self, obj):
1368
1375
else : # if given empty file, just return JSON string
1369
1376
return json .dumps (self .config_file , indent = 4 , cls = CustomEncoder )
1370
1377
1378
+ def collapse_strains (
1379
+ self ,
1380
+ from_strain : str ,
1381
+ to_strain : str ,
1382
+ new_config : config ,
1383
+ ):
1384
+ """
1385
+ Modifies `self` such that all infections, infection histories, and enums that refer to the strain in `from_strain`
1386
+ now point to `to_strain`. Number of strains are preserved, shifting all strain indexes left by 1
1387
+ to make space for this new most-recent strain. New config is loaded to update strain specific values and indexes.
1388
+
1389
+ Example
1390
+ ----------
1391
+ self.STRAIN_IDX["delta"] -> 0
1392
+ self.STRAIN_IDX["omicron"] -> 1
1393
+ self.STRAIN_IDX["BA2/BA5"] -> 2
1394
+ self.collapse_strains("omicron", "delta") #collapses omicron and delta strains
1395
+ self.STRAIN_IDX["delta"] -> 0
1396
+ self.STRAIN_IDX["omicron"] -> *0*
1397
+ self.STRAIN_IDX["BA2/BA5"] -> *1*
1398
+ self.STRAIN_IDX[_] -> *2*
1399
+
1400
+ Parameters
1401
+ ----------
1402
+ from_strain: str
1403
+ the strain name of the strain being collapsed, whos references will be rerouted.
1404
+ to_strain: str
1405
+ the strain name of the strain being joined with from_strain, typically the oldest strain.
1406
+
1407
+ Modifies
1408
+ ----------
1409
+ self.INITIAL_STATE
1410
+ all compartments within initial state will be modified with new initial states
1411
+ whos infection histories line up with the collapse strains and the new state.
1412
+
1413
+ all parameters within `new_config` will be used to override parameters within `self`
1414
+ """
1415
+ from_strain_idx = self .STRAIN_IDX [from_strain ]
1416
+ to_strain_idx = self .STRAIN_IDX [to_strain ]
1417
+ (
1418
+ immune_state_converter ,
1419
+ strain_converter ,
1420
+ ) = utils .combined_strains_mapping (
1421
+ from_strain_idx ,
1422
+ to_strain_idx ,
1423
+ self .NUM_STRAINS ,
1424
+ )
1425
+ return_state = []
1426
+ for idx , compartment in enumerate (self .INITIAL_STATE ):
1427
+ # we dont have a strain axis if are in the S compartment, otherwise we do
1428
+ strain_axis = idx != self .IDX .S
1429
+ strain_combined_compartment = utils .combine_strains (
1430
+ compartment ,
1431
+ immune_state_converter ,
1432
+ strain_converter ,
1433
+ self .NUM_STRAINS ,
1434
+ strain_axis = strain_axis ,
1435
+ )
1436
+ return_state .append (strain_combined_compartment )
1437
+
1438
+ # people who are actively infected with `from_strain` need to be combined together as well
1439
+ self .INITIAL_STATE = tuple (return_state )
1440
+ self .config_file = new_config
1441
+ # use the new config to update things like STRAIN_IDX enum and strain_interactions matrix.
1442
+ self .__dict__ .update (** new_config .__dict__ )
1443
+ # end with some minor update tasks because our init_date likely changed
1444
+ # along with our strain_interactions matrix
1445
+ self .load_cross_immunity_matrix ()
1446
+ self .load_vaccination_model ()
1447
+ self .load_external_i_distributions ()
1448
+ self .load_contact_matrix ()
1449
+
1371
1450
1372
1451
def build_basic_mechanistic_model (config : config ):
1373
1452
"""
0 commit comments