Skip to content

Commit

Permalink
Optimize and fix g
Browse files Browse the repository at this point in the history
  • Loading branch information
kbarros committed Aug 29, 2024
1 parent d833b15 commit 02138ad
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions src/Spiral/SpinWaveTheorySpiral2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,36 @@ function dispersion(sswt::SpinWaveTheorySpiral, qpts)
return sort!(reshape(disp, 3L, size(qpts.qs)...); dims=1, rev=true)
end

# Observables must be dipole moments, with some choice of apply_g. Extract and
# return this parameter.
function is_apply_g(swt::SpinWaveTheory, measure::MeasureSpec)
obs1 = measure.observables
for apply_g in (true, false)
obs2 = all_dipole_observables(swt.sys; apply_g)
vec(obs1) vec(obs2) && return apply_g
end
error("General measurements not supported for spiral calculation")
end

function gs_as_scalar(swt::SpinWaveTheory, measure::MeasureSpec)
return if is_apply_g(swt, measure)
map(swt.sys.gs) do g
g = to_float_or_mat3(g; atol=1e-12)
g isa Float64 || error("Anisotropic g-tensor not supported for spiral calculation")
g
end
else
map(swt.sys.gs) do _
1.0
end
end
end

function intensities_bands(sswt::SpinWaveTheorySpiral, qpts; kT=0) # TODO: branch=nothing
(; swt,k, axis) = sswt
(; swt, k, axis) = sswt
(; sys, data, measure) = swt
(; local_rotations, observables_localized, sqrtS) = data

isempty(measure.observables) && error("No observables! Construct SpinWaveTheorySpiral with a `measure` argument.")
sys.mode == :SUN && error("SU(N) mode not supported for spiral calculation")
@assert sys.mode in (:dipole, :dipole_large_S)
Expand Down Expand Up @@ -264,17 +290,18 @@ function intensities_bands(sswt::SpinWaveTheorySpiral, qpts; kT=0) # TODO: branc
H = zeros(ComplexF64, 2L, 2L)
T0 = zeros(ComplexF64, 2L, 2L)
T = zeros(ComplexF64, 2L, 2L, 3)
Nobs = size(measure.observables, 1)
disp = zeros(Float64, L, 3, Nq)
intensity = zeros(eltype(measure), L, 3, Nq)
disp_flat = reshape(disp, 3L, Nq)
intensity_flat = reshape(intensity, 3L, Nq)
S = zeros(ComplexF64, 3, 3, L, 3)
S = zeros(CMat3, L, 3)
Avec_pref = zeros(ComplexF64, Na)

# If g-tensors are included in observables, they must be scalar. Precompute.
gs = gs_as_scalar(swt, measure)

for (iq, q) in enumerate(qpts.qs)
q_global = cryst.recipvecs * q
q_reshaped = sys.crystal.recipvecs \ q_global

for branch in 1:3 # (q-k, q, q+k) modes for propagation wavevector k
energies = excitations!(T0, H, sswt, q; branch)
Expand All @@ -283,50 +310,46 @@ function intensities_bands(sswt::SpinWaveTheorySpiral, qpts; kT=0) # TODO: branc
end

for i in 1:Na
r_global = global_position(sys, (1,1,1,i))
r_global = global_position(sys, (1, 1, 1, i))
ff = get_swt_formfactor(measure, 1, i)
Avec_pref[i] = exp(- im * dot(q_global, r_global))
Avec_pref[i] *= compute_form_factor(ff, norm2(q_global))
end

Avec = zeros(ComplexF64, Nobs)
for α in 1:3, β in 1:3
for branch in 1:3, band in 1:L
fill!(Avec, 0)
data = swt.data::SWTDataDipole
v = reshape(view(T,:,band,branch),Na,2)
for i in 1:Na, μ in 1:Nobs
O = data.observables_localized[μ, i]
displacement_local_frame = SA[v[i, 2] + v[i, 1], im * (v[i, 2] - v[i, 1]), 0.0]
Avec[μ] += Avec_pref[i] * (data.sqrtS[i]/sqrt(2)) * (O' * displacement_local_frame)[1]
end
S[α,β,band,branch] = Avec[α] * conj(Avec[β]) / Ncells
for branch in 1:3, band in 1:L
Avec = zero(CVec{3})
v = reshape(view(T, :, band, branch), Na, 2)
for i in 1:Na
R = local_rotations[i]
displacement = R * SA[v[i, 2] + v[i, 1], im * (v[i, 2] - v[i, 1]), 0.0]
Avec += Avec_pref[i] * (sqrtS[i]/sqrt(2)) * gs[i] * displacement
end
S[band, branch] = Avec * Avec' / Ncells
end

for band = 1:L
case = spiral_propagation_case(k)
if case == 1
# The three branches (q-k, q, q+k) are equivalent when k = 0
# (modulo 1). Spread 1/3 of the intensity among every branch.
S[:,:,band,1] .*= 1/3
S[:,:,band,2] .*= 1/3
S[:,:,band,3] .*= 1/3
@assert S[:,:,band,1] S[:,:,band, 2] S[:,:,band,3]
S[band, 1] *= 1/3
S[band, 2] *= 1/3
S[band, 3] *= 1/3
@assert S[band, 1] S[band, 2] S[band, 3]
elseif case == 2
S[:,:,band,1] = R1 * S[:,:,band,1] * R1 + conj(R1) * S[:,:,band,1] * R1
S[:,:,band,2] = R2 * S[:,:,band,2] * R2
S[:,:,band,3] = conj(R1) * S[:,:,band,3] * conj(R1) + R1 * S[:,:,band,3] * conj(R1)
S[band, 1] = R1 * S[band, 1] * R1 + conj(R1) * S[band, 1] * R1
S[band, 2] = R2 * S[band, 2] * R2
S[band, 3] = conj(R1) * S[band, 3] * conj(R1) + R1 * S[band, 3] * conj(R1)
else
S[:,:,band,1] = R1 * S[:,:,band,1] * R1
S[:,:,band,2] = R2 * S[:,:,band,2] * R2
S[:,:,band,3] = conj(R1) * S[:,:,band,3] * conj(R1)
S[band, 1] = R1 * S[band, 1] * R1
S[band, 2] = R2 * S[band, 2] * R2
S[band, 3] = conj(R1) * S[band, 3] * conj(R1)
end
end

for branch in 1:3, band in 1:L
corrbuf = map(measure.corr_pairs) do (α, β)
S[α, β, band, branch]
S[band, branch][α, β]
end
intensity[band, branch, iq] = thermal_prefactor(disp[band, branch, iq]; kT) * measure.combiner(q_global, corrbuf)
end
Expand Down

0 comments on commit 02138ad

Please sign in to comment.