Skip to content

Commit

Permalink
Fix wraparound bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Lazersmoke committed Mar 8, 2024
1 parent 5e826c3 commit 045f533
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
33 changes: 30 additions & 3 deletions src/SampledCorrelations/CorrelationSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ end

function new_sample!(sc::SampledCorrelations, sys::System)
(; dt, samplebuf, measperiod, observables, processtraj!) = sc
nsnaps = size(samplebuf, 6)

# Only fill the sample buffer half way; the rest is zero-padding
buf_size = size(samplebuf, 6)
nsnaps = (buf_size÷2) + 1
samplebuf[:,:,:,:,:,(nsnaps+1):end] .= 0

@assert size(sys.dipoles) == size(samplebuf)[2:5] "`System` size not compatible with given `SampledCorrelations`"

trajectory!(samplebuf, sys, dt, nsnaps, observables.observables; measperiod)
Expand Down Expand Up @@ -82,7 +87,27 @@ function accum_sample!(sc::SampledCorrelations)
(; data, M, observables, samplebuf, nsamples, fft!) = sc
natoms = size(samplebuf)[5]

fft! * samplebuf # Apply pre-planned and pre-normalized FFT
num_time_offsets = size(samplebuf,6)
T = (num_time_offsets÷2) + 1 # Duration that each signal was recorded for

# Time offsets (in samples) Δt = [0,1,...,(T-1),-(T-1),...,-1] produced by
# the cross-correlation between two length-T signals
time_offsets = FFTW.fftfreq(num_time_offsets,num_time_offsets)

# In samplebuf, the original signal is from 1:T, and the rest
# is zero-padding, from (T+1):num_time_offsets.
fft! * samplebuf

# Number of contributions to the DFT sum (non-constant due to zero-padding).
# Equivalently, this is the number of estimates of the correlation with
# each offset Δt that need to be averaged over.
n_contrib = reshape(T .- abs.(time_offsets),1,1,1,num_time_offsets)

# As long as `num_time_offsets` is odd, there will be a non-zero number of
# contributions, so we don't need this line
@assert isodd(num_time_offsets)
#n_contrib[n_contrib .== 0] .= Inf

count = nsamples[1] += 1

# Note that iterating over the `correlations` (a SortedDict) causes
Expand All @@ -96,10 +121,12 @@ function accum_sample!(sc::SampledCorrelations)
sample_β = @view samplebuf[β,:,:,:,j,:]
databuf = @view data[c,i,j,:,:,:,:]

corr = FFTW.fft(FFTW.ifft(sample_α .* conj.(sample_β),4) ./ n_contrib,4)

if isnothing(M)
for k in eachindex(databuf)
# Store the diff for one complex number on the stack.
diff = sample_α[k] * conj(sample_β[k]) - databuf[k]
diff = corr[k] - databuf[k]

# Accumulate into running average
databuf[k] += diff * (1/count)
Expand Down
10 changes: 5 additions & 5 deletions src/SampledCorrelations/CorrelationUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ values.
function available_energies(sc::SampledCorrelations; negative_energies=false)
isnan(sc.Δω) && (return NaN)

= size(sc.data, 7)
= div(, 2) + 1
ωvals = FFTW.fftfreq(nω, nω * sc.Δω)
return negative_energies ? ωvals : ωvals[1:]
end
n_all_ω = size(sc.data, 7)
n_non_neg_ω = div(n_all_ω, 2) + 1
ωvals = FFTW.fftfreq(n_all_ω, n_all_ω * sc.Δω)
return negative_energies ? ωvals : ωvals[1:n_non_neg_ω]
end
39 changes: 26 additions & 13 deletions src/SampledCorrelations/SampledCorrelations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Base.getproperty(sc::SampledCorrelations, sym::Symbol) = sym == :latsize ? size(
function clone_correlations(sc::SampledCorrelations{N}) where N
dims = size(sc.data)[2:4]
= size(sc.data, 7)
normalization_factor = 1/(nω * (prod(dims)))
normalization_factor = 1/√(prod(dims))
fft! = normalization_factor * FFTW.plan_fft!(sc.samplebuf, (2,3,4,6)) # Avoid copies/deep copies of C-generated data structures
M = isnothing(sc.M) ? nothing : copy(sc.M)
return SampledCorrelations{N}(copy(sc.data), M, sc.crystal, sc.origin_crystal, sc.Δω,
Expand Down Expand Up @@ -135,15 +135,21 @@ function dynamical_correlations(sys::System{N}; dt=nothing, Δt=nothing, nω, ω
observables = parse_observables(N; observables, correlations, g = apply_g ? sys.gs : nothing)

# Determine trajectory measurement parameters
= Int64(nω)
if!= 1
@assert π/dt > ωmax "Desired `ωmax` not possible with specified `dt`. Choose smaller `dt` value."
measperiod = floor(Int, π/(dt * ωmax))
= 2-1 # Ensure there are nω _non-negative_ energies.
Δω = 2π / (dt*measperiod*nω)
else
if isnan(nω) # instant_correlations case
measperiod = 1
dt = Δω = NaN
n_all_ω = 1
else
# Determine how many time steps to skip between saving samples
# by skipping the largest number possible while still resolving ωmax
@assert π/dt > ωmax "Desired `ωmax` not possible with specified `dt`. Choose smaller `dt` value."
measperiod = floor(Int, π/(dt * ωmax))

# The user specifies the number of _non-negative_ energies
n_non_neg_ω = Int64(nω)
@assert n_non_neg_ω > 0 "nω must be at least 1"
n_all_ω = 2n_non_neg_ω-1
Δω = 2π / (dt*measperiod*n_all_ω)
end

# Set up trajectory processing function (e.g., symmetrize)
Expand All @@ -159,12 +165,19 @@ function dynamical_correlations(sys::System{N}; dt=nothing, Δt=nothing, nω, ω

# Preallocation
na = natoms(sys.crystal)
samplebuf = zeros(ComplexF64, num_observables(observables), sys.latsize..., na, nω)
data = zeros(ComplexF64, num_correlations(observables), na, na, sys.latsize..., nω)

# The sample buffer holds n_non_neg_ω measurements, and the rest is a zero buffer
samplebuf = zeros(ComplexF64, num_observables(observables), sys.latsize..., na, n_all_ω)

# The output data has n_all_ω many (positive and negative and zero) frequencies
data = zeros(ComplexF64, num_correlations(observables), na, na, sys.latsize..., n_all_ω)
M = calculate_errors ? zeros(Float64, size(data)...) : nothing

# Normalize FFT according to physical convention
normalizationFactor = 1/(nω * (prod(sys.latsize)))
# Specially Normalized FFT.
# This is designed so that when it enters ifft(fft * fft) later (in squared fashion)
# it will result in the 1/N factor needed to average over the
# N-many independent estimates of the correlation.
normalizationFactor = 1/√(prod(sys.latsize))
fft! = normalizationFactor * FFTW.plan_fft!(samplebuf, (2,3,4,6))

# Other initialization
Expand Down Expand Up @@ -212,5 +225,5 @@ The following optional keywords are available:
`correlations=[(1,1),(1,2)]`.
"""
function instant_correlations(sys::System; kwargs...)
dynamical_correlations(sys; dt=NaN, nω=1, ωmax=NaN, kwargs...)
dynamical_correlations(sys; dt=NaN, nω=NaN, ωmax=NaN, kwargs...)
end

0 comments on commit 045f533

Please sign in to comment.