From b6c186ecb9b91560b68a70376b48917d84abe0d2 Mon Sep 17 00:00:00 2001 From: John Lapeyre Date: Wed, 2 Oct 2024 19:47:13 -0400 Subject: [PATCH] Change API regarding convergence information (#33) There is now a single function `lambertw` for computing the Lambert W function. It takes a keyword argument, `info`. If `info` is false, the default, then only the result of computation is returned. If it is `true` then a triple giving the result and info on convergence is returned. In neither case is a warning or error explicitly raised. --- README.md | 4 +-- src/LambertW.jl | 63 ++++++++++++++++++++----------------------- test/lambertw_test.jl | 13 +++++++-- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 09f1bc2..9c41a17 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ also called the omega function or product logarithm. ```julia lambertw(z,k) # Lambert W function for argument z and branch index k lambertw(z) # the same as lambertw(z,0) -lambertw_check_convergence(z, k=0) # The same as above but throw an error if the computation failed to converge +lambertw(z; info=true) # Return a 3-tuple that includes convergence information. ``` `z` may be Complex or Real. `k` must be an integer. For Real @@ -36,7 +36,7 @@ julia> lambertw(-pi/2 + 0im) / pi 4.6681174759251105e-18 + 0.5im ``` -#### Note on `lambertw_check_convergence` +#### Note on `info=true` You can use this for extra safety. But I have been unable to find any input for which the root finding fails to converge quickly. diff --git a/src/LambertW.jl b/src/LambertW.jl index c796775..43c7bf6 100644 --- a/src/LambertW.jl +++ b/src/LambertW.jl @@ -13,7 +13,7 @@ module LambertW import IrrationalConstants -export lambertw, lambertwbp, lambertw_check_convergence +export lambertw, lambertwbp const omega_const_bf_ = Ref{BigFloat}() @@ -44,7 +44,7 @@ julia> lambertw(LambertW.lambertwbranchpoint, -1) ### Lambert W function """ - lambertw(z, k::Integer=0, maxits::Integer=1000) + lambertw(z, k::Integer=0, maxits::Integer=1000; info::Bool=false) Compute the `k`th branch of the Lambert W function of `z`. @@ -52,12 +52,15 @@ If `z` is real, `k` must be either `0` or `-1`. For `Real` `z`, the domain of th `k = -1` is `[-1/e, 0]` and the domain of the branch `k = 0` is `[-1/e, Inf]`. For `Complex` `z`, and all `k`, the domain is the complex plane. -The result is computed via a root-finding loop. If the number of iterations exceeds -`maxits`, then the loop exits early, returning a result without warning about the failure -to converge. This will probably never happen. However, if you want to be more careful, -call `lambertw_check_convergence` instead. The latter function returns the result if -`maxits` was not reached, and otherwise throws an error. +If `info` is `false` then `lambertw` returns just the result of the computation. +If `info` is `true`, then it returns a 3-tuple. The first item is the the result of the +computation. The second item is `true` if the root-finding compution converged in fewer +than `maxits` iterations, and otherwise is `false`. The third item is the number of +iterations performed. + +I have been unable to find a value of `z` for which the root-finding fails to converge +within ten iterations. ```jldoctest julia> lambertw(-1/MathConstants.e, -1) -1.0 @@ -75,24 +78,11 @@ julia> lambertw(Complex(-10.0, 3.0), 4) -0.9274337508660128 + 26.37693445371142im ``` """ -lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits)[1] - -""" - lambertw_check_convergence(z, k::Integer=0, maxits::Integer=1000) - -This is the same as `lambertw` except that if the root finding fails to converge in `maxits` iterations, -an error is thrown. -""" -function lambertw_check_convergence(z, k::Integer=0, maxits::Integer=1000) - (w, converged) = _lambertw(float(z), k, maxits) - if ! converged - error("lambertw failed to converge in $maxits iterations") - end - w +function lambertw(z, k::Integer=0, maxits::Integer=1000; info::Bool=false) + result = _lambertw(float(z), k, maxits) + info ? result : result[1] end -#lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits) - # lambertw(e + 0im, k) is ok for all k ### Real z @@ -103,16 +93,19 @@ function _lambertw(x::Real, k, maxits) throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1")) end +# If we don't run root finding at all, return `true` for success with zero iterations. +_no_loop(w) = (w, true, 0) + # Real x, k = 0 # This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf. # There is a magic number here. It could be noted, or possibly removed. # In particular, the fancy initial condition selection does not seem to help speed. function lambertw_branch_zero(x::T, maxits) where T<:Real - isnan(x) && return(NaN) - x == Inf && return Inf # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf) + isnan(x) && return _no_loop(NaN) + x == Inf && return _no_loop(Inf) # appears to return convert(BigFloat, Inf) for x == BigFloat(Inf) one_t = one(T) oneoe = -one_t / convert(T, MathConstants.e) # The branch point - x == oneoe && return -one_t + x == oneoe && return _no_loop(-one_t) oneoe <= x || throw(DomainError(x)) itwo_t = 1 / convert(T, 2) if x > one_t @@ -128,9 +121,9 @@ end # Real x, k = -1 function lambertw_branch_one(x::T, maxits) where T<:Real oneoe = -one(T) / convert(T, MathConstants.e) - x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above + x == oneoe && return _no_loop(-one(T)) # W approaches -1 as x -> -1/e from above oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e - x == zero(T) && return -convert(T, Inf) # W decreases w/o bound as x -> 0 from below + x == zero(T) && return _no_loop(-convert(T, Inf)) # W decreases w/o bound as x -> 0 from below x < zero(T) || throw(DomainError(x)) return lambertw_root_finding(x, log(-x), maxits) end @@ -143,8 +136,8 @@ function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real pointseven = 7//10 if abs(z) <= one_t/convert(T, MathConstants.e) if z == 0 - k == 0 && return z - return complex(-convert(T, Inf), zero(T)) + k == 0 && return _no_loop(z) + return _no_loop(complex(-convert(T, Inf), zero(T))) end if k == 0 w = z @@ -158,10 +151,10 @@ function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real w = abs(z+ 1//2) < 1//10 ? imag(z) > 0 ? complex(pointseven, pointseven) : complex(pointseven, -pointseven) : z else if real(z) == convert(T, Inf) - k == 0 && return z + k == 0 && return _no_loop(z) return z + complex(0, 2*k*pi) end - real(z) == -convert(T, Inf) && return -z + complex(0, (2*k+1)*pi) + real(z) == -convert(T, Inf) && return _no_loop(-z + complex(0, (2*k+1)*pi)) w = log(z) k != 0 ? w += complex(0, 2*k*pi) : nothing end @@ -178,7 +171,8 @@ function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number lastx = x lastdiff = zero(T) converged::Bool = false - for _ in 1:maxits + num_iters = 0 + for iter_count in 1:maxits ex = exp(x) xexz = x * ex - z x1 = x + 1 @@ -186,12 +180,13 @@ function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number xdiff = abs(lastx - x) if xdiff <= 3 * eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle converged = true + num_iters = iter_count break end lastx = x lastdiff = xdiff end - return (x, converged) + return (x, converged, num_iters) end ### Inverse of Lambert W function diff --git a/test/lambertw_test.jl b/test/lambertw_test.jl index 6667fb0..64eb465 100644 --- a/test/lambertw_test.jl +++ b/test/lambertw_test.jl @@ -188,6 +188,15 @@ end @test string(LambertW.Omega()) == "ω" end -@testset "lambertw_check_convergence" begin - @test lambertw_check_convergence(1.0) == lambertw(1.0) +@testset "lambertw info" begin + result = lambertw(1.0; info=true) + @test result[1] == lambertw(1.0) + @test result[2] + @test result[3] > 1 && result[3] < 10 + + for z in (10., complex(10), lambertwbranchpoint) + res = lambertw(1.0; info=true) + @test res[2] + @test length(res) == 3 + end end