diff --git a/NEWS.md b/NEWS.md index 78d1d9b4bc7b8..3c893566c7219 100644 --- a/NEWS.md +++ b/NEWS.md @@ -110,6 +110,7 @@ New library features * `invoke` now supports passing a Method instead of a type signature making this interface somewhat more flexible for certain uncommon use cases ([#56692]). * `invoke` now supports passing a CodeInstance instead of a type, which can enable certain compiler plugin workflows ([#56660]). +* `sort` now supports `NTuple`s ([#54494]) Standard library changes ------------------------ diff --git a/base/sort.jl b/base/sort.jl index 29e67a3eb8d8c..f9d74f626fe58 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -4,7 +4,7 @@ module Sort using Base.Order -using Base: copymutable, midpoint, require_one_based_indexing, uinttype, +using Base: copymutable, midpoint, require_one_based_indexing, uinttype, tail, sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit import Base: @@ -1482,8 +1482,9 @@ InitialOptimizations(next) = SubArrayOptimization( `DefaultStable` is an algorithm which indicates that a fast, general purpose sorting algorithm should be used, but does not specify exactly which algorithm. -Currently, it is composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of -Radix, Insertion, Counting, Quick sorts. +Currently, when sorting short NTuples, this is an unrolled mergesort, and otherwise it is +composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of Radix, Insertion, +Counting, Quick sorts. We begin with MissingOptimization because it has no runtime cost when it is not triggered and can enable other optimizations to be applied later. For example, @@ -1619,6 +1620,7 @@ defalg(v::AbstractArray) = DEFAULT_STABLE defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation +defalg(v::NTuple) = DEFAULT_STABLE """ sort!(v; alg::Base.Sort.Algorithm=Base.Sort.defalg(v), lt=isless, by=identity, rev::Bool=false, order::Base.Order.Ordering=Base.Order.Forward) @@ -1757,6 +1759,41 @@ julia> v """ sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +function sort(x::NTuple; + alg::Algorithm=defalg(x), + lt=isless, + by=identity, + rev::Union{Bool,Nothing}=nothing, + order::Ordering=Forward, + scratch::Union{Vector, Nothing}=nothing) + # Can't do this check with type parameters because of https://github.com/JuliaLang/julia/issues/56698 + scratch === nothing || eltype(x) == eltype(scratch) || throw(ArgumentError("scratch has the wrong eltype")) + _sort(x, alg, ord(lt,by,rev,order), (;scratch))::typeof(x) +end +# Folks who want to hack internals can define a new _sort(x::NTuple, ::TheirAlg, o::Ordering) +# or _sort(x::NTuple{N, TheirType}, ::DefaultStable, o::Ordering) where N +function _sort(x::NTuple, a::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw) + # The unrolled tuple sort is prohibitively slow to compile for length > 9. + # See https://github.com/JuliaLang/julia/pull/46104#issuecomment-1435688502 for benchmarks + if length(x) > 9 + v = copymutable(x) + _sort!(v, a, o, kw) + typeof(x)(v) + else + _mergesort(x, o) + end +end +_mergesort(x::Union{NTuple{0}, NTuple{1}}, o::Ordering) = x +function _mergesort(x::NTuple, o::Ordering) + a, b = Base.IteratorsMD.split(x, Val(length(x)>>1)) + merge(_mergesort(a, o), _mergesort(b, o), o) +end +merge(x::NTuple, y::NTuple{0}, o::Ordering) = x +merge(x::NTuple{0}, y::NTuple, o::Ordering) = y +merge(x::NTuple{0}, y::NTuple{0}, o::Ordering) = x # Method ambiguity +merge(x::NTuple, y::NTuple, o::Ordering) = + (lt(o, y[1], x[1]) ? (y[1], merge(x, tail(y), o)...) : (x[1], merge(tail(x), y, o)...)) + ## partialsortperm: the permutation to sort the first k elements of an array ## """ diff --git a/test/sorting.jl b/test/sorting.jl index 71af50429027a..e16b30de5bfc8 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -94,6 +94,22 @@ end vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99) end +function tuple_sort_test(x) + @test issorted(sort(x)) + length(x) > 9 && return # length > 9 uses a vector fallback + @test 0 == @allocated sort(x) +end +@testset "sort(::NTuple)" begin + @test sort(()) == () + @test sort((9,8,3,3,6,2,0,8)) == (0,2,3,3,6,8,8,9) + @test sort((9,8,3,3,6,2,0,8), by=x->x÷3) == (2,0,3,3,8,6,8,9) + for i in 1:40 + tuple_sort_test(rand(NTuple{i, Float64})) + end + @test_throws MethodError sort((1,2,3.0)) + @test Base.infer_return_type(sort, Tuple{Tuple{Vararg{Int}}}) == Tuple{Vararg{Int}} +end + @testset "partialsort" begin @test partialsort([3,6,30,1,9],3) == 6 @test partialsort([3,6,30,1,9],3:4) == [6,9] @@ -913,6 +929,20 @@ end end @test sort([1,2,3], alg=MySecondAlg()) == [9,9,9] @test all(sort(v, alg=Base.Sort.InitialOptimizations(MySecondAlg())) .=== vcat(fill(9, 100), fill(missing, 10))) + + # Tuple extensions (custom alg) + @test_throws MethodError sort((1,2,3), alg=MyFirstAlg()) + Base.Sort._sort(v::NTuple, ::MyFirstAlg, o::Base.Order.Ordering, kw) = (17,2,9) + @test sort((1,2,3), alg=MyFirstAlg()) == (17,2,9) + + struct TupleFoo + x::Int + end + + # Tuple extensions (custom type) + @test_throws MethodError sort(TupleFoo.((3,1,2))) + Base.Sort._sort(v::NTuple{N, TupleFoo}, ::Base.Sort.DefaultStable, o::Base.Order.Ordering, kw) where N = v + @test sort(TupleFoo.((3,1,2))) === TupleFoo.((3,1,2)) end @testset "sort!(v, lo, hi, alg, order)" begin