From 0cc797b17a68345a0287d440d3326feb03d7b9fa Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Sun, 25 Sep 2022 02:04:48 -0400 Subject: [PATCH 1/7] atan(2) hypot clamp --- src/bfloat16.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index a331337..18ceb93 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -13,6 +13,7 @@ import Base: isfinite, isnan, precision, iszero, eps, asin, acos, atan, acsc, asec, acot, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, + clamp, hypot, bitstring primitive type BFloat16 <: AbstractFloat 16 end @@ -262,3 +263,8 @@ for F in (:abs, :abs2, :sqrt, :cbrt, end end +Base.atan(y::BFloat16, x::BFloat16) = BFloat16(atan(Float32(y), Float32(x))) +Base.hypot(x::BFloat16, y::BFloat16) = BFloat16(hypot(Float32(x), Float32(y))) +Base.hypot(x::BFloat16, y::BFloat16, z::BFloat16) = BFloat16(hypot(Float32(x), Float32(y), Float32(z))) +Base.clamp(x::BFloat16, lo::BFloat16, hi::BFloat16) = BFloat16(clamp(Float32(x), Float32(lo), Float32(hi))) + From 143d7d8f8572285ad1a6dd22d7ab9dc310fc4db6 Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Sun, 25 Sep 2022 02:10:54 -0400 Subject: [PATCH 2/7] bitstring BFloat16(Irrational) --- src/bfloat16.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 18ceb93..146334d 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -138,6 +138,9 @@ function Base.Float64(x::BFloat16) Float64(Float32(x)) end +# accept Irrational +BFloat16s.BFloat16(x::Irrational) = BFloat16(Float32(x)) + # Truncation to integer types Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) @@ -268,3 +271,4 @@ Base.hypot(x::BFloat16, y::BFloat16) = BFloat16(hypot(Float32(x), Float32(y))) Base.hypot(x::BFloat16, y::BFloat16, z::BFloat16) = BFloat16(hypot(Float32(x), Float32(y), Float32(z))) Base.clamp(x::BFloat16, lo::BFloat16, hi::BFloat16) = BFloat16(clamp(Float32(x), Float32(lo), Float32(hi))) +Base.bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x)) From 66642ed32962d80347e28ad4b3c66ff8debf404c Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Sun, 25 Sep 2022 02:12:04 -0400 Subject: [PATCH 3/7] Update Project.toml --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index bc49d45..7996586 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BFloat16s" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -authors = ["Keno Fischer "] -version = "0.4.0" +authors = ["Keno Fischer ", "Jeffrey Sarnoff "] +version = "0.4.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 36826020da121858111acae45400a1d0d35bcac5 Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Sun, 25 Sep 2022 02:23:35 -0400 Subject: [PATCH 4/7] Update README.md --- README.md | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index afbcb11..146b269 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,24 @@ experiments. This package exports the BFloat16 data type. This datatype should behave just like any builtin floating point type (e.g. you can construct it from -other floating point types - e.g. `BFloat16(1.0)`). In addition, this package -provides the `LowPrecArray` type. This array is supposed to emulate the kind -of matmul operation that TPUs do well (BFloat16 multiply with Float32 -accumulate). Broadcasts and scalar operations are peformed in Float32 (as -they would be on a TPU) while matrix multiplies are performed in BFloat16 with -Float32 accumulates, e.g. +other floating point types - e.g. `BFloat16(1.0)`). Many predicates, +conversion, structural and mathematical functions are supported: +``` + Int16, Int32, Int64, Float16, Float32, Float64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, + isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, + sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, + signbit, exponent, significand, frexp, ldexp, exponent_one, exponent_half, + exp, exp2, exp10, expm1, log, log2, log10, log1p, + sin, cos, tan, csc, sec, cot, asin, acos, atan, acsc, asec, acot, + sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, + round, trunc, floor, ceil, abs, abs2, sqrt, cbrt, clamp, hypot, bitstring +``` + +In addition, this package provides the `LowPrecArray` type. This array is +supposed to emulate the kind of matmul operation that TPUs do well +(BFloat16 multiply with Float32 accumulate). Broadcasts and scalar operations +are peformed in Float32 (as they would be on a TPU) while matrix multiplies +are performed in BFloat16 with Float32 accumulates, e.g. ``` julia> A = LowPrecArray(rand(Float32, 5, 5)) From d6f43f76e0a08fa49bf2b753e6247803a22fa4c7 Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Wed, 31 Jan 2024 10:34:39 -0500 Subject: [PATCH 5/7] fix conflicts --- .vscode/settings.json | 1 + Project.toml | 4 +- README.md | 48 +++++++++-- src/bfloat16.jl | 183 +++++++++++++++++++++++++++++++++--------- 4 files changed, 192 insertions(+), 44 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/Project.toml b/Project.toml index 7996586..009099d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BFloat16s" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -authors = ["Keno Fischer ", "Jeffrey Sarnoff "] -version = "0.4.1" +authors = ["Keno Fischer "] +version = "0.5.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/README.md b/README.md index 146b269..98480df 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,20 @@ experiments. # Usage -This package exports the BFloat16 data type. This datatype should behave -just like any builtin floating point type (e.g. you can construct it from -other floating point types - e.g. `BFloat16(1.0)`). Many predicates, -conversion, structural and mathematical functions are supported: +This package exports the BFloat16 data type. This datatype behaves just like any built-in floating point type + +```julia +julia> using BFloat16s + +julia> a = BFloat16(2) +BFloat16(2.0) + +julia> sqrt(a) +BFloat16(1.4140625) ``` + +Many predicates, conversion, structural and mathematical functions are supported: +```julia Int16, Int32, Int64, Float16, Float32, Float64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, @@ -26,6 +35,34 @@ conversion, structural and mathematical functions are supported: round, trunc, floor, ceil, abs, abs2, sqrt, cbrt, clamp, hypot, bitstring ``` +However, in practice you may hit a `MethodError` indicating that this package does not implement +some method for `BFloat16` although it should. Please raise an issue so that we can +close that gap in support. + +### solving a linear equation system + +```julia +julia> A = randn(BFloat16,3,3) +3×3 Matrix{BFloat16}: + 1.46875 -1.20312 -1.0 + 0.257812 -0.671875 -0.929688 + -0.410156 -1.75 -0.0162354 + +julia> b = randn(BFloat16,3) +3-element Vector{BFloat16}: + -0.26367188 + -0.14160156 + 0.77734375 + +julia> A\b +3-element Vector{BFloat16}: + -0.24902344 + -0.38671875 + 0.36328125 +``` + +## `LowPrecArray` for mixed-precision Float32/BFloat16 matrix multiplications + In addition, this package provides the `LowPrecArray` type. This array is supposed to emulate the kind of matmul operation that TPUs do well (BFloat16 multiply with Float32 accumulate). Broadcasts and scalar operations @@ -66,6 +103,5 @@ julia> Float64.(A.storage)^2 1.22742 1.90498 1.70653 1.63928 2.18076 ``` -Note that the low precision result differs from (is less precise than) the -result computed in Float32 arithmetic (which matches the result in Float64 +Note that the low precision result differs from (is less precise than) the result computed in Float32 arithmetic (which matches the result in Float64 precision). diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 146334d..e435068 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -14,9 +14,44 @@ import Base: isfinite, isnan, precision, iszero, eps, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, clamp, hypot, - bitstring + bitstring, isinteger + + +import Printf + +# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating +# code that uses the `bfloat` IR type, together with the necessary runtime functions. +# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s +# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall +# back to defining a primitive type that will be represented as an `i16`. If, in addition, +# the target supports BFloat16 arithmetic, we can use LLVM intrinsics. +# - x86: storage and arithmetic support in LLVM 15 +# - aarch64: storage support in LLVM 17 +const llvm_storage = if isdefined(Core, :BFloat16) + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17" + true + else + false + end +else + false +end +const llvm_arithmetic = if llvm_storage + using Core: BFloat16 + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + else + false + end +else + primitive type BFloat16 <: AbstractFloat 16 end + false +end -primitive type BFloat16 <: AbstractFloat 16 end +Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) +Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x) # Floating point property queries for f in (:sign_mask, :exponent_mask, :exponent_one, @@ -141,16 +176,22 @@ end # accept Irrational BFloat16s.BFloat16(x::Irrational) = BFloat16(Float32(x)) -# Truncation to integer types -Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) -Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) - # Basic arithmetic -for f in (:+, :-, :*, :/, :^) - @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) +# Basic arithmetic +if llvm_arithmetic + +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) + -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) + *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) + /(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y) + -(x::BFloat16) = Base.neg_float(x) + ^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y)) +else + for f in (:+, :-, :*, :/, :^) + @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) + end + -(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) end --(x::BFloat16) = reinterpret(BFloat16, reinterpret(UInt16, x) ⊻ sign_mask(BFloat16)) -^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y)) +^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y) for F in (:abs, :sqrt, :exp, :log, :log2, :log10, :sin, :cos, :tan, :asin, :acos, :atan, @@ -196,6 +237,67 @@ end # Wide multiplication Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y) +# Truncation to integer types +if llvm_arithmetic + for Ti in (Int8, Int16, Int32, Int64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) + end + end + for Ti in (UInt8, UInt16, UInt32, UInt64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x) + end + end +else + Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) +end +for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128) + if Ti <: Unsigned || sizeof(Ti) < 2 + # Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound + # directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or + # rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`). + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti)) - one(BFloat16)) < x < $(BFloat16(typemax(Ti)) + one(BFloat16)) + return Base.unsafe_trunc($Ti, x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x) + return Base.unsafe_trunc($Ti, x) + else + throw(InexactError($(Expr(:quote, Ti.name.name)), $Ti, x)) + end + end + end + else + # Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be + # truncated to `BFloat16(typemin(Ti)` is itself. Similarly, + # `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that + # `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for + # `Float16` or larger integer types. + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti))) + return unsafe_trunc($Ti, x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x) + return unsafe_trunc($Ti, x) + else + throw(InexactError($(Expr(:quote, Ti.name.name)), $Ti, x)) + end + end + end + end +end + # Showing function Base.show(io::IO, x::BFloat16) hastypeinfo = BFloat16 === get(io, :typeinfo, Any) @@ -205,10 +307,11 @@ function Base.show(io::IO, x::BFloat16) print(io, "NaNB16") else hastypeinfo || print(io, "BFloat16(") - show(IOContext(io, :typeinfo=>Float32), Float32(x)) + show(IOContext(io, :typeinfo => Float32), Float32(x)) hastypeinfo || print(io, ")") end end +Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler @@ -223,34 +326,42 @@ exponent(x::BFloat16) = exponent(Float32(x)) bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x)) # next/prevfloat -function Base.nextfloat(x::BFloat16) - if isfinite(x) - ui = reinterpret(UInt16,x) - if ui < 0x8000 # positive numbers - return reinterpret(BFloat16,ui+0x0001) - elseif ui == 0x8000 # =-zero(T) - return reinterpret(BFloat16,0x0001) - else # negative numbers - return reinterpret(BFloat16,ui-0x0001) +function Base.nextfloat(f::BFloat16, d::Integer) + F = typeof(f) + fumax = reinterpret(Unsigned, F(Inf)) + U = typeof(fumax) + + isnan(f) && return f + fi = reinterpret(Signed, f) + fneg = fi < 0 + fu = unsigned(fi & typemax(fi)) + + dneg = d < 0 + da = uabs(d) + if da > typemax(U) + fneg = dneg + fu = fumax + else + du = da % U + if fneg ⊻ dneg + if du > fu + fu = min(fumax, du - fu) + fneg = !fneg + else + fu = fu - du + end + else + if fumax - fu < du + fu = fumax + else + fu = fu + du + end end - else # NaN / Inf case - return x end -end - -function Base.prevfloat(x::BFloat16) - if isfinite(x) - ui = reinterpret(UInt16,x) - if ui == 0x0000 # =zero(T) - return reinterpret(BFloat16,0x8001) - elseif ui < 0x8000 # positive numbers - return reinterpret(BFloat16,ui-0x0001) - else # negative numbers - return reinterpret(BFloat16,ui+0x0001) - end - else # NaN / Inf case - return x + if fneg + fu |= sign_mask(F) end + reinterpret(F, fu) end # math functions From 75db320ca20d3ee586c346697daebaadd45c6ed6 Mon Sep 17 00:00:00 2001 From: Jeffrey Sarnoff Date: Wed, 31 Jan 2024 10:35:48 -0500 Subject: [PATCH 6/7] Revert "fix conflicts" This reverts commit d6f43f76e0a08fa49bf2b753e6247803a22fa4c7. --- .vscode/settings.json | 1 - Project.toml | 4 +- README.md | 48 ++--------- src/bfloat16.jl | 183 +++++++++--------------------------------- 4 files changed, 44 insertions(+), 192 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9e26dfe..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1 +0,0 @@ -{} \ No newline at end of file diff --git a/Project.toml b/Project.toml index 009099d..7996586 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BFloat16s" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -authors = ["Keno Fischer "] -version = "0.5.0" +authors = ["Keno Fischer ", "Jeffrey Sarnoff "] +version = "0.4.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/README.md b/README.md index 98480df..146b269 100644 --- a/README.md +++ b/README.md @@ -11,20 +11,11 @@ experiments. # Usage -This package exports the BFloat16 data type. This datatype behaves just like any built-in floating point type - -```julia -julia> using BFloat16s - -julia> a = BFloat16(2) -BFloat16(2.0) - -julia> sqrt(a) -BFloat16(1.4140625) +This package exports the BFloat16 data type. This datatype should behave +just like any builtin floating point type (e.g. you can construct it from +other floating point types - e.g. `BFloat16(1.0)`). Many predicates, +conversion, structural and mathematical functions are supported: ``` - -Many predicates, conversion, structural and mathematical functions are supported: -```julia Int16, Int32, Int64, Float16, Float32, Float64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, @@ -35,34 +26,6 @@ Many predicates, conversion, structural and mathematical functions are supporte round, trunc, floor, ceil, abs, abs2, sqrt, cbrt, clamp, hypot, bitstring ``` -However, in practice you may hit a `MethodError` indicating that this package does not implement -some method for `BFloat16` although it should. Please raise an issue so that we can -close that gap in support. - -### solving a linear equation system - -```julia -julia> A = randn(BFloat16,3,3) -3×3 Matrix{BFloat16}: - 1.46875 -1.20312 -1.0 - 0.257812 -0.671875 -0.929688 - -0.410156 -1.75 -0.0162354 - -julia> b = randn(BFloat16,3) -3-element Vector{BFloat16}: - -0.26367188 - -0.14160156 - 0.77734375 - -julia> A\b -3-element Vector{BFloat16}: - -0.24902344 - -0.38671875 - 0.36328125 -``` - -## `LowPrecArray` for mixed-precision Float32/BFloat16 matrix multiplications - In addition, this package provides the `LowPrecArray` type. This array is supposed to emulate the kind of matmul operation that TPUs do well (BFloat16 multiply with Float32 accumulate). Broadcasts and scalar operations @@ -103,5 +66,6 @@ julia> Float64.(A.storage)^2 1.22742 1.90498 1.70653 1.63928 2.18076 ``` -Note that the low precision result differs from (is less precise than) the result computed in Float32 arithmetic (which matches the result in Float64 +Note that the low precision result differs from (is less precise than) the +result computed in Float32 arithmetic (which matches the result in Float64 precision). diff --git a/src/bfloat16.jl b/src/bfloat16.jl index e435068..146334d 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -14,44 +14,9 @@ import Base: isfinite, isnan, precision, iszero, eps, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, clamp, hypot, - bitstring, isinteger - - -import Printf - -# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating -# code that uses the `bfloat` IR type, together with the necessary runtime functions. -# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s -# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall -# back to defining a primitive type that will be represented as an `i16`. If, in addition, -# the target supports BFloat16 arithmetic, we can use LLVM intrinsics. -# - x86: storage and arithmetic support in LLVM 15 -# - aarch64: storage support in LLVM 17 -const llvm_storage = if isdefined(Core, :BFloat16) - if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" - true - elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17" - true - else - false - end -else - false -end -const llvm_arithmetic = if llvm_storage - using Core: BFloat16 - if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" - true - else - false - end -else - primitive type BFloat16 <: AbstractFloat 16 end - false -end + bitstring -Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) -Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x) +primitive type BFloat16 <: AbstractFloat 16 end # Floating point property queries for f in (:sign_mask, :exponent_mask, :exponent_one, @@ -176,22 +141,16 @@ end # accept Irrational BFloat16s.BFloat16(x::Irrational) = BFloat16(Float32(x)) +# Truncation to integer types +Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) +Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) + # Basic arithmetic -# Basic arithmetic -if llvm_arithmetic - +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) - -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) - *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) - /(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y) - -(x::BFloat16) = Base.neg_float(x) - ^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y)) -else - for f in (:+, :-, :*, :/, :^) - @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) - end - -(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) +for f in (:+, :-, :*, :/, :^) + @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) end -^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y) +-(x::BFloat16) = reinterpret(BFloat16, reinterpret(UInt16, x) ⊻ sign_mask(BFloat16)) +^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y)) for F in (:abs, :sqrt, :exp, :log, :log2, :log10, :sin, :cos, :tan, :asin, :acos, :atan, @@ -237,67 +196,6 @@ end # Wide multiplication Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y) -# Truncation to integer types -if llvm_arithmetic - for Ti in (Int8, Int16, Int32, Int64) - @eval begin - Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) - end - end - for Ti in (UInt8, UInt16, UInt32, UInt64) - @eval begin - Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x) - end - end -else - Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) -end -for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128) - if Ti <: Unsigned || sizeof(Ti) < 2 - # Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound - # directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or - # rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`). - @eval begin - function Base.trunc(::Type{$Ti}, x::BFloat16) - if $(BFloat16(typemin(Ti)) - one(BFloat16)) < x < $(BFloat16(typemax(Ti)) + one(BFloat16)) - return Base.unsafe_trunc($Ti, x) - else - throw(InexactError(:trunc, $Ti, x)) - end - end - function (::Type{$Ti})(x::BFloat16) - if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x) - return Base.unsafe_trunc($Ti, x) - else - throw(InexactError($(Expr(:quote, Ti.name.name)), $Ti, x)) - end - end - end - else - # Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be - # truncated to `BFloat16(typemin(Ti)` is itself. Similarly, - # `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that - # `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for - # `Float16` or larger integer types. - @eval begin - function Base.trunc(::Type{$Ti}, x::BFloat16) - if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti))) - return unsafe_trunc($Ti, x) - else - throw(InexactError(:trunc, $Ti, x)) - end - end - function (::Type{$Ti})(x::BFloat16) - if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x) - return unsafe_trunc($Ti, x) - else - throw(InexactError($(Expr(:quote, Ti.name.name)), $Ti, x)) - end - end - end - end -end - # Showing function Base.show(io::IO, x::BFloat16) hastypeinfo = BFloat16 === get(io, :typeinfo, Any) @@ -307,11 +205,10 @@ function Base.show(io::IO, x::BFloat16) print(io, "NaNB16") else hastypeinfo || print(io, "BFloat16(") - show(IOContext(io, :typeinfo => Float32), Float32(x)) + show(IOContext(io, :typeinfo=>Float32), Float32(x)) hastypeinfo || print(io, ")") end end -Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler @@ -326,42 +223,34 @@ exponent(x::BFloat16) = exponent(Float32(x)) bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x)) # next/prevfloat -function Base.nextfloat(f::BFloat16, d::Integer) - F = typeof(f) - fumax = reinterpret(Unsigned, F(Inf)) - U = typeof(fumax) - - isnan(f) && return f - fi = reinterpret(Signed, f) - fneg = fi < 0 - fu = unsigned(fi & typemax(fi)) - - dneg = d < 0 - da = uabs(d) - if da > typemax(U) - fneg = dneg - fu = fumax - else - du = da % U - if fneg ⊻ dneg - if du > fu - fu = min(fumax, du - fu) - fneg = !fneg - else - fu = fu - du - end - else - if fumax - fu < du - fu = fumax - else - fu = fu + du - end +function Base.nextfloat(x::BFloat16) + if isfinite(x) + ui = reinterpret(UInt16,x) + if ui < 0x8000 # positive numbers + return reinterpret(BFloat16,ui+0x0001) + elseif ui == 0x8000 # =-zero(T) + return reinterpret(BFloat16,0x0001) + else # negative numbers + return reinterpret(BFloat16,ui-0x0001) end + else # NaN / Inf case + return x end - if fneg - fu |= sign_mask(F) +end + +function Base.prevfloat(x::BFloat16) + if isfinite(x) + ui = reinterpret(UInt16,x) + if ui == 0x0000 # =zero(T) + return reinterpret(BFloat16,0x8001) + elseif ui < 0x8000 # positive numbers + return reinterpret(BFloat16,ui-0x0001) + else # negative numbers + return reinterpret(BFloat16,ui+0x0001) + end + else # NaN / Inf case + return x end - reinterpret(F, fu) end # math functions From 663b207059da073c730565f79e025e59c42ab11d Mon Sep 17 00:00:00 2001 From: "Viral B. Shah" Date: Tue, 17 Sep 2024 13:59:02 -0400 Subject: [PATCH 7/7] Update bfloat16.jl --- src/bfloat16.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 278854b..1b234a2 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -433,9 +433,9 @@ for F in (:abs, :abs2, :sqrt, :cbrt, end end - Base.atan(y::BFloat16, x::BFloat16) = BFloat16(atan(Float32(y), Float32(x))) +Base.atan(y::BFloat16, x::BFloat16) = BFloat16(atan(Float32(y), Float32(x))) Base.hypot(x::BFloat16, y::BFloat16) = BFloat16(hypot(Float32(x), Float32(y))) Base.hypot(x::BFloat16, y::BFloat16, z::BFloat16) = BFloat16(hypot(Float32(x), Float32(y), Float32(z))) Base.clamp(x::BFloat16, lo::BFloat16, hi::BFloat16) = BFloat16(clamp(Float32(x), Float32(lo), Float32(hi))) -Base.bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x)) \ No newline at end of file +Base.bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x))