-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Complete removal of "force" and optimizer fixes
- Loading branch information
Showing
10 changed files
with
195 additions
and
221 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,160 +1,123 @@ | ||
# The following four helper functions allow for more code resuse since the | ||
# projective parameterization is formally the same for both dipoles and coherent | ||
# states (the element types are just real in the first case, complex in the | ||
# second). | ||
function set_forces_optim!(∇H, ∇H_dip, sys::System{N}) where {N} | ||
Sunny.set_energy_grad_dipoles!(∇H_dip, sys.dipoles, sys) | ||
Sunny.set_energy_grad_coherents!(∇H, ∇H_dip, sys.coherents, sys) | ||
end | ||
|
||
function set_forces_optim!(∇H, _, sys::System{0}) | ||
Sunny.set_energy_grad_dipoles!(∇H, sys.dipoles, sys) | ||
end | ||
|
||
@inline function set_spin_optim!(sys::System{N}, α, z, site) where N | ||
set_coherent_state!(sys, projective_to_conventional(α, z), site) | ||
end | ||
|
||
@inline function set_spin_optim!(sys::System{0}, α, z, site) | ||
polarize_spin!(sys, projective_to_conventional(α, z), site) | ||
# Returns the stereographic projection u(α) = (2v + (1-v²)n)/(1+v²), which | ||
# involves the orthographic projection v = (1-nn̄)α. The input `n` must be | ||
# normalized. When `α=0`, the output is `u=n`, and when `|α|→ ∞` the output is | ||
# `u=-n`. In all cases, `|u|=1`. | ||
function stereographic_projection(α, n) | ||
@assert n'*n ≈ 1 | ||
v = α - n*(n'*α) # project out component parallel to `n` | ||
v² = real(v'*v) | ||
u = (2v + (1-v²)*n) / (1+v²) # stereographic projection | ||
return u | ||
end | ||
|
||
# Converts unnormalized representation of coherent state, α, to standard dipole | ||
# or normalized complex vector representation. | ||
function projective_to_conventional(α, z) | ||
v = (I - z*z')*α | ||
v2 = v'*v | ||
return (2v + (1-v2)*z) / (1+v2) # Guaranteed to be normalized | ||
# Calculate the vector-Jacobian-product x̄ du(α)/dα, where | ||
# u(v) = (2v + (1-v²)n)/(1+v²), v(α,ᾱ)=Pα, and P=1-nn̄. | ||
# | ||
# From the chain rule for Wirtinger derivatives, | ||
# du/dα = (du/dv) (dv/dα) + (du/dv̄) (dv̄/dα) = du/dv P. | ||
# | ||
# In the second step, we used | ||
# dv/dα = P | ||
# dv̄/dα = conj(dv/dᾱ) = 0. | ||
# | ||
# The remaining Jacobian matrix is | ||
# du/dv = (2-2nv̄)/(1+v²) - 2(2v+(1-v²)n)/(1+v²)² v̄ | ||
# = c - c[(1+cb)n + cv]v̄, | ||
# where b = (1-v²)/2 and c = 2/(1+v²). | ||
# | ||
# Using the above definitions, return: | ||
# x̄ du/dα = x̄ du/dv P | ||
# | ||
@inline function vjp_stereographic_projection(x, α, n) | ||
v = α - n*(n'*α) | ||
v² = real(v'*v) | ||
b = (1-v²)/2 | ||
c = 2/(1+v²) | ||
# Perform dot products first to avoid constructing outer-product | ||
x̄_dudv = c*x' - c * (x' * ((1+c*b)*n + c*v)) * v' | ||
# Apply projection P=1-nn̄ on right | ||
return x̄_dudv - (x̄_dudv * n) * n' | ||
end | ||
|
||
# Calculate du(α)/dα and apply to `vec`. u(α) = (2v + (1-v²)z)/(1+v²) with v = | ||
# (1-zz†)α. Won't allocate if all inputs are StaticArrays. | ||
@inline function apply_projective_jacobian(vec, α, z) | ||
P = (I - z*z') | ||
v = P*α | ||
v2 = v'*v | ||
|
||
dv_dα = P | ||
dv2_dα = 2*(α' - 2*(α'*z)*z') | ||
jac = (1/(1+v2)) * ((2dv_dα - (z*dv2_dα)') - (1/(1+v2)) * dv2_dα' * ((2v + (1-v2)*z)')) | ||
# function variance(αs) | ||
# ncomponents = length(αs) * length(first(αs)) | ||
# return sum(real(α'*α) for α in αs) / ncomponents | ||
# end | ||
|
||
return jac * vec | ||
end | ||
|
||
# Calculate H(u(α)) | ||
function optim_energy(αs, zs, sys::System{N})::Float64 where N | ||
T = N == 0 ? Vec3 : CVec{N} | ||
αs = reinterpret(reshape, T, αs) | ||
function optim_set_spins!(sys::System{0}, αs, ns) | ||
αs = reinterpret(reshape, Vec3, αs) | ||
for site in all_sites(sys) | ||
set_spin_optim!(sys, αs[site], zs[site], site) | ||
s = stereographic_projection(αs[site], ns[site]) | ||
polarize_spin!(sys, s, site) | ||
end | ||
return energy(sys) # Note: `Sunny.energy` seems to allocate and is type-unstable (7/20/2023) | ||
end | ||
|
||
# Non-allocating check for largest unnormalized coordinate. | ||
function maxnorm(αs) | ||
max = 0.0 | ||
for α in αs | ||
magnitude = norm(α) | ||
max = magnitude > max ? magnitude : max | ||
end | ||
return max | ||
end | ||
|
||
# Calculate dH(u(α))/dα | ||
function optim_gradient!(buf, αs, zs, B, sys::System{N}, halted, quickmode=false) where N | ||
T = N == 0 ? Vec3 : CVec{N} | ||
αs = reinterpret(reshape, T, αs) | ||
Hgrad = reinterpret(reshape, T, buf) | ||
|
||
# Check if any coordinate has gone adrift and signal need to reset if necessary | ||
if !quickmode | ||
maxdist = maxnorm(αs) | ||
if maxdist > 1.5 # 1.5 found empirically, works well for both dipole and SU(3) | ||
Hgrad .*= 0 # Trick Optim.jl to stop by setting gradient to 0 (triggers convergence tests) | ||
halted[] = true # Let main loop know we haven't really converged | ||
return | ||
end | ||
end | ||
|
||
# Calculate gradient of energy with respect to α | ||
function optim_set_spins!(sys::System{N}, αs, ns) where N | ||
αs = reinterpret(reshape, CVec{N}, αs) | ||
for site in all_sites(sys) | ||
set_spin_optim!(sys, αs[site], zs[site], site) | ||
end | ||
|
||
set_forces_optim!(Hgrad, B, sys) | ||
|
||
for site in all_sites(sys) | ||
Hgrad[site] = apply_projective_jacobian(Hgrad[site], αs[site], zs[site]) | ||
Z = stereographic_projection(αs[site], ns[site]) | ||
set_coherent_state!(sys, Z, site) | ||
end | ||
end | ||
|
||
# Quick "touchup" optimization that assumes the system is already near the | ||
# ground state. Never changes the parameterization of coherent states or | ||
# dipoles. For internal use when setting up a spin wave calculation. | ||
function minimize_energy_touchup!(sys::System{N}; method=Optim.LBFGS, maxiters=50, kwargs...) where N | ||
numbertype = N == 0 ? Float64 : ComplexF64 | ||
buffer = N == 0 ? sys.dipoles : sys.coherents | ||
B = N == 0 ? nothing : get_dipole_buffers(sys, 1) |> only | ||
|
||
zs = copy(buffer) | ||
αs = zeros(numbertype, (N == 0 ? 3 : N, size(buffer)...)) | ||
halted = Ref(false) | ||
|
||
f(proj_coords) = optim_energy(proj_coords, zs, sys) | ||
g!(G, proj_coords) = optim_gradient!(G, proj_coords, zs, B, sys, halted, true) # true skips coordinate drifting test | ||
|
||
options = Optim.Options(iterations=maxiters, kwargs...) | ||
output = Optim.optimize(f, g!, αs, method(), options) | ||
success = Optim.converged(output) | ||
if !success | ||
@warn "Optimization failed to converge within $(output.iterations) iterations. `System` not in ground state and spin wave calculations may fail." | ||
end | ||
|
||
return success | ||
function optim_set_gradient!(G, sys::System{0}, αs, ns) | ||
(αs, G) = reinterpret.(reshape, Vec3, (αs, G)) | ||
set_energy_grad_dipoles!(G, sys.dipoles, sys) | ||
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) / norm(sys.dipoles) | ||
end | ||
function optim_set_gradient!(G, sys::System{N}, αs, ns) where N | ||
(αs, G) = reinterpret.(reshape, CVec{N}, (αs, G)) | ||
set_energy_grad_coherents!(G, sys.coherents, sys) | ||
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) / norm(sys.coherents) | ||
end | ||
|
||
|
||
""" | ||
minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 1000, kwargs...) where N | ||
minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 100, subiters=20, kwargs...) where N | ||
Optimizes the spin configuration in `sys` to minimize energy. Any method from | ||
Optim.jl that accepts only a gradient may be used by setting the `method` | ||
keyword. Defaults to LBFGS. | ||
""" | ||
function minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 1000, kwargs...) where N | ||
function minimize_energy!(sys::System{N}; method=Optim.LBFGS(), maxiters=100, subiters=20, kwargs...) where N | ||
# Allocate buffers for optimization: | ||
# - Each `ns[site]` defines a direction for stereographic projection. | ||
# - Each `αs[:,site]` will be optimized in the space orthogonal to `ns[site]`. | ||
if iszero(N) | ||
ns = normalize.(sys.dipoles) | ||
αs = zeros(Float64, 3, size(sys.dipoles)...) | ||
else | ||
ns = normalize.(sys.coherents) | ||
αs = zeros(ComplexF64, N, size(sys.coherents)...) | ||
end | ||
|
||
# Set up type and buffer information depending on system type | ||
numbertype = N == 0 ? Float64 : ComplexF64 | ||
buffer = N == 0 ? sys.dipoles : sys.coherents | ||
B = N == 0 ? nothing : get_dipole_buffers(sys, 1) |> only | ||
# Functions to calculate energy and gradient for the state `αs` | ||
function f(αs) | ||
optim_set_spins!(sys, αs, ns) | ||
return energy(sys) # TODO: `Sunny.energy` seems to allocate and is type-unstable (7/20/2023) | ||
end | ||
function g!(G, αs) | ||
optim_set_spins!(sys, αs, ns) | ||
optim_set_gradient!(G, sys, αs, ns) | ||
end | ||
|
||
# Allocate buffers for optimization | ||
zs = copy(buffer) | ||
αs = zeros(numbertype, (N == 0 ? 3 : N, size(buffer)...)) | ||
halted = Ref(false) | ||
# Perform optimization, resetting parameterization of coherent states as necessary | ||
options = Optim.Options(iterations=subiters, show_trace=true, kwargs...) | ||
|
||
# Set up energy and gradient functions using closures to pass "constants" | ||
f(proj_coords) = optim_energy(proj_coords, zs, sys) | ||
g!(G, proj_coords) = optim_gradient!(G, proj_coords, zs, B, sys, halted, false) | ||
for iter in 1 : div(maxiters, subiters, RoundUp) | ||
output = Optim.optimize(f, g!, αs, method, options) | ||
|
||
# Perform optimization, resetting parameterization of coherent states as necessary | ||
options = Optim.Options(iterations=maxiters, kwargs...) | ||
totaliters = 0 | ||
success = false | ||
while totaliters < maxiters | ||
output = Optim.optimize(f, g!, αs, method(), options) | ||
if halted[] | ||
zs .= buffer # Reset parameterization based on current state | ||
αs .*= 0 # Reset unnormalized coordinates | ||
halted[] = false | ||
else | ||
if Optim.converged(output) # Convergence report only meaningful if not halted | ||
success = true | ||
break | ||
end | ||
if Optim.converged(output) | ||
cnt = (iter-1)*subiters + output.iterations | ||
return cnt | ||
end | ||
totaliters += output.iterations | ||
|
||
# Reset parameterization based on current state | ||
ns .= normalize.(iszero(N) ? sys.dipoles : sys.coherents) | ||
αs .*= 0 | ||
end | ||
|
||
return success | ||
end | ||
@warn "Optimization failed to converge within $maxiters iterations." | ||
return -1 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.