Skip to content

Commit

Permalink
Update time FFT in correlation calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
ddahlbom committed Jun 6, 2024
1 parent 5f46c8c commit 4617a99
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
17 changes: 9 additions & 8 deletions src/SampledCorrelations/CorrelationSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function no_processing(::SampledCorrelations)
end

function accum_sample!(sc::SampledCorrelations; window)
(; data, M, observables, samplebuf, nsamples, space_fft!, time_fft!) = sc
(; data, M, observables, samplebuf, corrbuf, nsamples, space_fft!, time_fft!, corr_fft!, corr_ifft!) = sc
natoms = size(samplebuf)[5]

num_time_offsets = size(samplebuf,6)
Expand Down Expand Up @@ -122,22 +122,24 @@ function accum_sample!(sc::SampledCorrelations; window)
# According to Sunny convention, the correlation is between
# α† and β. This conjugation implements both the dagger on the α
# as well as the appropriate spacetime offsets of the correlation.
corr = FFTW.ifft(conj.(sample_α) .* sample_β,4) ./ n_contrib
@. corrbuf = conj(sample_α) * sample_β
corr_ifft! * corrbuf
corrbuf ./= n_contrib

if window == :cosine
# Apply a cosine windowing to force the correlation at Δt=±(T-1) to be zero
# to force periodicity. In terms of the spectrum S(ω), this applys a smoothing
# with a characteristic lengthscale of O(1) frequency bins.
window_func = cos.(range(0,π,length = num_time_offsets + 1)[1:end-1]).^2
corr .*= reshape(window_func,1,1,1,num_time_offsets)
corrbuf .*= reshape(window_func,1,1,1,num_time_offsets)
end

FFTW.fft!(corr,4)
corr_fft! * corrbuf

if isnothing(M)
for k in eachindex(databuf)
# Store the diff for one complex number on the stack.
diff = corr[k] - databuf[k]
diff = corrbuf[k] - databuf[k]

# Accumulate into running average
databuf[k] += diff * (1/count)
Expand All @@ -149,15 +151,14 @@ function accum_sample!(sc::SampledCorrelations; window)
μ_old = databuf[k]

# Update running mean.
matrixelem = sample_α[k] * conj(sample_β[k])
databuf[k] += (matrixelem - databuf[k]) * (1/count)
databuf[k] += (corrbuf[k] - databuf[k]) / count
μ = databuf[k]

# Update variance estimate.
# Note that the first term of `diff` is real by construction
# (despite appearances), but `real` is explicitly called to
# avoid automatic typecasting errors caused by roundoff.
Mbuf[k] += real((matrixelem - μ_old)*conj(matrixelem - μ))
Mbuf[k] += real((corrbuf[k] - μ_old)*conj(corrbuf[k] - μ))
end
end
end
Expand Down
12 changes: 10 additions & 2 deletions src/SampledCorrelations/SampledCorrelations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ struct SampledCorrelations{N}

# Specs for sample generation and accumulation
samplebuf :: Array{ComplexF64, 6} # New sample buffer
corrbuf :: Array{ComplexF64, 4} # Buffer for real-time correlations
space_fft! :: FFTW.AbstractFFTs.Plan # Pre-planned FFT
time_fft! :: FFTW.AbstractFFTs.Plan # Pre-planned FFT
corr_fft! :: FFTW.AbstractFFTs.Plan # Pre-planned FFT
corr_ifft! :: FFTW.AbstractFFTs.Plan # Pre-planned FFT
measperiod :: Int # Steps to skip between saving observables (downsampling for dynamical calcs)
apply_g :: Bool # Whether to apply the g-factor
dt :: Float64 # Step size for trajectory integration
Expand Down Expand Up @@ -59,9 +62,11 @@ function clone_correlations(sc::SampledCorrelations{N}) where N
# Avoid copies/deep copies of C-generated data structures
space_fft! = 1/√(prod(dims)) * FFTW.plan_fft!(sc.samplebuf, (2,3,4))
time_fft! = 1/√(n_all_ω) * FFTW.plan_fft!(sc.samplebuf, 6)
corr_fft! = FFTW.plan_fft!(sc.corrbuf, 4)
corr_ifft! = FFTW.plan_ifft!(sc.corrbuf, 4)
M = isnothing(sc.M) ? nothing : copy(sc.M)
return SampledCorrelations{N}(copy(sc.data), M, sc.crystal, sc.origin_crystal, sc.Δω,
deepcopy(sc.observables), copy(sc.samplebuf), space_fft!, time_fft!, sc.measperiod, sc.apply_g, sc.dt,
deepcopy(sc.observables), copy(sc.samplebuf), copy(sc.corrbuf), space_fft!, time_fft!, corr_fft!, corr_ifft!, sc.measperiod, sc.apply_g, sc.dt,
copy(sc.nsamples), sc.processtraj!)
end

Expand Down Expand Up @@ -166,6 +171,7 @@ function dynamical_correlations(sys::System{N}; dt=nothing, Δt=nothing, nω, ω

# The sample buffer holds n_non_neg_ω measurements, and the rest is a zero buffer
samplebuf = zeros(ComplexF64, num_observables(observables), sys.latsize..., na, n_all_ω)
corrbuf = zeros(ComplexF64, sys.latsize..., n_all_ω)

# The output data has n_all_ω many (positive and negative and zero) frequencies
data = zeros(ComplexF64, num_correlations(observables), na, na, sys.latsize..., n_all_ω)
Expand All @@ -176,14 +182,16 @@ function dynamical_correlations(sys::System{N}; dt=nothing, Δt=nothing, nω, ω
# https://github.com/SunnySuite/Sunny.jl/issues/264 (subject to change).
space_fft! = 1/√(prod(sys.latsize)) * FFTW.plan_fft!(samplebuf, (2,3,4))
time_fft! = 1/√(n_all_ω) * FFTW.plan_fft!(samplebuf, 6)
corr_fft! = FFTW.plan_fft!(corrbuf, 4)
corr_ifft! = FFTW.plan_ifft!(corrbuf, 4)

# Other initialization
nsamples = Int64[0]

# Make Structure factor and add an initial sample
origin_crystal = isnothing(sys.origin) ? nothing : sys.origin.crystal
sc = SampledCorrelations{N}(data, M, sys.crystal, origin_crystal, Δω, observables,
samplebuf, space_fft!, time_fft!, measperiod, apply_g, dt, nsamples, process_trajectory)
samplebuf, corrbuf, space_fft!, time_fft!, corr_fft!, corr_ifft!, measperiod, apply_g, dt, nsamples, process_trajectory)

return sc
end
Expand Down

0 comments on commit 4617a99

Please sign in to comment.