You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working on a nested optimization script with autodiff and wanted to try out JAX as an alternative to Julia. The reason a Nested optimization is important here is because that although fitting all 3 parameters together is faster, kf easily becomes stuck in a local minimum when the initial guess is poor and I want to be robust to this when fitting non simulated data. I am new to jax and want to see if anyone can identify any glaring reasons why this code would run slowly. It is significantly slower than expected and I haven't been able to run it long enough yet to see if the fit is correct. I have experimented with the optimistix package as well but have had less success so far.
Thank you for the help!
import jax
import jax.numpy as jnp
import diffrax as dfx
import optax
import matplotlib.pyplot as plt
for i in range(num_iterations):
kf, opt_state, loss_value = update_kf_step(kf, opt_state, args)
if i % 100 == 0:
print(f"KF Iteration {i}, Loss: {loss_value}")
return kf
def optimize_p(initial_p, fixed_args, num_iterations=1000):
p = initial_p
opt_state = optimizer_p.init(p)
for i in range(num_iterations):
p, opt_state, loss_value = update_p_step(p, opt_state, fixed_args)
if i % 100 == 0:
print(f"P Iteration {i}, Loss: {loss_value}")
return p
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello
I am working on a nested optimization script with autodiff and wanted to try out JAX as an alternative to Julia. The reason a Nested optimization is important here is because that although fitting all 3 parameters together is faster, kf easily becomes stuck in a local minimum when the initial guess is poor and I want to be robust to this when fitting non simulated data. I am new to jax and want to see if anyone can identify any glaring reasons why this code would run slowly. It is significantly slower than expected and I haven't been able to run it long enough yet to see if the fit is correct. I have experimented with the optimistix package as well but have had less success so far.
Thank you for the help!
import jax
import jax.numpy as jnp
import diffrax as dfx
import optax
import matplotlib.pyplot as plt
@jax.jit
def sys(t, u, args):
p, p_fix, kf, plateau, kd, start_time, end_time = args
forcing = forcing_function(t, plateau, kd, start_time, end_time)
@jax.jit
def forcing_function(t, plateau, kd, start_time, end_time):
return jnp.where(
t < start_time,
plateau * t,
jnp.where(
t < end_time,
plateau,
plateau * jnp.exp(-kd * (t - end_time))
)
)
#@jax.jit
def Steady_State(args):
p, p_fix, kf, plateau, kd, start_time, end_time = args
new_args = (p, p_fix, kf, 0.0, kd, start_time, end_time) # Set plateau to 0 for SS
def solve_ode(args, u0, ts):
term = dfx.ODETerm(sys)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=ts)
t0, t1 = ts[0], ts[-1]
dt0 = 0.1
sol = dfx.diffeqsolve(term, solver, t0, t1, dt0, y0=u0, args=args, saveat=saveat, adjoint=dfx.DirectAdjoint())
return sol
#@jax.jit
def set_kf(kf_guess, args):
p, p_fix, plateau, kd, u02, start_time, end_time = args
args_ss = (p, p_fix, kf_guess, 0.0, kd, start_time, end_time)
sol = Steady_State(args_ss)
return jnp.sum((sol[2] - u02)**2)
@jax.jit
def update_kf_step(kf, opt_state, args):
loss_value, grads = jax.value_and_grad(set_kf)(kf, args)
updates, opt_state = optimizer_kf.update(grads, opt_state)
kf = optax.apply_updates(kf, updates)
return kf, opt_state, loss_value
def optimize_kf(initial_kf, args, num_iterations=500):
kf = initial_kf
opt_state = optimizer_kf.init(kf)
#@jax.jit
def loss_fn(p, fixed_args):
p_fix, kf_guess, plateau, kd, u02, start_time, end_time, Tmeas, Ymeas = fixed_args
@jax.jit
def update_p_step(p, opt_state, fixed_args):
loss_value, grads = jax.value_and_grad(loss_fn)(p, fixed_args)
updates, opt_state = optimizer_p.update(grads, opt_state)
p = optax.apply_updates(p, updates)
return p, opt_state, loss_value
def optimize_p(initial_p, fixed_args, num_iterations=1000):
p = initial_p
opt_state = optimizer_p.init(p)
def main():
# Set fixed parameters
kf0 = jnp.array([10.0]) # True Value
p_fix = jnp.array([0.25]) # Known parameter
p_var = jnp.array([1.0, 0.5]) # Unknown parameters
u0_guess = jnp.array([1000.0, 750.0, 20.0])
u02 = 20.0
plateau, kd = 0.5, 1.0
start_time, end_time = 1.0, 10.0
if name == "main":
main()
Beta Was this translation helpful? Give feedback.
All reactions