Skip to content

Commit

Permalink
Fix calculate_derivatives for PETSc (#1342)
Browse files Browse the repository at this point in the history
* fix calculate derivative

* run Black

* test central difference scheme

* simplify conditionals

* attempt to get around pylint error

* Disable Pylint false positive

* pull outside for fix

* remove redundant pylint check

---------

Co-authored-by: Ludovico Bianchi <lbianchi@lbl.gov>
  • Loading branch information
dallan-keylogic and lbianchi-lbl authored Mar 1, 2024
1 parent e52a1ce commit 9de79c6
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 21 deletions.
79 changes: 59 additions & 20 deletions idaes/core/solvers/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyomo.core.expr.visitor import identify_variables
import pyomo.dae as pyodae
from pyomo.common import Executable
from pyomo.dae.flatten import flatten_dae_components
from pyomo.dae.flatten import flatten_dae_components, slice_component_along_sets
from pyomo.util.subsystems import (
TemporarySubsystemManager,
create_subsystem_block,
Expand Down Expand Up @@ -428,7 +428,7 @@ def petsc_dae_by_time_element(
symbolic_solver_labels=True,
between=None,
interpolate=True,
calculate_derivatives=True,
calculate_derivatives=False,
previous_trajectory=None,
representative_time=None,
snes_options=None,
Expand Down Expand Up @@ -711,17 +711,17 @@ def petsc_dae_by_time_element(
# May not have trajectory from fixed variables and they
# shouldn't change anyway, so only set not fixed vars
var[t].value = vec[i]
if calculate_derivatives:
# the petsc solver interface does not currently return time
# derivatives, and if it did, they would be estimated based on a
# smaller time step. This option uses Pyomo.DAE's discretization
# equations to calculate the time derivative values
calculate_time_derivatives(m, time)
# return the solver results and trajectory if available
if calculate_derivatives:
# the petsc solver interface does not currently return time
# derivatives, and if it did, they would be estimated based on a
# smaller time step. This option uses Pyomo.DAE's discretization
# equations to calculate the time derivative values
calculate_time_derivatives(m, time, between=between)
# return the solver results and trajectory if available
return PetscDAEResults(results=res_list, trajectory=tj)


def calculate_time_derivatives(m, time):
def calculate_time_derivatives(m, time, between=None):
"""Calculate the derivative values from the discretization equations.
Args:
Expand All @@ -731,19 +731,58 @@ def calculate_time_derivatives(m, time):
Returns:
None
"""
# Leave between an optional argument for backwards compatibility
if between is None:
between = time
for var in m.component_objects(pyo.Var):
if isinstance(var, pyodae.DerivativeVar):
if time in ComponentSet(var.get_continuousset_list()):
parent = var.parent_block()
name = var.local_name + "_disc_eq"
disc_eq = getattr(parent, name)
for i, v in var.items():
try:
if disc_eq[i].active:
v.value = 0 # Make sure there is a value
calculate_variable_from_constraint(v, disc_eq[i])
except KeyError:
pass # discretization equation may not exist at first time
parent_block = var.parent_block()
disc_eq = getattr(parent_block, var.local_name + "_disc_eq")

deriv_dict = dict(
(key, pyo.Reference(slc))
for key, slc in slice_component_along_sets(var, (time,))
)
disc_dict = dict(
(key, pyo.Reference(slc))
for key, slc in slice_component_along_sets(disc_eq, (time,))
)

for key, deriv in deriv_dict.items():
# state = state_dict[key]
disc_eq = disc_dict[key]
for t in time:
if t < between.first() or t > between.last():
# Outside of integration range, skip calculation
continue
old_value = deriv[t].value
try:
# TODO This calculates the value of the derivative even
# if one of the state var values is from outside the
# integration range, so long as it's initialized. Is
# this the desired behavior?
if disc_eq[t].active and not deriv[t].fixed:
deriv[t].value = 0 # Make sure there is a value
calculate_variable_from_constraint(deriv[t], disc_eq[t])
except KeyError as err:
# Discretization and continuity equations may or may not exist at the first or last time
# points depending on the method. Backwards skips first, forwards skips last, central skips
# both (which means the user needs to provide additional equations)
if t == time.first() or t == time.last():
pass
else:
raise err
except ValueError as err:
# At edges of between, it's unclear which adjacent
# values of state variables have been populated.
# Therefore we might get hit with value errors.
if t == between.first() or t == between.last():
# Reset deriv value to old value
if disc_eq[t].active and not deriv[t].fixed:
deriv[t].value = old_value
else:
raise err


class PetscTrajectory(object):
Expand Down
Loading

0 comments on commit 9de79c6

Please sign in to comment.