From b30a4b9e7363ce1a2e2b73d1171465ec8dc61bcc Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 5 Jul 2025 02:10:14 -0400 Subject: [PATCH 1/3] improve `log1pmx`, add `Float32 implimentation Through some clever use of Remez.jl (and some testing to better limit where the fallback is appropriate), we can remove almost all of the branches from the Float64 implementation and add a similar (but slightly faster) Float32 version. --- src/basicfuns.jl | 61 ++++++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e2774d9..258975e 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -288,31 +288,18 @@ $(SIGNATURES) Return `log(1 + x) - x`. Use naive calculation or range reduction outside kernel range. Accurate ~2ulps for all `x`. -This will fall back to the naive calculation for argument types different from `Float64`. +This will fall back to the naive calculation for argument types different from `Float32, Float64`. """ -function log1pmx(x::Float64) - if !(-0.7 < x < 0.9) +log1pmx(x::Real) = log1p(x) - x # Naive fallback + +function log1pmx(x::Union{Float32, Float64}) + if !(-0.425 < x < 0.4) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5 return log1p(x) - x - elseif x > 0.315 - u = (x-0.5)/1.5 - return _log1pmx_ker(u) - 9.45348918918356180e-2 - 0.5*u - elseif x > -0.227 - return _log1pmx_ker(x) - elseif x > -0.4 - u = (x+0.25)/0.75 - return _log1pmx_ker(u) - 3.76820724517809274e-2 + 0.25*u - elseif x > -0.6 - u = (x+0.5)*2.0 - return _log1pmx_ker(u) - 1.93147180559945309e-1 + 0.5*u else - u = (x+0.625)/0.375 - return _log1pmx_ker(u) - 3.55829253011726237e-1 + 0.625*u + return _log1pmx_ker(x) end end -# Naive fallback -log1pmx(x::Real) = log1p(x) - x - """ $(SIGNATURES) @@ -345,21 +332,29 @@ function logmxp1(x::Real) end # The kernel of log1pmx -# Accuracy within ~2ulps for -0.227 < x < 0.315 -function _log1pmx_ker(x::Float64) - r = x/(x+2.0) +# Accuracy within ~2ulps -0.227 < x < 0.315 for Float64 +# Accuracy <2.18ulps -0.425 < x < 0.425 for Float32 +# parameters foudn via Remez.jl, specifically: +# g(x) = evalpoly(x, big(2)./ntuple(i->2i+1, 50)) +# p = T.(Tuple(ratfn_minimax(g, (1e-3, (.425/(.425+2))^2), 8, 0)[1])) +function _log1pmx_ker(x::T) where T <: Union{Float32, Float64} + r = x / (x+2) t = r*r - w = @horner(t, - 6.66666666666666667e-1, # 2/3 - 4.00000000000000000e-1, # 2/5 - 2.85714285714285714e-1, # 2/7 - 2.22222222222222222e-1, # 2/9 - 1.81818181818181818e-1, # 2/11 - 1.53846153846153846e-1, # 2/13 - 1.33333333333333333e-1, # 2/15 - 1.17647058823529412e-1) # 2/17 - hxsq = 0.5*x*x - r*(hxsq+w*t)-hxsq + if T == Float32 + p = (0.6666658f0, 0.40008822f0, 0.2827692f0, 0.26246136f0) + else + p = (0.6666666666666669, + 0.3999999999997768, + 0.2857142857784595, + 0.2222222142048249, + 0.18181870670924566, + 0.15382646727504887, + 0.1337701340211177, + 0.11201972567415432, + 0.143418239946679) + w = evalpoly(t, p) + hxsq = x*x/2 + muladd(r, muladd(w, t, hxsq), -hxsq) end From a5fe900634c82de3419215ad2b99ff65b1570448 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Sat, 5 Jul 2025 02:14:44 -0400 Subject: [PATCH 2/3] improve codegen --- src/basicfuns.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 258975e..599fe23 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -292,8 +292,8 @@ This will fall back to the naive calculation for argument types different from ` """ log1pmx(x::Real) = log1p(x) - x # Naive fallback -function log1pmx(x::Union{Float32, Float64}) - if !(-0.425 < x < 0.4) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5 +function log1pmx(x::T) where T <: Union{Float32, Float64} + if !(T(-0.425) < x < T(0.4)) # accurate within 2 ULPs when log2(abs(log1p(x))) > 1.5 return log1p(x) - x else return _log1pmx_ker(x) From 3f424d3c9c08ae8f7a43e1a3270d67b9749c11d6 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Mon, 7 Jul 2025 13:01:24 -0400 Subject: [PATCH 3/3] improve tests --- src/basicfuns.jl | 1 + test/basicfuns.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/basicfuns.jl b/src/basicfuns.jl index 599fe23..5cc0df2 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -352,6 +352,7 @@ function _log1pmx_ker(x::T) where T <: Union{Float32, Float64} 0.1337701340211177, 0.11201972567415432, 0.143418239946679) + end w = evalpoly(t, p) hxsq = x*x/2 muladd(r, muladd(w, t, hxsq), -hxsq) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 33c6365..6a84b54 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -216,7 +216,8 @@ end @test log1pmx(2f0) ≈ log(3f0) - 2f0 for x in -0.5:0.1:10 - @test log1pmx(Float32(x)) ≈ Float32(log1pmx(x)) + @test log1pmx(Float32(x)) ≈ Float32(log1pmx(x)) atol=3*eps(Float32(x)) + @test log1pmx(x) ≈ Float64(log1pmx(big(x))) atol=3*eps(x) end end