Skip to content

Commit

Permalink
Complete removal of "force" and optimizer fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kbarros committed Jul 25, 2023
1 parent a952bf6 commit 0168d76
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 221 deletions.
14 changes: 3 additions & 11 deletions src/Integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,17 @@ end

function step!(sys::System{N}, integrator::Langevin) where N
(Z′, ΔZ₁, ΔZ₂, ξ, HZ) = get_coherent_buffers(sys, 5)
∇E = get_dipole_buffers(sys, 1) |> only
Z = sys.coherents

randn!(sys.rng, ξ)

# Prediction
@. sys.dipoles = expected_spin(Z) # temporarily desyncs dipoles and coherents
set_energy_grad_dipoles!(∇E, sys.dipoles, sys)
set_energy_grad_coherents!(HZ, ∇E, Z, sys)
set_energy_grad_coherents!(HZ, Z, sys)
rhs_langevin!(ΔZ₁, Z, ξ, HZ, integrator, sys)
@. Z′ = normalize_ket(Z + ΔZ₁, sys.κs)

# Correction
@. sys.dipoles = expected_spin(Z′) # temporarily desyncs dipoles and coherents
set_energy_grad_dipoles!(∇E, sys.dipoles, sys)
set_energy_grad_coherents!(HZ, ∇E, Z′, sys)
set_energy_grad_coherents!(HZ, Z′, sys)
rhs_langevin!(ΔZ₂, Z′, ξ, HZ, integrator, sys)
@. Z = normalize_ket(Z + (ΔZ₁+ΔZ₂)/2, sys.κs)

Expand Down Expand Up @@ -186,7 +181,6 @@ end
function step!(sys::System{N}, integrator::ImplicitMidpoint; max_iters=100) where N
(; atol) = integrator
(ΔZ, Z̄, Z′, Z″, HZ) = get_coherent_buffers(sys, 5)
∇E = get_dipole_buffers(sys, 1) |> only
Z = sys.coherents

@. Z′ = Z
Expand All @@ -195,9 +189,7 @@ function step!(sys::System{N}, integrator::ImplicitMidpoint; max_iters=100) wher
for _ in 1:max_iters
@.= (Z + Z′)/2

@. sys.dipoles = expected_spin(Z̄) # temporarily desyncs dipoles and coherents
set_energy_grad_dipoles!(∇E, sys.dipoles, sys)
set_energy_grad_coherents!(HZ, ∇E, Z̄, sys)
set_energy_grad_coherents!(HZ, Z̄, sys)
rhs_ll!(ΔZ, HZ, integrator, sys)

@. Z″ = Z + ΔZ
Expand Down
225 changes: 94 additions & 131 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
@@ -1,160 +1,123 @@
# The following four helper functions allow for more code resuse since the
# projective parameterization is formally the same for both dipoles and coherent
# states (the element types are just real in the first case, complex in the
# second).
function set_forces_optim!(∇H, ∇H_dip, sys::System{N}) where {N}
Sunny.set_energy_grad_dipoles!(∇H_dip, sys.dipoles, sys)
Sunny.set_energy_grad_coherents!(∇H, ∇H_dip, sys.coherents, sys)
end

function set_forces_optim!(∇H, _, sys::System{0})
Sunny.set_energy_grad_dipoles!(∇H, sys.dipoles, sys)
end

@inline function set_spin_optim!(sys::System{N}, α, z, site) where N
set_coherent_state!(sys, projective_to_conventional(α, z), site)
end

@inline function set_spin_optim!(sys::System{0}, α, z, site)
polarize_spin!(sys, projective_to_conventional(α, z), site)
# Returns the stereographic projection u(α) = (2v + (1-v²)n)/(1+v²), which
# involves the orthographic projection v = (1-nn̄)α. The input `n` must be
# normalized. When `α=0`, the output is `u=n`, and when `|α|→ ∞` the output is
# `u=-n`. In all cases, `|u|=1`.
function stereographic_projection(α, n)
@assert n'*n 1
v = α - n*(n'*α) # project out component parallel to `n`
= real(v'*v)
u = (2v + (1-v²)*n) / (1+v²) # stereographic projection
return u
end

# Converts unnormalized representation of coherent state, α, to standard dipole
# or normalized complex vector representation.
function projective_to_conventional(α, z)
v = (I - z*z')*α
v2 = v'*v
return (2v + (1-v2)*z) / (1+v2) # Guaranteed to be normalized
# Calculate the vector-Jacobian-product x̄ du(α)/dα, where
# u(v) = (2v + (1-v²)n)/(1+v²), v(α,ᾱ)=Pα, and P=1-nn̄.
#
# From the chain rule for Wirtinger derivatives,
# du/dα = (du/dv) (dv/dα) + (du/dv̄) (dv̄/dα) = du/dv P.
#
# In the second step, we used
# dv/dα = P
# dv̄/dα = conj(dv/dᾱ) = 0.
#
# The remaining Jacobian matrix is
# du/dv = (2-2nv̄)/(1+v²) - 2(2v+(1-v²)n)/(1+v²)² v̄
# = c - c[(1+cb)n + cv]v̄,
# where b = (1-v²)/2 and c = 2/(1+v²).
#
# Using the above definitions, return:
# x̄ du/dα = x̄ du/dv P
#
@inline function vjp_stereographic_projection(x, α, n)
v = α - n*(n'*α)
= real(v'*v)
b = (1-v²)/2
c = 2/(1+v²)
# Perform dot products first to avoid constructing outer-product
x̄_dudv = c*x' - c * (x' * ((1+c*b)*n + c*v)) * v'
# Apply projection P=1-nn̄ on right
return x̄_dudv - (x̄_dudv * n) * n'
end

# Calculate du(α)/dα and apply to `vec`. u(α) = (2v + (1-v²)z)/(1+v²) with v =
# (1-zz†)α. Won't allocate if all inputs are StaticArrays.
@inline function apply_projective_jacobian(vec, α, z)
P = (I - z*z')
v = P*α
v2 = v'*v

dv_dα = P
dv2_dα = 2*' - 2*'*z)*z')
jac = (1/(1+v2)) * ((2dv_dα - (z*dv2_dα)') - (1/(1+v2)) * dv2_dα' * ((2v + (1-v2)*z)'))
# function variance(αs)
# ncomponents = length(αs) * length(first(αs))
# return sum(real(α'*α) for α in αs) / ncomponents
# end

return jac * vec
end

# Calculate H(u(α))
function optim_energy(αs, zs, sys::System{N})::Float64 where N
T = N == 0 ? Vec3 : CVec{N}
αs = reinterpret(reshape, T, αs)
function optim_set_spins!(sys::System{0}, αs, ns)
αs = reinterpret(reshape, Vec3, αs)
for site in all_sites(sys)
set_spin_optim!(sys, αs[site], zs[site], site)
s = stereographic_projection(αs[site], ns[site])
polarize_spin!(sys, s, site)
end
return energy(sys) # Note: `Sunny.energy` seems to allocate and is type-unstable (7/20/2023)
end

# Non-allocating check for largest unnormalized coordinate.
function maxnorm(αs)
max = 0.0
for α in αs
magnitude = norm(α)
max = magnitude > max ? magnitude : max
end
return max
end

# Calculate dH(u(α))/dα
function optim_gradient!(buf, αs, zs, B, sys::System{N}, halted, quickmode=false) where N
T = N == 0 ? Vec3 : CVec{N}
αs = reinterpret(reshape, T, αs)
Hgrad = reinterpret(reshape, T, buf)

# Check if any coordinate has gone adrift and signal need to reset if necessary
if !quickmode
maxdist = maxnorm(αs)
if maxdist > 1.5 # 1.5 found empirically, works well for both dipole and SU(3)
Hgrad .*= 0 # Trick Optim.jl to stop by setting gradient to 0 (triggers convergence tests)
halted[] = true # Let main loop know we haven't really converged
return
end
end

# Calculate gradient of energy with respect to α
function optim_set_spins!(sys::System{N}, αs, ns) where N
αs = reinterpret(reshape, CVec{N}, αs)
for site in all_sites(sys)
set_spin_optim!(sys, αs[site], zs[site], site)
end

set_forces_optim!(Hgrad, B, sys)

for site in all_sites(sys)
Hgrad[site] = apply_projective_jacobian(Hgrad[site], αs[site], zs[site])
Z = stereographic_projection(αs[site], ns[site])
set_coherent_state!(sys, Z, site)
end
end

# Quick "touchup" optimization that assumes the system is already near the
# ground state. Never changes the parameterization of coherent states or
# dipoles. For internal use when setting up a spin wave calculation.
function minimize_energy_touchup!(sys::System{N}; method=Optim.LBFGS, maxiters=50, kwargs...) where N
numbertype = N == 0 ? Float64 : ComplexF64
buffer = N == 0 ? sys.dipoles : sys.coherents
B = N == 0 ? nothing : get_dipole_buffers(sys, 1) |> only

zs = copy(buffer)
αs = zeros(numbertype, (N == 0 ? 3 : N, size(buffer)...))
halted = Ref(false)

f(proj_coords) = optim_energy(proj_coords, zs, sys)
g!(G, proj_coords) = optim_gradient!(G, proj_coords, zs, B, sys, halted, true) # true skips coordinate drifting test

options = Optim.Options(iterations=maxiters, kwargs...)
output = Optim.optimize(f, g!, αs, method(), options)
success = Optim.converged(output)
if !success
@warn "Optimization failed to converge within $(output.iterations) iterations. `System` not in ground state and spin wave calculations may fail."
end

return success
function optim_set_gradient!(G, sys::System{0}, αs, ns)
(αs, G) = reinterpret.(reshape, Vec3, (αs, G))
set_energy_grad_dipoles!(G, sys.dipoles, sys)
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) / norm(sys.dipoles)
end
function optim_set_gradient!(G, sys::System{N}, αs, ns) where N
(αs, G) = reinterpret.(reshape, CVec{N}, (αs, G))
set_energy_grad_coherents!(G, sys.coherents, sys)
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) / norm(sys.coherents)
end


"""
minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 1000, kwargs...) where N
minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 100, subiters=20, kwargs...) where N
Optimizes the spin configuration in `sys` to minimize energy. Any method from
Optim.jl that accepts only a gradient may be used by setting the `method`
keyword. Defaults to LBFGS.
"""
function minimize_energy!(sys::System{N}; method=Optim.LBFGS, maxiters = 1000, kwargs...) where N
function minimize_energy!(sys::System{N}; method=Optim.LBFGS(), maxiters=100, subiters=20, kwargs...) where N
# Allocate buffers for optimization:
# - Each `ns[site]` defines a direction for stereographic projection.
# - Each `αs[:,site]` will be optimized in the space orthogonal to `ns[site]`.
if iszero(N)
ns = normalize.(sys.dipoles)
αs = zeros(Float64, 3, size(sys.dipoles)...)
else
ns = normalize.(sys.coherents)
αs = zeros(ComplexF64, N, size(sys.coherents)...)
end

# Set up type and buffer information depending on system type
numbertype = N == 0 ? Float64 : ComplexF64
buffer = N == 0 ? sys.dipoles : sys.coherents
B = N == 0 ? nothing : get_dipole_buffers(sys, 1) |> only
# Functions to calculate energy and gradient for the state `αs`
function f(αs)
optim_set_spins!(sys, αs, ns)
return energy(sys) # TODO: `Sunny.energy` seems to allocate and is type-unstable (7/20/2023)
end
function g!(G, αs)
optim_set_spins!(sys, αs, ns)
optim_set_gradient!(G, sys, αs, ns)
end

# Allocate buffers for optimization
zs = copy(buffer)
αs = zeros(numbertype, (N == 0 ? 3 : N, size(buffer)...))
halted = Ref(false)
# Perform optimization, resetting parameterization of coherent states as necessary
options = Optim.Options(iterations=subiters, show_trace=true, kwargs...)

# Set up energy and gradient functions using closures to pass "constants"
f(proj_coords) = optim_energy(proj_coords, zs, sys)
g!(G, proj_coords) = optim_gradient!(G, proj_coords, zs, B, sys, halted, false)
for iter in 1 : div(maxiters, subiters, RoundUp)
output = Optim.optimize(f, g!, αs, method, options)

# Perform optimization, resetting parameterization of coherent states as necessary
options = Optim.Options(iterations=maxiters, kwargs...)
totaliters = 0
success = false
while totaliters < maxiters
output = Optim.optimize(f, g!, αs, method(), options)
if halted[]
zs .= buffer # Reset parameterization based on current state
αs .*= 0 # Reset unnormalized coordinates
halted[] = false
else
if Optim.converged(output) # Convergence report only meaningful if not halted
success = true
break
end
if Optim.converged(output)
cnt = (iter-1)*subiters + output.iterations
return cnt
end
totaliters += output.iterations

# Reset parameterization based on current state
ns .= normalize.(iszero(N) ? sys.dipoles : sys.coherents)
αs .*= 0
end

return success
end
@warn "Optimization failed to converge within $maxiters iterations."
return -1
end
2 changes: 1 addition & 1 deletion src/Sunny.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ include("System/OnsiteCoupling.jl")
include("System/Ewald.jl")
include("System/Interactions.jl")
export SpinInfo, System, Site, all_sites, position_to_site,
global_position, magnetic_moment, polarize_spin!, polarize_spins!, randomize_spins!, energy, forces,
global_position, magnetic_moment, polarize_spin!, polarize_spins!, randomize_spins!, energy,
spin_operators, stevens_operators, set_external_field!, set_onsite_coupling!, set_exchange!,
dmvec, enable_dipole_dipole!, to_inhomogeneous, set_external_field_at!, set_vacancy_at!, set_onsite_coupling_at!,
symmetry_equivalent_bonds, set_exchange_at!, remove_periodicity!
Expand Down
18 changes: 9 additions & 9 deletions src/System/Ewald.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function ewald_energy(sys::System{N}) where N
return E / prod(latsize)
end

# Use FFT to accumulate the entire field -dE/ds for long-range dipole-dipole
# Use FFT to accumulate the entire field dE/ds for long-range dipole-dipole
# interactions
function accum_ewald_grad!(∇E::Array{Vec3, 4}, dipoles::Array{Vec3, 4}, sys::System{N}) where N
(; ewald, units, gs) = sys
Expand Down Expand Up @@ -182,8 +182,8 @@ function accum_ewald_grad!(∇E::Array{Vec3, 4}, dipoles::Array{Vec3, 4}, sys::S
end
end

# Calculate the field -dE/ds at site1 generated by a dipole at site2.
function ewald_pairwise_force_at(sys::System{N}, site1, site2) where N
# Calculate the field dE/ds at site1 generated by a dipole at site2.
function ewald_pairwise_grad_at(sys::System{N}, site1, site2) where N
(; gs, ewald, units) = sys
latsize = size(ewald.ϕ)[1:3]
cell_offset = mod.(Tuple(to_cell(site2)-to_cell(site1)), latsize)
Expand All @@ -193,14 +193,14 @@ function ewald_pairwise_force_at(sys::System{N}, site1, site2) where N
# accounts for the quadratic dependence on the dipole. If site1 != site2, it
# accounts for energy contributions from both ordered pairs (site1, site2)
# and (site2, site1).
return - 2 * units.μB^2 * gs[site1]' * ewald.A[cell, to_atom(site1), to_atom(site2)] * gs[site2] * sys.dipoles[site2]
return 2 * units.μB^2 * gs[site1]' * ewald.A[cell, to_atom(site1), to_atom(site2)] * gs[site2] * sys.dipoles[site2]
end

# Calculate the field -dE/ds at `site` generated by all `dipoles`.
function ewald_force_at(sys::System{N}, site) where N
# Calculate the field dE/ds at `site` generated by all `dipoles`.
function ewald_grad_at(sys::System{N}, site) where N
acc = zero(Vec3)
for site2 in all_sites(sys)
acc += ewald_pairwise_force_at(sys, site, site2)
acc += ewald_pairwise_grad_at(sys, site, site2)
end
return acc
end
Expand All @@ -212,6 +212,6 @@ function ewald_energy_delta(sys::System{N}, site, s::Vec3) where N
Δs = s - dipoles[site]
Δμ = units.μB * (sys.gs[site] * Δs)
i = to_atom(site)
B = ewald_force_at(sys, site)
return - ΔsB + dot(Δμ, ewald.A[1, 1, 1, i, i], Δμ)
∇E = ewald_grad_at(sys, site)
return Δs∇E + dot(Δμ, ewald.A[1, 1, 1, i, i], Δμ)
end
Loading

0 comments on commit 0168d76

Please sign in to comment.