Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort for NTuples #54494

Merged
merged 20 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Sort

using Base.Order

using Base: copymutable, midpoint, require_one_based_indexing, uinttype,
using Base: copymutable, midpoint, require_one_based_indexing, uinttype, tail,
sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit

import Base:
Expand Down Expand Up @@ -1475,21 +1475,16 @@ InitialOptimizations(next) = SubArrayOptimization(
Small{10}(
IEEEFloatOptimization(
next)))))
"""
DEFAULT_STABLE

The default sorting algorithm.

This algorithm is guaranteed to be stable (i.e. it will not reorder elements that compare
equal). It makes an effort to be fast for most inputs.

The algorithms used by `DEFAULT_STABLE` are an implementation detail. See extended help
for the current dispatch system.
"""
struct DefaultStable <: Algorithm end

# Extended Help
`DefaultStable` is an algorithm which indicates that a fast, general purpose sorting
algorithm should be used, but does not specify exactly which algorithm.

`DEFAULT_STABLE` is composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid
of Radix, Insertion, Counting, Quick sorts.
Currently, when sorting short NTuples, this is an unrolled mergesort, and otherwise it is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a compat notice indicating what versions of Julia support NTuple sorting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though it should live in ?sort, not ?Base.Sort.DEFAULT_STABLE

composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of Radix, Insertion,
Counting, Quick sorts.

We begin with MissingOptimization because it has no runtime cost when it is not
triggered and can enable other optimizations to be applied later. For example,
Expand Down Expand Up @@ -1549,7 +1544,39 @@ stage.
Finally, if the input has length less than 80, we dispatch to [`InsertionSort`](@ref) and
otherwise we dispatch to [`ScratchQuickSort`](@ref).
"""
const DEFAULT_STABLE = InitialOptimizations(
struct DefaultStable <: Algorithm end

"""
DEFAULT_STABLE

The default sorting algorithm.

This algorithm is guaranteed to be stable (i.e. it will not reorder elements that compare
equal). It makes an effort to be fast for most inputs.

The algorithms used by `DEFAULT_STABLE` are an implementation detail. See the extended help
of `Base.Sort.DefaultStable` for the current dispatch system.
"""
const DEFAULT_STABLE = DefaultStable()

"""
DefaultUnstable <: Algorithm

Like [`DefaultStable`](@ref), but does not guarantee stability.
"""
struct DefaultUnstable <: Algorithm end

"""
DEFAULT_UNSTABLE

An efficient sorting algorithm which may or may not be stable.

The algorithms used by `DEFAULT_UNSTABLE` are an implementation detail. They are currently
the same as those used by [`DEFAULT_STABLE`](@ref), but this is subject to change in future.
"""
const DEFAULT_UNSTABLE = DefaultUnstable()

const _DEFAULT_ALGORITHMS_FOR_VECTORS = InitialOptimizations(
IsUIntMappable(
Small{40}(
CheckSorted(
Expand All @@ -1560,15 +1587,10 @@ const DEFAULT_STABLE = InitialOptimizations(
ScratchQuickSort())))))),
StableCheckSorted(
ScratchQuickSort())))
"""
DEFAULT_UNSTABLE

An efficient sorting algorithm.
_sort!(v::AbstractVector, ::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw) =
_sort!(v, _DEFAULT_ALGORITHMS_FOR_VECTORS, o, kw)

The algorithms used by `DEFAULT_UNSTABLE` are an implementation detail. They are currently
the same as those used by [`DEFAULT_STABLE`](@ref), but this is subject to change in future.
"""
const DEFAULT_UNSTABLE = DEFAULT_STABLE
const SMALL_THRESHOLD = 20

function Base.show(io::IO, alg::Algorithm)
Expand Down Expand Up @@ -1598,6 +1620,7 @@ defalg(v::AbstractArray) = DEFAULT_STABLE
defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE
defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::NTuple) = DEFAULT_STABLE

"""
sort!(v; alg::Base.Sort.Algorithm=Base.Sort.defalg(v), lt=isless, by=identity, rev::Bool=false, order::Base.Order.Ordering=Base.Order.Forward)
Expand Down Expand Up @@ -1736,6 +1759,37 @@ julia> v
"""
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)

function sort(x::NTuple{N,T};
alg::Algorithm=defalg(x),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
scratch::Union{Vector{T}, Nothing}=nothing) where {N,T}
_sort(x, alg, ord(lt,by,rev,order), (;scratch))
end
# Folks who want to hack internals can define a new _sort(x::NTuple, ::TheirAlg, o::Ordering)
# or _sort(x::NTuple{N, TheirType}, ::DefaultStable, o::Ordering) where N
function _sort(x::NTuple, a::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be public if you want people to overload it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"hack internals" sounds like it shouldn't be public?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lilith mentioned earlier that a motivation for the design was allowing people to dispatch on the eltype and algorithm. I guess they changed their mind given the comment, but I still think it'd be nice if things like this were made to be dispatchable API when possible. It's a very annoying downside of kwargs in public APIs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _sort! method is designed to be extensible and will eventually be made public (maybe renamed to sort_impl! or something). However, as a complicated API, I wanted to give it a couple of years of usage within Base and packages that choose to hack internals before making it public. IMO it's better to have good features a while from now than have to choose between sub-optimal API and breaking changes.

if length(x) > 9
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
v = copymutable(x)
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
_sort!(v, a, o, kw)
typeof(x)(v)
else
_mergesort(x, o)
end
end
_mergesort(x::Union{NTuple{0}, NTuple{1}}, o::Ordering) = x
function _mergesort(x::NTuple, o::Ordering)
a, b = Base.IteratorsMD.split(x, Val(length(x)>>1))
merge(_mergesort(a, o), _mergesort(b, o), o)
end
merge(x::NTuple, y::NTuple{0}, o::Ordering) = x
merge(x::NTuple{0}, y::NTuple, o::Ordering) = y
merge(x::NTuple{0}, y::NTuple{0}, o::Ordering) = x # Method ambiguity
merge(x::NTuple, y::NTuple, o::Ordering) =
(lt(o, y[1], x[1]) ? (y[1], merge(x, tail(y), o)...) : (x[1], merge(tail(x), y, o)...))

## partialsortperm: the permutation to sort the first k elements of an array ##

"""
Expand Down
43 changes: 37 additions & 6 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ end
vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99)
end

function tuple_sort_test(x)
@test issorted(sort(x))
length(x) > 9 && return # length > 9 uses a vector fallback
@test 0 == @allocated sort(x)
end
@testset "sort(::NTuple)" begin
@test sort((9,8,3,3,6,2,0,8)) == (0,2,3,3,6,8,8,9)
@test sort((9,8,3,3,6,2,0,8), by=x->x÷3) == (2,0,3,3,8,6,8,9)
for i in 1:40
tuple_sort_test(rand(NTuple{i, Float64}))
end
@test_throws MethodError sort((1,2,3.0))
end

@testset "partialsort" begin
@test partialsort([3,6,30,1,9],3) == 6
@test partialsort([3,6,30,1,9],3:4) == [6,9]
Expand Down Expand Up @@ -799,9 +813,9 @@ end
let
requires_uint_mappable = Union{Base.Sort.RadixSort, Base.Sort.ConsiderRadixSort,
Base.Sort.CountingSort, Base.Sort.ConsiderCountingSort,
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes),
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes.big),
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes.big.next)}
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes),
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes.big),
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes.big.next)}

function test_alg(kw, alg, float=true)
for order in [Base.Forward, Base.Reverse, Base.By(x -> x^2)]
Expand Down Expand Up @@ -841,15 +855,18 @@ end
end
end

test_alg_rec(Base.DEFAULT_STABLE)
test_alg_rec(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS)
end
end

@testset "show(::Algorithm)" begin
@test eval(Meta.parse(string(Base.DEFAULT_STABLE))) === Base.DEFAULT_STABLE
lines = split(string(Base.DEFAULT_STABLE), '\n')
@test eval(Meta.parse(string(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS))) === Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS
lines = split(string(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS), '\n')
@test 10 < maximum(length, lines) < 100
@test 1 < length(lines) < 30

@test eval(Meta.parse(string(Base.DEFAULT_STABLE))) === Base.DEFAULT_STABLE
@test string(Base.DEFAULT_STABLE) == "Base.Sort.DefaultStable()"
end

@testset "Extensibility" begin
Expand Down Expand Up @@ -890,6 +907,20 @@ end
end
@test sort([1,2,3], alg=MySecondAlg()) == [9,9,9]
@test all(sort(v, alg=Base.Sort.InitialOptimizations(MySecondAlg())) .=== vcat(fill(9, 100), fill(missing, 10)))

# Tuple extensions (custom alg)
@test_throws MethodError sort((1,2,3), alg=MyFirstAlg())
Base.Sort._sort(v::NTuple, ::MyFirstAlg, o::Base.Order.Ordering, kw) = "hi!"
@test sort((1,2,3), alg=MyFirstAlg()) == "hi!"

struct TupleFoo
x::Int
end

# Tuple extensions (custom type)
@test_throws MethodError sort(TupleFoo.((3,1,2)))
Base.Sort._sort(v::NTuple{N, TupleFoo}, ::Base.Sort.DefaultStable, o::Base.Order.Ordering, kw) where N = v
@test sort(TupleFoo.((3,1,2))) === TupleFoo.((3,1,2))
end

@testset "sort!(v, lo, hi, alg, order)" begin
Expand Down