Skip to content

Commit d5a24eb

Browse files
committed
PR feedback
1 parent a4f04ed commit d5a24eb

File tree

3 files changed

+27
-50
lines changed

3 files changed

+27
-50
lines changed

aviary/mission/flops_based/ode/mission_ODE.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -225,30 +225,6 @@ def setup(self):
225225

226226
print_level = 0 if analysis_scheme is AnalysisScheme.SHOOTING else 2
227227

228-
if analysis_scheme is AnalysisScheme.SHOOTING and False:
229-
from aviary.utils.functions import create_printcomp
230-
dummy_comp = create_printcomp(
231-
all_inputs=[
232-
't_curr',
233-
Mission.Design.RESERVE_FUEL,
234-
Dynamic.Mission.MASS,
235-
Dynamic.Mission.DISTANCE,
236-
Dynamic.Mission.ALTITUDE,
237-
Dynamic.Mission.FLIGHT_PATH_ANGLE,
238-
],
239-
input_units={
240-
't_curr': 's',
241-
Dynamic.Mission.FLIGHT_PATH_ANGLE: 'deg',
242-
Dynamic.Mission.DISTANCE: 'NM',
243-
})
244-
self.add_subsystem(
245-
"dummy_comp",
246-
dummy_comp(),
247-
promotes_inputs=["*"],)
248-
self.set_input_defaults(
249-
Dynamic.Mission.DISTANCE, val=0, units='NM')
250-
self.set_input_defaults('t_curr', val=0, units='s')
251-
252228
self.nonlinear_solver = om.NewtonSolver(solve_subsystems=True,
253229
atol=1.0e-10,
254230
rtol=1.0e-10,

aviary/mission/gasp_based/ode/time_integration_base_classes.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def __init__(
218218

219219
self.t_name = t_name
220220
self.states = states
221+
self.state_names = list(states.keys())
221222

222223
self.parameters = parameters
223224
self.outputs = outputs
@@ -578,7 +579,7 @@ def compute_traj_loop(self, first_problem, inputs, outputs, t0=0., state0=None):
578579
inputs[state_name+"_initial"].squeeze()
579580
if state_name in self.traj_initial_state_input
580581
else 0.
581-
for state_name in first_problem.states.keys()
582+
for state_name in first_problem.state_names
582583
]).squeeze()
583584

584585
while True:
@@ -641,7 +642,7 @@ def compute_traj_loop(self, first_problem, inputs, outputs, t0=0., state0=None):
641642

642643
outputs[output_name] = sim_results[-1].x[
643644
-1,
644-
list(sim_problems[-1].states.keys()).index(state_name)
645+
sim_problems[-1].state_names.index(state_name)
645646
]
646647

647648
for output in self.traj_promote_final_output:
@@ -698,13 +699,13 @@ def compute_partials(self, inputs, J):
698699
param_deriv = np.zeros(len(param_dict))
699700

700701
if output in self.traj_final_state_output:
701-
costate[list(next_prob.states.keys()).index(output)] = 1.
702+
costate[next_prob.state_names.index(output)] = 1.
702703
else: # in self.traj_promote_final_output
703704

704705
next_prob.state_equation_function(next_res.t[-1], next_res.x[-1, :])
705706
costate[:] = next_prob.compute_totals(
706707
output,
707-
list(next_prob.states.keys()),
708+
next_prob.state_names,
708709
return_format='array'
709710
).squeeze()
710711

@@ -745,12 +746,12 @@ def compute_partials(self, inputs, J):
745746
num_active_event_channels += 1
746747
dg_dx = np.zeros((1, prob.dim_state))
747748

748-
if channel_name in prob.states.keys():
749-
dg_dx[0, list(prob.states.keys()).index(channel_name)] = 1.
749+
if channel_name in prob.state_names:
750+
dg_dx[0, prob.state_names.index(channel_name)] = 1.
750751
else:
751752
dg_dx[0, :] = prob.compute_totals(
752753
[channel_name],
753-
list(prob.states.keys()),
754+
prob.state_names,
754755
return_format='array'
755756
)
756757

@@ -786,17 +787,17 @@ def compute_partials(self, inputs, J):
786787

787788
# here and co-state assume number of states is only decreasing
788789
# forward in time
789-
for state_name in next_prob.states.keys():
790-
state_idx = list(next_prob.states.keys()).index(state_name)
790+
for state_name in next_prob.state_names:
791+
state_idx = next_prob.state_names.index(state_name)
791792

792-
if state_name in prob.states.keys():
793+
if state_name in prob.state_names:
793794
f_plus[
794795
state_idx
795-
] = plus_rate[list(prob.states.keys()).index(state_name)]
796+
] = plus_rate[prob.state_names.index(state_name)]
796797

797798
# state_update[
798-
# list(next_prob.states.keys()).index(state_name)
799-
# ] = x[list(prob.states.keys()).index(state_name)]
799+
# next_prob.state_names.index(state_name)
800+
# ] = x[prob.state_names.index(state_name)]
800801

801802
# TODO: make sure index multiplying next_pronb costate
802803
# lines up -- since costate is pre-filled to next_prob's
@@ -811,7 +812,7 @@ def compute_partials(self, inputs, J):
811812

812813
dh_j_dx = prob.compute_totals(
813814
[state_name],
814-
list(prob.states.keys()),
815+
prob.state_names,
815816
return_format='array').squeeze()
816817

817818
dh_dparam[state_idx, :] = prob.compute_totals(
@@ -820,15 +821,15 @@ def compute_partials(self, inputs, J):
820821
return_format='array'
821822
).squeeze()
822823

823-
for state_name_2 in prob.states.keys():
824+
for state_name_2 in prob.state_names:
824825
# I'm actually computing dh_dx.T
825826
# dh_dx rows are new state, columns are old state
826827
# now, dh_dx.T rows are old state, columns are new
827828
# so I think this is right
828829
dh_dx[
829-
list(next_prob.states.keys()).index(state_name_2),
830+
next_prob.state_names.index(state_name_2),
830831
state_idx,
831-
] = dh_j_dx[list(prob.states.keys()).index(state_name_2)]
832+
] = dh_j_dx[prob.state_names.index(state_name_2)]
832833

833834
else:
834835
state_update[
@@ -842,7 +843,7 @@ def compute_partials(self, inputs, J):
842843

843844
state_rate_names = [val['rate'] for _, val in prob.states.items()]
844845
df_dx_data[idx, :, :] = prob.compute_totals(state_rate_names,
845-
list(prob.states.keys()),
846+
prob.state_names,
846847
return_format='array').T
847848
if param_dict:
848849
df_dparam_data[idx, ...] = prob.compute_totals(
@@ -957,7 +958,7 @@ def compute_partials(self, inputs, J):
957958
# lamda_dot_plus = lamda_dot
958959
if self.verbosity is Verbosity.DEBUG:
959960
if np.any(state_disc):
960-
print("update is non-zero!", prob, prob.states.keys(),
961+
print("update is non-zero!", prob, prob.state_names,
961962
state_disc, costate, lamda_dot)
962963
print(
963964
"inner product becomes...",
@@ -966,7 +967,7 @@ def compute_partials(self, inputs, J):
966967
state_disc[None,
967968
:] @ dh_dx.T @ lamda_dot_plus[:, None]
968969
)
969-
print("dh_dx for", prob, prob.states.keys(), "\n", dh_dx)
970+
print("dh_dx for", prob, prob.state_names, "\n", dh_dx)
970971
print("costate", costate)
971972
costate_update_terms = [
972973
dh_dx.T @ costate[:, None],
@@ -1063,17 +1064,17 @@ def co_state_rate(t, costate, *args):
10631064

10641065
# TODO: do co-states need unit changes? probably not...
10651066
for state_name in prob.state_names:
1066-
costate[list(next_prob.states.keys()).index(
1067-
state_name)] = co_res.x[-1, list(prob.states.keys()).index(state_name)]
1067+
costate[next_prob.state_names.index(
1068+
state_name)] = co_res.x[-1, prob.state_names.index(state_name)]
10681069
lamda_dot_plus[
1069-
list(next_prob.states.keys()).index(state_name)
1070-
] = lamda_dot_plus_rate[list(prob.states.keys()).index(state_name)]
1070+
next_prob.state_names.index(state_name)
1071+
] = lamda_dot_plus_rate[prob.state_names.index(state_name)]
10711072

10721073
for state_to_deriv, metadata in self.traj_initial_state_input.items():
10731074
param_name = metadata["name"]
10741075
J[output_name, param_name] = costate_reses[output][-1].x[
10751076
-1,
1076-
list(prob.states.keys()).index(state_to_deriv)
1077+
prob.state_names.index(state_to_deriv)
10771078
]
10781079
for param_deriv_val, param_deriv_name in zip(param_deriv, param_dict):
10791080
J[output_name, param_deriv_name] = param_deriv_val

aviary/validation_cases/benchmark_tests/test_bench_GwGm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_bench_GwGm_SNOPT(self):
6161
assert_near_equal(prob.get_val("traj.desc2.timeseries.distance")[-1],
6262
3675.0, tolerance=rtol)
6363

64-
@require_pyoptsparse(optimizer="SNOPT")
64+
@require_pyoptsparse(optimizer="IPOPT")
6565
def test_bench_GwGm_shooting(self):
6666
local_phase_info = deepcopy(phase_info)
6767
prob = run_aviary('models/test_aircraft/aircraft_for_bench_GwGm.csv',

0 commit comments

Comments
 (0)