Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerically stable accumulation of intensity variance #214

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/SampledCorrelations/CorrelationSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function no_processing(::SampledCorrelations)
end

function accum_sample!(sc::SampledCorrelations)
(; data, variance, observables, samplebuf, nsamples, fft!) = sc
(; data, M, observables, samplebuf, nsamples, fft!) = sc
natoms = size(samplebuf)[5]

fft! * samplebuf # Apply pre-planned and pre-normalized FFT
Expand All @@ -98,16 +98,16 @@ function accum_sample!(sc::SampledCorrelations)
sample_β = @view samplebuf[β,:,:,:,j,:]
databuf = @view data[c,i,j,:,:,:,:]

if isnothing(variance)
if isnothing(M)
for k in eachindex(databuf)
# Store the diff for one complex number on the stack.
diff = sample_α[k] * conj(sample_β[k]) - databuf[k]

# Accumulate into running average
databuf[k] += diff * (1/count)
end
else
varbuf = @view variance[c,i,j,:,:,:,:]
else
Mbuf = @view M[c,i,j,:,:,:,:]
for k in eachindex(databuf)
# Store old (complex) mean on stack.
μ_old = databuf[k]
Expand All @@ -121,8 +121,7 @@ function accum_sample!(sc::SampledCorrelations)
# Note that the first term of `diff` is real by construction
# (despite appearances), but `real` is explicitly called to
# avoid automatic typecasting errors caused by roundoff.
diff = real((conj(matrixelem) - conj(μ_old))*(matrixelem - μ)) - varbuf[k]
varbuf[k] += diff * (1/count)
Mbuf[k] += real((matrixelem - μ_old)*conj(matrixelem - μ))
end
end
end
Expand Down
15 changes: 7 additions & 8 deletions src/SampledCorrelations/SampledCorrelations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ initialized by calling either [`dynamical_correlations`](@ref) or
struct SampledCorrelations{N}
# 𝒮^{αβ}(q,ω) data and metadata
data :: Array{ComplexF64, 7} # Raw SF data for 1st BZ (numcorrelations × natoms × natoms × latsize × energy)
variance :: Union{Nothing, Array{Float64, 7}} # Running variance calculation for Welford's algorithm
M :: Union{Nothing, Array{Float64, 7}} # Running estimate of (nsamples - 1)*σ² (where σ² is the variance of intensities)
crystal :: Crystal # Crystal for interpretation of q indices in `data`
origin_crystal :: Union{Nothing,Crystal} # Original user-specified crystal (if different from above) -- needed for FormFactor accounting
Δω :: Float64 # Energy step size (could make this a virtual property)
Expand Down Expand Up @@ -57,8 +57,8 @@ function clone_correlations(sc::SampledCorrelations{N}) where N
nω = size(sc.data, 7)
normalization_factor = 1/(nω * √(prod(dims)))
fft! = normalization_factor * FFTW.plan_fft!(sc.samplebuf, (2,3,4,6)) # Avoid copies/deep copies of C-generated data structures
variance = isnothing(sc.variance) ? nothing : copy(sc.variance)
return SampledCorrelations{N}(copy(sc.data), variance, sc.crystal, sc.origin_crystal, sc.Δω,
M = isnothing(sc.M) ? nothing : copy(sc.M)
return SampledCorrelations{N}(copy(sc.data), M, sc.crystal, sc.origin_crystal, sc.Δω,
deepcopy(sc.observables), copy(sc.samplebuf), fft!, sc.measperiod, sc.apply_g, sc.Δt,
copy(sc.nsamples), sc.processtraj!)
end
Expand All @@ -77,9 +77,8 @@ function merge_correlations(scs::Vector{SampledCorrelations{N}}) where N
n = sc_merged.nsamples[1]
m = sc.nsamples[1]
@. μ = (n/(n+m))*sc_merged.data + (m/(n+m))*sc.data
if !isnothing(sc_merged.variance)
@. sc_merged.variance = (n/(n+m))*(sc_merged.variance + abs(μ - sc_merged.data)^2) +
(m/(n+m))*(sc.variance + abs(μ - sc.data)^2)
if !isnothing(sc_merged.M)
@. sc_merged.M = (sc_merged.M + n*abs(μ - sc_merged.data)^2) + (sc.M + m*abs(μ - sc.data)^2)
end
sc_merged.data .= μ
sc_merged.nsamples[1] += m
Expand Down Expand Up @@ -154,7 +153,7 @@ function dynamical_correlations(sys::System{N}; Δt, nω, ωmax,
na = natoms(sys.crystal)
samplebuf = zeros(ComplexF64, num_observables(observables), sys.latsize..., na, nω)
data = zeros(ComplexF64, num_correlations(observables), na, na, sys.latsize..., nω)
variance = calculate_errors ? zeros(Float64, size(data)...) : nothing
M = calculate_errors ? zeros(Float64, size(data)...) : nothing

# Normalize FFT according to physical convention
normalizationFactor = 1/(nω * √(prod(sys.latsize)))
Expand All @@ -165,7 +164,7 @@ function dynamical_correlations(sys::System{N}; Δt, nω, ωmax,

# Make Structure factor and add an initial sample
origin_crystal = isnothing(sys.origin) ? nothing : sys.origin.crystal
sc = SampledCorrelations{N}(data, variance, sys.crystal, origin_crystal, Δω, observables,
sc = SampledCorrelations{N}(data, M, sys.crystal, origin_crystal, Δω, observables,
samplebuf, fft!, measperiod, apply_g, Δt, nsamples, processtraj!)

return sc
Expand Down
2 changes: 1 addition & 1 deletion test/test_correlation_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end
# Merge correlations and check if result equal to running calculation.
sc_merged = merge_correlations([sc1, sc2])
@test sc0.data ≈ sc_merged.data
@test sc0.variance ≈ sc_merged.variance
@test sc0.M ≈ sc_merged.M
end

@testitem "Sampled correlations reference" begin
Expand Down
Loading