20
20
from psymple .ported_objects import (
21
21
ParameterAssignment ,
22
22
DifferentialAssignment ,
23
+ PortedObject ,
23
24
)
24
25
25
26
@@ -28,6 +29,182 @@ class PopulationSystemError(Exception):
28
29
29
30
30
31
class System :
32
+ def __init__ (self , ported_object ):
33
+ self .variables = {}
34
+ self .parameters = {}
35
+
36
+ assert isinstance (ported_object , PortedObject )
37
+ compiled = ported_object .compile ()
38
+
39
+ self .create_time_variable ()
40
+
41
+ variable_assignments , parameter_assignments = compiled .get_assignments ()
42
+
43
+ variables , parameters = self .get_symbol_containers (variable_assignments , parameter_assignments )
44
+ self .create_simulation_variables (variable_assignments , variables + self .time , parameters )
45
+ self .create_simulation_parameters (parameter_assignments , variables + self .time , parameters )
46
+ self .update_update_rules ()
47
+
48
+ def create_time_variable (self ):
49
+ # At the moment the only global variable is time
50
+ self .time = SimVariable (Variable (T , 0.0 , "system time" ))
51
+ self .time .set_update_rule (
52
+ SimUpdateRule (
53
+ #self.time,
54
+ equation = "1" ,
55
+ variables = Variables (),
56
+ parameters = Parameters (),
57
+ description = "system time" ,
58
+ )
59
+ )
60
+
61
+ def get_symbol_containers (self , variable_assignments , parameter_assignments ):
62
+ variables = [SimVariable (assg .variable ) for assg in variable_assignments ]
63
+ parameters = [SimParameter (assg .parameter ) for assg in parameter_assignments ]
64
+ return Variables (variables ), Parameters (parameters )
65
+
66
+ def create_simulation_variables (self , variable_assignments , variables , parameters ):
67
+ for assg in variable_assignments :
68
+ update_rule = assg .to_update_rule (variables , parameters )
69
+ sim_variable = SimVariable (assg .variable )
70
+ sim_variable .set_update_rule (update_rule )
71
+ self .variables [str (assg .variable .symbol )] = sim_variable
72
+
73
+ def create_simulation_parameters (self , parameter_assignments , variables , parameters ):
74
+ for assg in parameter_assignments :
75
+ sim_parameter = SimParameter (assg .parameter )
76
+ sim_parameter .initialize_update_rule (variables , parameters )
77
+ self .parameters [str (assg .parameter .symbol )] = sim_parameter
78
+
79
+ def update_update_rules (self ):
80
+ variables = Variables (list (self .variables .values ()))
81
+ parameters = Parameters (list (self .parameters .values ()))
82
+ for var in self .variables .values ():
83
+ new_update_rule = SimUpdateRule .from_update_rule (var .update_rule , variables + self .time , parameters )
84
+ var .set_update_rule (new_update_rule )
85
+ for par in self .parameters .values ():
86
+ par .initialize_update_rule (variables + self .time , parameters )
87
+
88
+ def _compute_parameter_update_order (self ):
89
+ variable_symbols = {v .symbol for v in self .variables .values ()} | {T }
90
+ # print("params")
91
+ # for par in self.parameters:
92
+ # print(type(par), par)
93
+ parameter_symbols = {p .symbol : p for p in self .parameters .values ()}
94
+ # print("param symbol")
95
+ # for symbol in parameter_symbols:
96
+ # print(type(symbol), symbol)
97
+ G = nx .DiGraph ()
98
+ G .add_nodes_from (parameter_symbols )
99
+ for parameter in self .parameters .values ():
100
+ parsym = parameter .symbol
101
+ for dependency in parameter .dependent_parameters ():
102
+ if dependency .symbol in parameter_symbols :
103
+ G .add_edge (dependency .symbol , parsym )
104
+ elif dependency .symbol not in variable_symbols :
105
+ raise PopulationSystemError (
106
+ f"Parameter { parsym } references undefined symbol { dependency } "
107
+ )
108
+ try :
109
+ nodes = nx .topological_sort (G )
110
+ except nx .exception .NetworkXUnfeasible :
111
+ raise PopulationSystemError (
112
+ f"System parameters contain cyclic dependencies"
113
+ )
114
+ return list (nodes )
115
+
116
+
117
+
118
+ class Simulation :
119
+ def __init__ (self , system , solver = "discrete_int" ):
120
+ self .system = system
121
+ self .variables = system .variables
122
+ self .parameters = system .parameters
123
+ self .time = system .time
124
+ self .solver = solver
125
+
126
+ def _compute_substitutions (self ):
127
+ update_order = [str (par ) for par in self .system ._compute_parameter_update_order ()]
128
+ print (update_order )
129
+ variables = Variables (list (self .variables .values ())) + self .time
130
+ for parameter in update_order :
131
+ self .parameters [parameter ].substitute_parameters (variables )
132
+ for variable in self .variables .values ():
133
+ variable .substitute_parameters (variables )
134
+
135
+ #TODO: Remove variable dependency from update_rule
136
+
137
+ def simulate (self , t_end , ** options ):
138
+ self ._compute_substitutions ()
139
+ if self .solver == "discrete_int" :
140
+ assert "n_steps" in options .keys ()
141
+ n_steps = options ["n_steps" ]
142
+ solver = DiscreteIntegrator (self , t_end , n_steps )
143
+ solver .run ()
144
+
145
+ def plot_solution (self , variables , t_range = None ):
146
+ t_series = self .time .time_series
147
+ if t_range is None :
148
+ sl = slice (None , None )
149
+ else :
150
+ lower = bisect (t_series , t_range [0 ])
151
+ upper = bisect (t_series , t_range [1 ])
152
+ sl = slice (lower , upper )
153
+ if isinstance (variables , set ):
154
+ variables = {v : {} for v in variables }
155
+ legend = []
156
+ for var_name , options in variables .items ():
157
+ variable = self .variables [var_name ]
158
+ if isinstance (options , str ):
159
+ plt .plot (t_series [sl ], variable .time_series [sl ], options )
160
+ else :
161
+ plt .plot (t_series [sl ], variable .time_series [sl ], ** options )
162
+ legend .append (variable .symbol .name )
163
+ plt .legend (legend , loc = "best" )
164
+ plt .xlabel ("time" )
165
+ plt .grid ()
166
+ plt .show ()
167
+
168
+
169
+ class Solver :
170
+ def __init__ (self , simulation , t_end ):
171
+ if t_end <= 0 or not isinstance (t_end , int ):
172
+ raise ValueError (
173
+ "Simulation time must terminate at a positive integer, "
174
+ f"not '{ t_end } '."
175
+ )
176
+ self .t_end = t_end
177
+ self .simulation = simulation
178
+
179
+ class DiscreteIntegrator (Solver ):
180
+ def __init__ (self , simulation , t_end , n_steps ):
181
+ super ().__init__ (simulation , t_end )
182
+ self .n_steps = n_steps
183
+
184
+ def run (self ):
185
+ for i in range (self .t_end ):
186
+ self ._advance_time_unit (self .n_steps )
187
+
188
+ def _advance_time (self , time_step ):
189
+ self .simulation .time .update_buffer ()
190
+ for variable in self .simulation .variables .values ():
191
+ variable .update_buffer ()
192
+ for variable in self .simulation .variables .values ():
193
+ variable .update_time_series (time_step )
194
+ self .simulation .time .update_time_series (time_step )
195
+
196
+ def _advance_time_unit (self , n_steps ):
197
+ if n_steps <= 0 or not isinstance (n_steps , int ):
198
+ raise ValueError (
199
+ "Number of time steps in a day must be a positive integer, "
200
+ f"not '{ n_steps } '."
201
+ )
202
+ for i in range (n_steps ):
203
+ self ._advance_time (1 / n_steps )
204
+
205
+
206
+ '''
207
+ class System_old:
31
208
def __init__(
32
209
self, population=None, variable_assignments=[], parameter_assignments=[]
33
210
):
@@ -230,3 +407,4 @@ def plot_solution(self, variables, t_range=None):
230
407
plt.xlabel("time")
231
408
plt.grid()
232
409
plt.show()
410
+ '''
0 commit comments