-
Notifications
You must be signed in to change notification settings - Fork 220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interface for specifying adjoints of FEM models #2303
Comments
Naive question, but isn't this solved by SciMLSensitivity and typically the adjoint method used automatically when ADing through |
@devmotion You're right. There's the |
To illustrate the math of the adjoint problem, let's consider a simple ODE in time with given initial condition where The adjoint is used when we try to optimise the loss function (i.e. the log-density in Turing) which requires the gradient computed using the Lagrange multiplier Therefore, to use the adjoint method, we need access to |
Did a bit of digging around [SciMLSensitivity](https://docs.sciml.ai/SciMLSensitivity/):
For the purpose of our project, we'll likely use FEM solver outside the SciML ecosystem. FEM solver will provide the equivalent of Therefore, there are two ways to implement the adjoint acceleration with our custom FEM solver.
Method 1 is more complete and works better with the whole SciML ecosystem, but it will require close collaboration with the SciMLSensitivity project. It will be a harder engineering challenge. Method 2 is more specific to what we need, easier to implement, but slightly detached from the SciML ecosystem. |
You don't have to do any of that. You're overthinking it. If you use forward/reverse mode AD on a Now the next thing is extending to PDEs. PDE discretizations like FEM are simply just transformations of PDEs into computable forms. These computable forms are the SciMLBase interfaces, such as LinearProblem, NonlinearProblem, ODEProblem, etc. For example, a method of lines discretization which leaves time intact will give you an ODEProblem, while a finite element collocation in time gives you a NonlinearProblem. It does not matter how you do your PDE discretization, you end up with one of the canonical mathematical problems to solve for the coefficients of the representation. In that case, applying the adjoint method then follows automatically, since you perform your discretization, then build a LinearProblem/NonlinearProblem/ODEProblem, and solve it, and then again when Turing automatic differentiation applies, it will automatically know (given the size of the problem and other heuristics) to apply adjoint differentiation to the LinearProblem/NonlinearProblem/ODEProblem solve. That means that the only thing you have to do in order to make this work out is to ensure that your FEM discretization, i.e. the matrix assembly, is compatible with automatic differentiation. Preferably reverse mode. If you handle that, all of the other adjoint rules then follow, and all of the implicit differentiation tricks are then applied automatically behind the scenes via other rules definitions. |
And for reference, we have some projects using Ferrite.jl which seems to be Enzyme compatible, so if Turing can use Enzyme these days then using Ferrite for the semi-discretization would be a good optimization. We should probably make a tutorial along these lines. |
Thanks @ChrisRackauckas . |
Only the semi-discretization would have to be AD compatible, then it would flow. That just takes a few adjoint overloads. Someone tested that back in like 2020 for FEniCS.jl. Ferrite.jl should be directly compatible with Enzyme. |
@llfung SciML/SciMLSensitivity.jl#1105 might provide additional information on this. Can we treat custom PDE solvers as a "special" autodiff package, then hook into |
@llfung talked to me today about his work on adjoint-accelerated programmable inference for large PDEs, and what would be needed on Turing's part to support that. As I understand it (and I know very little about FEM, so bear with me if my explanation/understanding here is poor), using the adjoint allows you to effectively compute the gradient with respect to the parameters of the FEM problem. What they would need is some sort of mechanism in Turing to specify the function computing the gradient of the FEM part, and have MLE/MAP, and maybe also other things like sampling, to make use of that.
My first question for @llfung was whether we could hijack the AD mechanisms and specify an AD rule for the FEM solver, and have the implementation of that rule use the adjoint. Lloyd tells me that adjoint requires knowledge of the surrounding objective function (log density). I don't understand this well enough, but maybe in reverse mode AD we actually have what we need in the case of computing a vector jacobian product?
If that doesn't pan out, maybe we need some sort of new context, for which we can say "when computing a gradient, once you hit the point of calling the FEM solver, use this adjoint thing instead".
I'll leave @llfung to expand on this and to correct all the bits I got wrong.
The text was updated successfully, but these errors were encountered: