Skip to content

Commit

Permalink
Improve order of calculations with topological_sort
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdh7 committed Dec 17, 2024
1 parent d16a37f commit 6f6c61b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 52 deletions.
97 changes: 47 additions & 50 deletions PyCO2SYS/engine/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def __init__(self, values=None, opts=None):
core_known = np.array([v in values for v in parameters_core])
icase_all = np.arange(1, len(parameters_core) + 1)
icase = icase_all[core_known]
assert len(icase) < 3, "You may not provide more than 2 known core parameters."
assert len(icase) < 3, "You cannot provide more than 2 known core parameters."
if len(icase) == 0:
icase = np.array(0)
elif len(icase) == 2:
Expand Down Expand Up @@ -923,57 +923,48 @@ def _assemble(self, icase, values):
self.nodes_original = list(k for k, v in values.items() if v is not None)
return graph, funcs, values

def _get(self, parameters, graph, funcs, values, save_steps, verbose):
def _get(self, parameters, values, save_steps, verbose):
def printv(*args, **kwargs):
if verbose:
print(*args, **kwargs)

# needs: which intermediate parameters we need to get the requested parameters
graph_unknown = graph.copy()
graph_unknown = self.graph.copy()
graph_unknown.remove_nodes_from([v for v in values if v not in parameters])
self_values = values.copy() # what is already known
results = {} # values for the requested parameters will go in here
needs = parameters.copy()
for p in parameters:
needs = needs | nx.ancestors(graph_unknown, p)
# The got counter increments each time we successfully get a value, either from
# the arguments, already-calculated values, or by calculating it.
# The loop stops once got reaches the number of parameters in `needs`, because
# then we're done.
got = 0
# We will cycle through the set of needed parameters
needs_cycle = itertools.cycle(needs)
self_values = values.copy() # what is already known
results = {} # values for the requested parameters will go in here
while got < len(needs):
p = next(needs_cycle)
needs = [p for p in nx.topological_sort(self.graph) if p in needs]
for p in needs:
printv("")
printv(p)
if p in self_values:
if p not in results:
results[p] = self_values[p]
got += 1
printv("{} is available!".format(p))
results[p] = self_values[p]
printv("{} is already available!".format(p))
else:
priors = graph.pred[p]
priors = self.graph.pred[p]
if len(priors) == 0 or all([r in self_values for r in priors]):
printv("Calculating {}".format(p))
self_values[p] = funcs[p](
printv("Calculating {}...".format(p))
self_values[p] = self.funcs[p](
*[
self_values[r]
for r in funcs[p].__code__.co_varnames[
: funcs[p].__code__.co_argcount
for r in self.funcs[p].__code__.co_varnames[
: self.funcs[p].__code__.co_argcount
]
]
)
# state 2 means that the value was calculated internally
if save_steps:
nx.set_node_attributes(graph, {p: 2}, name="state")
for f in funcs[p].__code__.co_varnames[
: funcs[p].__code__.co_argcount
nx.set_node_attributes(self.graph, {p: 2}, name="state")
for f in self.funcs[p].__code__.co_varnames[
: self.funcs[p].__code__.co_argcount
]:
nx.set_edge_attributes(graph, {(f, p): 2}, name="state")
nx.set_edge_attributes(
self.graph, {(f, p): 2}, name="state"
)
results[p] = self_values[p]
got += 1
printv("Got", got, "of", len(set(needs)))
# Get rid of jax overhead on results
for k, v in results.items():
try:
Expand All @@ -995,7 +986,7 @@ def printv(*args, **kwargs):
except AttributeError:
pass
values.update(self_values)
return results, graph, values
return results, values

def solve(self, parameters=None, save_steps=True, verbose=False):
"""Calculate and return parameter(s) and (optionally) save them internally.
Expand All @@ -1022,9 +1013,7 @@ def solve(self, parameters=None, save_steps=True, verbose=False):
parameters = [parameters]
parameters = set(parameters) # get rid of duplicates
# Solve the system
results, self.graph, self.values = self._get(
parameters, self.graph, self.funcs, self.values, save_steps, verbose
)
results, self.values = self._get(parameters, self.values, save_steps, verbose)
results = {k: v for k, v in results.items() if k in parameters}
return results

Expand Down Expand Up @@ -1156,45 +1145,53 @@ def get_func_of(self, var_of):
The created function has the signature
value_of = get_value_of(**values_original)
value_of = get_value_of(**value)
where the ``values_original`` are the originally known values:
where the ``values`` are the originally user-defined values, obtained with
either of the following:
values_original = {k: sys.values[k] for k in sys.nodes_original}
values_original = sys.get_values_original()
"""
# We get a sub-graph of the node of interest and all its ancestors, excluding
# originally fixed / user-defined values
nodes_vo = nx.ancestors(self.graph, var_of)
nodes_vo.add(var_of)
nodes_vo = onp.array([n for n in nodes_vo if n not in self.nodes_original])
nodes_vo_all = nx.ancestors(self.graph, var_of)
nodes_vo_all.add(var_of)
# nodes_vo = onp.array([n for n in nodes_vo if n not in self.nodes_original])
nodes_vo = [n for n in nodes_vo_all if n not in self.nodes_original]
graph_vo = self.graph.subgraph(nodes_vo)
# We need to know what order to run the functions in. The approach below is a
# bit of a bodge --- as far as I can see, it does what I want, but not sure why
# and can't be certain it will always do that; haven't found another way yet
pos = nx.nx_agraph.graphviz_layout(graph_vo, prog="dot")
rank = onp.argsort([pos[n][1] for n in nodes_vo])
nodes_vo = nodes_vo[rank][::-1]

def get_value_of(**values_original):
values_original = values_original.copy()
def get_value_of(**values):
values = values.copy()
# This loops through the functions in the correct order determined above so
# we end up calculating the value of interest, which is returned
for n in nodes_vo:
values_original.update(
for n in nx.topological_sort(graph_vo):
values.update(
{
n: self.funcs[n](
*[
values_original[v]
values[v]
for v in self.funcs[n].__code__.co_varnames[
: self.funcs[n].__code__.co_argcount
]
]
)
}
)
return values_original[var_of]
return values[var_of]

# Generate docstring
get_value_of.__doc__ = (
"Calculate ``{}``.".format(var_of)
+ "\n\nParameters\n----------"
+ "\nvalues : dict"
+ "\n Key-value pairs for the following parameters:"
)
for p in nodes_vo_all:
if p in self.nodes_original:
get_value_of.__doc__ += "\n {}".format(p)
get_value_of.__doc__ += "\n\nReturns\n-------"
get_value_of.__doc__ += "\n{}".format(var_of)
return get_value_of

def get_func_of_from_wrt(self, get_value_of, var_wrt):
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ dependencies = [
"jax",
"networkx",
"matplotlib",
"graphviz",
"pygraphviz",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down
7 changes: 7 additions & 0 deletions tests/_test_topsort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# %%
from PyCO2SYS import CO2System

sys = CO2System(values=dict(dic=2100, alkalinity=2300))
sys.solve("pH", verbose=False)
sys.solve("fCO2", verbose=True)
# TOPOLOGICAL SORT!

0 comments on commit 6f6c61b

Please sign in to comment.