Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling dropouts #11

Open
timholy opened this issue Nov 16, 2023 · 4 comments
Open

Handling dropouts #11

timholy opened this issue Nov 16, 2023 · 4 comments

Comments

@timholy
Copy link
Contributor

timholy commented Nov 16, 2023

In cases of poor initialization, some components of the mixture may drop out. For example, let's create a 2-component mixture that is very poorly initialized:

julia> X = randn(10);

julia> mix = MixtureModel([Normal(100, 0.001), Normal(200, 0.001)], [0.5, 0.5]);

julia> logpdf.(components(mix), X')
2×10 Matrix{Float64}:
 -4.92479e9   -4.97741e9   -5.02964e9   -5.15501e9   -5.05792e9     -5.16391e9   -4.88617e9   -4.93348e9   -5.09162e9
 -1.98493e10  -1.99548e10  -2.00592e10  -2.03088e10  -2.01157e10     -2.03265e10  -1.97717e10  -1.98667e10  -2.01828e10

You can see that both have poor likelihood, but one of the two always loses by a very large margin. Then when we go to optimize,

julia> fit_mle(mix, X)
ERROR: DomainError with NaN:
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/dev/Distributions/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:36 [inlined]
  [5] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:229 [inlined]
  [6] fit_mle(::Type{Normal{Float64}}, x::Vector{Float64}, w::Vector{Float64}; mu::Float64, sigma::Float64)
    @ Distributions ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:256
  [7] fit_mle
    @ ~/.julia/dev/Distributions/src/univariate/continuous/normal.jl:253 [inlined]
  [8] fit_mle
    @ ~/.julia/dev/ExpectationMaximization/src/that_should_be_in_Distributions.jl:17 [inlined]
  [9] (::ExpectationMaximization.var"#2#3"{Vector{Normal{Float64}}, Vector{Float64}, Matrix{Float64}})(k::Int64)
    @ ExpectationMaximization ./none:0
 [10] iterate(::Base.Generator{Vector{Any}, DualNumbers.var"#1#3"})
    @ Base ./generator.jl:47 [inlined]
 [11] collect_to!(dest::AbstractArray{T}, itr::Any, offs::Any, st::Any) where T
    @ Base ./array.jl:890 [inlined]
 [12] collect_to_with_first!(dest::AbstractArray, v1::Any, itr::Any, st::Any)
    @ Base ./array.jl:868 [inlined]
 [13] collect(itr::Base.Generator{UnitRange{Int64}, ExpectationMaximization.var"#2#3"{Vector{…}, Vector{…}, Matrix{…}}})
    @ Base ./array.jl:842
 [14] fit_mle!::Vector{…}, dists::Vector{…}, y::Vector{…}, method::ClassicEM; display::Symbol, maxiter::Int64, atol::Float64, robust::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:48
 [15] fit_mle!
    @ ~/.julia/dev/ExpectationMaximization/src/classic_em.jl:14 [inlined]
 [16] fit_mle(::MixtureModel{…}, ::Vector{…}; method::ClassicEM, display::Symbol, maxiter::Int64, atol::Float64, robust::Bool,
 infos::Bool)
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:30
 [17] fit_mle(::MixtureModel{Univariate, Continuous, Normal{Float64}, Categorical{Float64, Vector{Float64}}}, ::Vector{Float64})
    @ ExpectationMaximization ~/.julia/dev/ExpectationMaximization/src/fit_em.jl:12
 [18] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

This arises because α[:] = mean(γ, dims = 1) returns α = [1.0, 0.0]. In other words, component 2 of the mixture "drops out."

I've found errors like these, as well as positive-definiteness errors in a multivariate context, to be pretty ubiquitous when fitting complicated distributions and point-clouds. To me it seems we'd need to come up with some kind of guard against this behavior? But I'm not sure what the state-of-the-art approach is, or I'd implement it.

@dmetivie
Copy link
Owner

Yes I noticed that also.
The robust = true keyword kind of prevent some of these behavior but does not catch everything at all.

I think in some sense this is really inherent to the EM algo, if it starts near a local minimal that has a droupout component it will go toward it, until numerical precision return an error.
I don't think there is much we can do, aside from implementing a different version of EM that escape these holes.

That said, maybe something like LogarithmicNumbers.jl or for the exponential familly ExponentialFamily.jl could help ?

For practice, I also added this fit_mle to test over multiple initial condition and return the best fitted model and avoid errors with try and catch.

@timholy
Copy link
Contributor Author

timholy commented Nov 18, 2023

If returning "empty" components is OK, one easy option might be simply to add N*α[i] < thresh && continue so that components assigned fewer than thresh points just don't get updated. One could make thresh = 1 perhaps by default, but there would also be arguments for either thresh = 1e-6 or thresh = d^2/2 + d + 1 (the latter basically saying we want enough data to determine the amplitude, mean, and covariance matrix).

@timholy
Copy link
Contributor Author

timholy commented Nov 19, 2023

To get a sense of how common this is, I wrote a quick script to generate random test cases and then report back cases that exhibited various classes of errors:

using ExpectationMaximization
using Distributions
using Random

nwanted = 3
nmax = 10000

# For DomainError
domerrX = Matrix{Float64}[]
domerridxs = Vector{Int}[]   # indices of the centers in corresponding X

# For posdef errors
pderrX = Matrix{Float64}[]
pderridxs = Vector{Int}[]

function init_mixture(X, centeridxs)
    dist = [MvNormal(X[:, idx], 1) for idx in centeridxs]
    αs = ones(length(centeridxs)) / length(centeridxs)
    return MixtureModel(dist, αs)
end

for i = 1:nmax
    (length(domerrX) >= nwanted && length(pderrX) >= nwanted) && (@show i; break)
    ctrue = [randn(2) for _ = 1:3]
    X = reduce(hcat, [randn(length(c), 20) .+ c for c in ctrue])
    X = round.(X; digits=2)    # to make it easy to write to a text file
    startidx = randperm(60)[1:3]
    mix = init_mixture(X, startidx)
    try
        fit_mle(mix, X)
    catch err
        isa(err, InterruptException) && rethrow(err)
        if isa(err, DomainError)
            if length(domerrX) < nwanted
                push!(domerrX, X)
                push!(domerridxs, startidx)
            end
        else
            if length(pderrX) < nwanted
                push!(pderrX, X)
                push!(pderridxs, startidx)
            end
        end
    end
end

This didn't generate any of the positive-definite errors I've seen in different circumstances (maybe that requires higher dimensionality?), but somewhere between 5-10% of all cases resulted in a dropout. There doesn't appear to be anything particularly bizarre about them; here's a typical case:

image

The red dots are both data points and the starting positions of the clusters. If there's a pattern, it seems that at least one of the red dots should be fairly near the cluster edge.

timholy added a commit to timholy/ExpectationMaximization.jl that referenced this issue Nov 19, 2023
This eliminates the common failures observed in
dmetivie#11 (comment)
@timholy
Copy link
Contributor Author

timholy commented Nov 19, 2023

So, what ends up happening is that Σ → 0 because only a single point gets associated with a component. The existing robust=true fails to catch this because it results in NaN rather than Inf because exp(-mahalanobis^2)/sqrt(det(Σ)) → 0/0. It's likely that some kind of shrinkage might be the best solution, but I pushed a bandaid in #12.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants