In https://arxiv.org/abs/2510.19397 it is argued that the "correct" cost function to do gradient descent is
It would be nice to implement this cost function with gradients in jax. Note that jax does not have a differentiable matrix log function. However, one can go very far using their schur decomposition.