Skip to content

Commit

Permalink
Fix autograd => jax issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdh7 committed Dec 16, 2024
1 parent fa1e496 commit 368127c
Showing 2 changed files with 122 additions and 1,254 deletions.
246 changes: 122 additions & 124 deletions PyCO2SYS/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from jax import numpy as np

from .. import convert, equilibria, salts, solve
from . import nd, system
from . import system


def condition(input_locals, npts=None):
@@ -84,129 +84,127 @@ def condition(input_locals, npts=None):


# Define all gradable outputs
gradables = np.array(
[
"TAlk",
"TCO2",
"pHin",
"pCO2in",
"fCO2in",
"HCO3in",
"CO3in",
"CO2in",
"BAlkin",
"OHin",
"PAlkin",
"SiAlkin",
"NH3Alkin",
"H2SAlkin",
"Hfreein",
"RFin",
"OmegaCAin",
"OmegaARin",
"xCO2in",
"pHout",
"pCO2out",
"fCO2out",
"HCO3out",
"CO3out",
"CO2out",
"BAlkout",
"OHout",
"PAlkout",
"SiAlkout",
"NH3Alkout",
"H2SAlkout",
"Hfreeout",
"RFout",
"OmegaCAout",
"OmegaARout",
"xCO2out",
"pHinTOTAL",
"pHinSWS",
"pHinFREE",
"pHinNBS",
"pHoutTOTAL",
"pHoutSWS",
"pHoutFREE",
"pHoutNBS",
"TEMPIN",
"TEMPOUT",
"PRESIN",
"PRESOUT",
"SAL",
"PO4",
"SI",
"NH3",
"H2S",
"K0input",
"K1input",
"K2input",
"pK1input",
"pK2input",
"KWinput",
"KBinput",
"KFinput",
"KSinput",
"KP1input",
"KP2input",
"KP3input",
"KSiinput",
"KNH3input",
"KH2Sinput",
"K0output",
"K1output",
"K2output",
"pK1output",
"pK2output",
"KWoutput",
"KBoutput",
"KFoutput",
"KSoutput",
"KP1output",
"KP2output",
"KP3output",
"KSioutput",
"KNH3output",
"KH2Soutput",
"TB",
"TF",
"TS",
"gammaTCin",
"betaTCin",
"omegaTCin",
"gammaTAin",
"betaTAin",
"omegaTAin",
"gammaTCout",
"betaTCout",
"omegaTCout",
"gammaTAout",
"betaTAout",
"omegaTAout",
"isoQin",
"isoQout",
"isoQapprox_in",
"isoQapprox_out",
"psi_in",
"psi_out",
"TCa",
"SIRin",
"SIRout",
"PAR1",
"PAR2",
"PengCorrection",
"FugFacinput",
"FugFacoutput",
"fHinput",
"fHoutput",
"RGas",
"KCainput",
"KCaoutput",
"KArinput",
"KAroutput",
]
)
gradables = [
"TAlk",
"TCO2",
"pHin",
"pCO2in",
"fCO2in",
"HCO3in",
"CO3in",
"CO2in",
"BAlkin",
"OHin",
"PAlkin",
"SiAlkin",
"NH3Alkin",
"H2SAlkin",
"Hfreein",
"RFin",
"OmegaCAin",
"OmegaARin",
"xCO2in",
"pHout",
"pCO2out",
"fCO2out",
"HCO3out",
"CO3out",
"CO2out",
"BAlkout",
"OHout",
"PAlkout",
"SiAlkout",
"NH3Alkout",
"H2SAlkout",
"Hfreeout",
"RFout",
"OmegaCAout",
"OmegaARout",
"xCO2out",
"pHinTOTAL",
"pHinSWS",
"pHinFREE",
"pHinNBS",
"pHoutTOTAL",
"pHoutSWS",
"pHoutFREE",
"pHoutNBS",
"TEMPIN",
"TEMPOUT",
"PRESIN",
"PRESOUT",
"SAL",
"PO4",
"SI",
"NH3",
"H2S",
"K0input",
"K1input",
"K2input",
"pK1input",
"pK2input",
"KWinput",
"KBinput",
"KFinput",
"KSinput",
"KP1input",
"KP2input",
"KP3input",
"KSiinput",
"KNH3input",
"KH2Sinput",
"K0output",
"K1output",
"K2output",
"pK1output",
"pK2output",
"KWoutput",
"KBoutput",
"KFoutput",
"KSoutput",
"KP1output",
"KP2output",
"KP3output",
"KSioutput",
"KNH3output",
"KH2Soutput",
"TB",
"TF",
"TS",
"gammaTCin",
"betaTCin",
"omegaTCin",
"gammaTAin",
"betaTAin",
"omegaTAin",
"gammaTCout",
"betaTCout",
"omegaTCout",
"gammaTAout",
"betaTAout",
"omegaTAout",
"isoQin",
"isoQout",
"isoQapprox_in",
"isoQapprox_out",
"psi_in",
"psi_out",
"TCa",
"SIRin",
"SIRout",
"PAR1",
"PAR2",
"PengCorrection",
"FugFacinput",
"FugFacoutput",
"fHinput",
"fHoutput",
"RGas",
"KCainput",
"KCaoutput",
"KArinput",
"KAroutput",
]


def _outputs_grad(args, core_in, core_out, others_in, others_out, totals, Kis, Kos):
Loading

0 comments on commit 368127c

Please sign in to comment.