Replies: 2 comments
-
Just started using |
Beta Was this translation helpful? Give feedback.
0 replies
-
We don't have a plan to add them, but we could if needed. It's not so difficult to implement them in terms of import mlx.core as mx
def jacrev(f):
def jacfn(x):
# Needed for the size of the output
y = f(x)
def vjpfn(cotan):
return mx.vjp(f, (x,), (cotan,))[1][0]
return mx.vmap(vjpfn, in_axes=0)(mx.eye(len(y)))
return jacfn
def jacfwd(f):
def jacfn(x):
def jvpfn(tan):
return mx.jvp(f, (x,), (tan,))[1][0]
return mx.vmap(jvpfn, in_axes=0)(mx.eye(len(x)))
return jacfn
def hessian(f):
def hessfn(x):
def hvp(tan):
return mx.jvp(mx.grad(f), (x,), (tan,))[1][0]
return mx.vmap(hvp, in_axes=0)(mx.eye(len(x)))
return hessfn
print(jacrev(mx.sin)(mx.array([1.0, 2.0, 3.0])))
print(jacfwd(mx.sin)(mx.array([1.0, 2.0, 3.0])))
def fun(x):
return mx.sin(x).sum()
print(hessian(fun)(mx.array([1.0, 2.0, 3.0]))) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
In JAX, there are
jacfwd
,jacrev,
andhessian
functions for transforming the objective function into functions that compute first-order or second-order derivatives. I'm curious to know if MLX has plans to incorporate these three functions in the future.Beta Was this translation helpful? Give feedback.
All reactions