-
Notifications
You must be signed in to change notification settings - Fork 117
Add automatic differentiation test suite and add ForwardDiff support for dynamic master solvers #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Krastanov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On first pass, this looks awesome, thank you!
There are test failures though -- I have not looked into them, do you know how to address them?
|
Yes, I will do that soon, just got busy this week and didn't have time to deal with them :) |
|
I will mark this as a draft, just to organize my review queue a bit. Feel free to mark it back at any time. |
| function _promote_time_and_state(u0, H::AbstractOperator, J, rates, tspan) | ||
| # TODO: Find an alternative to promote_dual, which was moved to | ||
| # an extension in DiffEqBase 6.162.0 | ||
| ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt) | ||
| Ts = reduce(ext.promote_dual, (eltype(H), DiffEqBase.anyeltypedual(J), DiffEqBase.anyeltypedual(rates))) | ||
| Tt = real(Ts) | ||
| p = Vector{Tt}(undef,0) | ||
| u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) | ||
| tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) | ||
| return tspan_promote, u0_promote | ||
| end | ||
| _promote_time_and_state(u0, f::Function, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan) | ||
| function _promote_time_and_state(u0, f::Union{Tuple, Vector}, tspan) | ||
| # TODO: Find an alternative to promote_dual, which was moved to | ||
| # an extension in DiffEqBase 6.162.0 | ||
| ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt) | ||
| Ts = reduce(ext.promote_dual, (eltype(f[1]), DiffEqBase.anyeltypedual.(f[2:end])...)) | ||
| Tt = real(Ts) | ||
| p = Vector{Tt}(undef,0) | ||
| u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) | ||
| tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) | ||
| return tspan_promote, u0_promote | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the piece of code linked below (that is based on something you provided) is currently the cleanest way to fix the issues stemming from promote_dual
|
unmarked it from draft status to run the buildkite CI |
|
You might be interested in DifferentiationInterfaceTest.jl |
I added a test suite that exhaustively checks automatic differentiation capabilities for each solver. With this PR, FiniteDiff.jl and ForwardDiff.jl are fully supported for schroedinger and master solvers (including their dynamic versions). I added DifferentiationInterface.jl as a test dependency to quickly test other autodiff libraries supported in Julia in the future (such as Zygote.jl and Enzyme.jl). In the future I will also add support for the stochastic, semi-classical, and Monte Carlo solvers.
One note: here I am simply testing whether or not each differentiation operation runs on each solver with random test cases. I'm open to testing for correctness within some numerical tolerance, but I'd imagine we'd have to be extremely careful that the hundreds of tests pass every time, particularly when we have a handful of autodiff libraries supported in the suite.