@@ -218,6 +218,7 @@ def __init__(
218
218
219
219
self .t_name = t_name
220
220
self .states = states
221
+ self .state_names = list (states .keys ())
221
222
222
223
self .parameters = parameters
223
224
self .outputs = outputs
@@ -578,7 +579,7 @@ def compute_traj_loop(self, first_problem, inputs, outputs, t0=0., state0=None):
578
579
inputs [state_name + "_initial" ].squeeze ()
579
580
if state_name in self .traj_initial_state_input
580
581
else 0.
581
- for state_name in first_problem .states . keys ()
582
+ for state_name in first_problem .state_names
582
583
]).squeeze ()
583
584
584
585
while True :
@@ -641,7 +642,7 @@ def compute_traj_loop(self, first_problem, inputs, outputs, t0=0., state0=None):
641
642
642
643
outputs [output_name ] = sim_results [- 1 ].x [
643
644
- 1 ,
644
- list ( sim_problems [- 1 ].states . keys ()) .index (state_name )
645
+ sim_problems [- 1 ].state_names .index (state_name )
645
646
]
646
647
647
648
for output in self .traj_promote_final_output :
@@ -698,13 +699,13 @@ def compute_partials(self, inputs, J):
698
699
param_deriv = np .zeros (len (param_dict ))
699
700
700
701
if output in self .traj_final_state_output :
701
- costate [list ( next_prob .states . keys ()) .index (output )] = 1.
702
+ costate [next_prob .state_names .index (output )] = 1.
702
703
else : # in self.traj_promote_final_output
703
704
704
705
next_prob .state_equation_function (next_res .t [- 1 ], next_res .x [- 1 , :])
705
706
costate [:] = next_prob .compute_totals (
706
707
output ,
707
- list ( next_prob .states . keys ()) ,
708
+ next_prob .state_names ,
708
709
return_format = 'array'
709
710
).squeeze ()
710
711
@@ -745,12 +746,12 @@ def compute_partials(self, inputs, J):
745
746
num_active_event_channels += 1
746
747
dg_dx = np .zeros ((1 , prob .dim_state ))
747
748
748
- if channel_name in prob .states . keys () :
749
- dg_dx [0 , list ( prob .states . keys ()) .index (channel_name )] = 1.
749
+ if channel_name in prob .state_names :
750
+ dg_dx [0 , prob .state_names .index (channel_name )] = 1.
750
751
else :
751
752
dg_dx [0 , :] = prob .compute_totals (
752
753
[channel_name ],
753
- list ( prob .states . keys ()) ,
754
+ prob .state_names ,
754
755
return_format = 'array'
755
756
)
756
757
@@ -786,17 +787,17 @@ def compute_partials(self, inputs, J):
786
787
787
788
# here and co-state assume number of states is only decreasing
788
789
# forward in time
789
- for state_name in next_prob .states . keys () :
790
- state_idx = list ( next_prob .states . keys ()) .index (state_name )
790
+ for state_name in next_prob .state_names :
791
+ state_idx = next_prob .state_names .index (state_name )
791
792
792
- if state_name in prob .states . keys () :
793
+ if state_name in prob .state_names :
793
794
f_plus [
794
795
state_idx
795
- ] = plus_rate [list ( prob .states . keys ()) .index (state_name )]
796
+ ] = plus_rate [prob .state_names .index (state_name )]
796
797
797
798
# state_update[
798
- # list( next_prob.states.keys()) .index(state_name)
799
- # ] = x[list( prob.states.keys()) .index(state_name)]
799
+ # next_prob.state_names .index(state_name)
800
+ # ] = x[prob.state_names .index(state_name)]
800
801
801
802
# TODO: make sure index multiplying next_pronb costate
802
803
# lines up -- since costate is pre-filled to next_prob's
@@ -811,7 +812,7 @@ def compute_partials(self, inputs, J):
811
812
812
813
dh_j_dx = prob .compute_totals (
813
814
[state_name ],
814
- list ( prob .states . keys ()) ,
815
+ prob .state_names ,
815
816
return_format = 'array' ).squeeze ()
816
817
817
818
dh_dparam [state_idx , :] = prob .compute_totals (
@@ -820,15 +821,15 @@ def compute_partials(self, inputs, J):
820
821
return_format = 'array'
821
822
).squeeze ()
822
823
823
- for state_name_2 in prob .states . keys () :
824
+ for state_name_2 in prob .state_names :
824
825
# I'm actually computing dh_dx.T
825
826
# dh_dx rows are new state, columns are old state
826
827
# now, dh_dx.T rows are old state, columns are new
827
828
# so I think this is right
828
829
dh_dx [
829
- list ( next_prob .states . keys ()) .index (state_name_2 ),
830
+ next_prob .state_names .index (state_name_2 ),
830
831
state_idx ,
831
- ] = dh_j_dx [list ( prob .states . keys ()) .index (state_name_2 )]
832
+ ] = dh_j_dx [prob .state_names .index (state_name_2 )]
832
833
833
834
else :
834
835
state_update [
@@ -842,7 +843,7 @@ def compute_partials(self, inputs, J):
842
843
843
844
state_rate_names = [val ['rate' ] for _ , val in prob .states .items ()]
844
845
df_dx_data [idx , :, :] = prob .compute_totals (state_rate_names ,
845
- list ( prob .states . keys ()) ,
846
+ prob .state_names ,
846
847
return_format = 'array' ).T
847
848
if param_dict :
848
849
df_dparam_data [idx , ...] = prob .compute_totals (
@@ -957,7 +958,7 @@ def compute_partials(self, inputs, J):
957
958
# lamda_dot_plus = lamda_dot
958
959
if self .verbosity is Verbosity .DEBUG :
959
960
if np .any (state_disc ):
960
- print ("update is non-zero!" , prob , prob .states . keys () ,
961
+ print ("update is non-zero!" , prob , prob .state_names ,
961
962
state_disc , costate , lamda_dot )
962
963
print (
963
964
"inner product becomes..." ,
@@ -966,7 +967,7 @@ def compute_partials(self, inputs, J):
966
967
state_disc [None ,
967
968
:] @ dh_dx .T @ lamda_dot_plus [:, None ]
968
969
)
969
- print ("dh_dx for" , prob , prob .states . keys () , "\n " , dh_dx )
970
+ print ("dh_dx for" , prob , prob .state_names , "\n " , dh_dx )
970
971
print ("costate" , costate )
971
972
costate_update_terms = [
972
973
dh_dx .T @ costate [:, None ],
@@ -1063,17 +1064,17 @@ def co_state_rate(t, costate, *args):
1063
1064
1064
1065
# TODO: do co-states need unit changes? probably not...
1065
1066
for state_name in prob .state_names :
1066
- costate [list ( next_prob .states . keys ()) .index (
1067
- state_name )] = co_res .x [- 1 , list ( prob .states . keys ()) .index (state_name )]
1067
+ costate [next_prob .state_names .index (
1068
+ state_name )] = co_res .x [- 1 , prob .state_names .index (state_name )]
1068
1069
lamda_dot_plus [
1069
- list ( next_prob .states . keys ()) .index (state_name )
1070
- ] = lamda_dot_plus_rate [list ( prob .states . keys ()) .index (state_name )]
1070
+ next_prob .state_names .index (state_name )
1071
+ ] = lamda_dot_plus_rate [prob .state_names .index (state_name )]
1071
1072
1072
1073
for state_to_deriv , metadata in self .traj_initial_state_input .items ():
1073
1074
param_name = metadata ["name" ]
1074
1075
J [output_name , param_name ] = costate_reses [output ][- 1 ].x [
1075
1076
- 1 ,
1076
- list ( prob .states . keys ()) .index (state_to_deriv )
1077
+ prob .state_names .index (state_to_deriv )
1077
1078
]
1078
1079
for param_deriv_val , param_deriv_name in zip (param_deriv , param_dict ):
1079
1080
J [output_name , param_deriv_name ] = param_deriv_val
0 commit comments