diff --git a/src/emlink.jl b/src/emlink.jl index b6c75ee..05dc19d 100644 --- a/src/emlink.jl +++ b/src/emlink.jl @@ -150,7 +150,7 @@ function emlinkMARmov(patterns::MatchPatterns, dims::Tuple{Int,Int},varnames::Ve end for j in 1:length(uvals_gamma_jk[i])]),rev=true) end - delta = maximum(BigFloat.(abs.(recursive_flatten([p_m,p_u,p_gamma_km,p_gamma_ku]) - p_old))) + delta = maximum(BigFloat.(recursive_flatten([p_m,p_u,p_gamma_km,p_gamma_ku]) - p_old)) count += 1 if count > iter_max @@ -274,13 +274,13 @@ function emlinkMARmov(gamma_jk::Vector{Vector{UInt8}},n_j::Vector{Int64}, dims:: p_gamma_kjm[i,:] = [ismissing(j) ? missing : p_gamma_km[i][findfirst(uvals_gamma_jk[i] .== j)] for j in vals_gamma_jk[i]] p_gamma_kju[i,:] = [ismissing(j) ? missing : p_gamma_ku[i][findfirst(uvals_gamma_jk[i] .== j)] for j in vals_gamma_jk[i]] end - p_gamma_jm = sum.(skipmissing.(eachcol(log.(p_gamma_kjm)))) - p_gamma_ju = sum.(skipmissing.(eachcol(log.(p_gamma_kju)))) - log_prod_gamma_jm = p_gamma_jm .+ log(p_m) - log_prod_gamma_ju = p_gamma_ju .+ log(p_u) + p_gamma_jm = sum.(skipmissing.(eachcol(log.(abs.(p_gamma_kjm))))) + p_gamma_ju = sum.(skipmissing.(eachcol(log.(abs.(p_gamma_kju))))) + log_prod_gamma_jm = p_gamma_jm .+ log(abs(p_m)) + log_prod_gamma_ju = p_gamma_ju .+ log(abs(p_u)) zeta_j = exp.(log_prod_gamma_jm - logxpy(log_prod_gamma_jm,log_prod_gamma_ju)) num_prod = exp.(log.(n_j) + log.(zeta_j)) - p_m = exp(log(abs(sum(num_prod) + mu - 1)) - log(abs(psi - mu + sum(n_j)))) + p_m = exp(log(sum(num_prod) + mu - 1) - log(psi - mu + sum(n_j))) p_u = 1-p_m for i in 1:nfeatures