Skip to content

Commit

Permalink
support running on more than 1 rank
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthelen committed Mar 22, 2024
1 parent 0fc312a commit 91c8546
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def setup(self):

class TopLevelGroup(om.Group):
def setup(self):
if self.comm.size!=2:
raise SystemError('Please launch with 2 processors')

# IVCs that feed into both parallel groups
self.add_subsystem('ivc', om.IndepVarComp(), promotes=['*'])

Expand Down
69 changes: 35 additions & 34 deletions examples/aerostructural/supersonic_panel/as_opt_remote_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,40 +54,41 @@
prob.cleanup()

# write out data
cr = om.CaseReader("optimization_history.sql")
driver_cases = cr.list_cases('driver')

case = cr.get_case(0)
cons = case.get_constraints()
dvs = case.get_design_vars()
objs = case.get_objectives()

with open("optimization_history.dat","w+") as f:

for i, k in enumerate(objs.keys()):
f.write('objective: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + str(cr.get_case(case_id).get_objectives(scaled=False)[k][0]) + '\n')
if prob.model.comm.rank==0:
cr = om.CaseReader("optimization_history.sql")
driver_cases = cr.list_cases('driver')

case = cr.get_case(0)
cons = case.get_constraints()
dvs = case.get_design_vars()
objs = case.get_objectives()

with open("optimization_history.dat","w+") as f:

for i, k in enumerate(objs.keys()):
f.write('objective: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + str(cr.get_case(case_id).get_objectives(scaled=False)[k][0]) + '\n')
f.write(' ' + '\n')

for i, k in enumerate(cons.keys()):
f.write('constraint: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + ' '.join(map(str,cr.get_case(case_id).get_constraints(scaled=False)[k])) + '\n')
f.write(' ' + '\n')

for i, k in enumerate(dvs.keys()):
f.write('DV: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + ' '.join(map(str,cr.get_case(case_id).get_design_vars(scaled=False)[k])) + '\n')
f.write(' ' + '\n')

f.write('run times, function\n')
for i in range(len(prob.model.remote.times_function)):
f.write(f'{prob.model.remote.times_function[i]}\n')
f.write(' ' + '\n')

for i, k in enumerate(cons.keys()):
f.write('constraint: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + ' '.join(map(str,cr.get_case(case_id).get_constraints(scaled=False)[k])) + '\n')
f.write('run times, gradient\n')
for i in range(len(prob.model.remote.times_gradient)):
f.write(f'{prob.model.remote.times_gradient[i]}\n')
f.write(' ' + '\n')

for i, k in enumerate(dvs.keys()):
f.write('DV: ' + k + '\n')
for j, case_id in enumerate(driver_cases):
f.write(str(j) + ' ' + ' '.join(map(str,cr.get_case(case_id).get_design_vars(scaled=False)[k])) + '\n')
f.write(' ' + '\n')

f.write('run times, function\n')
for i in range(len(prob.model.remote.times_function)):
f.write(f'{prob.model.remote.times_function[i]}\n')
f.write(' ' + '\n')

f.write('run times, gradient\n')
for i in range(len(prob.model.remote.times_gradient)):
f.write(f'{prob.model.remote.times_gradient[i]}\n')
f.write(' ' + '\n')
64 changes: 36 additions & 28 deletions mphys/network/remote_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,36 @@ def initialize(self):
self.options.declare('use_derivative_coloring', default=False, types=bool, desc="assign derivative coloring to objective/constraints. Only for cases with parallel servers")

def setup(self):
if self.comm.size>1:
raise SystemError('Using Remote Component on more than 1 rank is not supported')
self.time_estimate_multiplier = self.options['time_estimate_multiplier']
self.time_estimate_buffer = self.options['time_estimate_buffer']
self.reboot_only_on_function_call = self.options['reboot_only_on_function_call']
self.dump_json = self.options['dump_json']
self.dump_separate_json = self.options['dump_separate_json']
self.var_naming_dot_replacement = self.options['var_naming_dot_replacement']
self.additional_remote_inputs = self.options['additional_remote_inputs']
self.additional_remote_outputs = self.options['additional_remote_outputs']
self.use_derivative_coloring = self.options['use_derivative_coloring']
self.derivative_coloring_num = 0
self.last_analysis_completed_time = time.time() # for tracking down time between function/gradient calls
if self.dump_separate_json:
self.dump_json = True

self._setup_server_manager()

# for tracking model times, and determining whether to relaunch servers
self.times_function = np.array([])
self.times_gradient = np.array([])

# get baseline model
print(f'CLIENT (subsystem {self.name}): Running model from setup to get design problem info', flush=True)
output_dict = self.evaluate_model(command='initialize',
remote_input_dict={'additional_inputs': self.additional_remote_inputs,
'additional_outputs': self.additional_remote_outputs,
'component_name': self.name})
output_dict = None
if self.comm.rank==0:
self.time_estimate_multiplier = self.options['time_estimate_multiplier']
self.time_estimate_buffer = self.options['time_estimate_buffer']
self.reboot_only_on_function_call = self.options['reboot_only_on_function_call']
self.dump_json = self.options['dump_json']
self.dump_separate_json = self.options['dump_separate_json']
self.additional_remote_inputs = self.options['additional_remote_inputs']
self.additional_remote_outputs = self.options['additional_remote_outputs']
self.last_analysis_completed_time = time.time() # for tracking down time between function/gradient calls
if self.dump_separate_json:
self.dump_json = True

self._setup_server_manager()

# for tracking model times, and determining whether to relaunch servers
self.times_function = np.array([])
self.times_gradient = np.array([])

# get baseline model
print(f'CLIENT (subsystem {self.name}): Running model from setup to get design problem info', flush=True)
output_dict = self.evaluate_model(command='initialize',
remote_input_dict={'additional_inputs': self.additional_remote_inputs,
'additional_outputs': self.additional_remote_outputs,
'component_name': self.name})
output_dict = self.comm.bcast(output_dict)

self._add_design_inputs_from_baseline_model(output_dict)
self._add_objectives_from_baseline_model(output_dict)
Expand All @@ -71,8 +73,11 @@ def setup(self):
self.declare_partials('*', '*')

def compute(self,inputs,outputs):
input_dict = self._create_input_dict_for_server(inputs)
remote_dict = self.evaluate_model(remote_input_dict=input_dict, command='evaluate')
remote_dict = None
if self.comm.rank==0:
input_dict = self._create_input_dict_for_server(inputs)
remote_dict = self.evaluate_model(remote_input_dict=input_dict, command='evaluate')
remote_dict = self.comm.bcast(remote_dict)

self._assign_objectives_from_remote_output(remote_dict, outputs)
self._assign_constraints_from_remote_output(remote_dict, outputs)
Expand All @@ -81,8 +86,11 @@ def compute(self,inputs,outputs):
def compute_partials(self, inputs, partials):
# NOTE: this will not use of and wrt inputs, if given in outer script's compute_totals/check_totals

input_dict = self._create_input_dict_for_server(inputs)
remote_dict = self.evaluate_model(remote_input_dict=input_dict, command='evaluate derivatives')
remote_dict = None
if self.comm.rank==0:
input_dict = self._create_input_dict_for_server(inputs)
remote_dict = self.evaluate_model(remote_input_dict=input_dict, command='evaluate derivatives')
remote_dict = self.comm.bcast(remote_dict)

self._assign_objective_partials_from_remote_output(remote_dict, partials)
self._assign_constraint_partials_from_remote_output(remote_dict, partials)
Expand Down
6 changes: 3 additions & 3 deletions mphys/network/zmq_pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ def enough_time_is_remaining(self, estimated_model_time):

def job_has_expired(self):
self.job.update_job_state()
if self.job.state!='R':
if self.job.state=='R':
return False
else:
if self.job_expiration_max_restarts is not None:
if self.job_expiration_restarts+1 > self.job_expiration_max_restarts:
self.stop_server()
raise RuntimeError(f'CLIENT (subsystem {self.component_name}): Reached maximum number of job expiration restarts')
self.job_expiration_restarts += 1
print(f'CLIENT (subsystem {self.component_name}): Job no longer running; flagging for job restart')
return True
else:
return False

def _port_is_in_use(self, port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand Down

0 comments on commit 91c8546

Please sign in to comment.