Skip to content

Commit

Permalink
expanded select to also deal with ranges.
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Dec 16, 2024
1 parent 7a6e941 commit 61cbb4a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
PaddedViews = "0.5.8, 0.5.9, 0.5.10, 0.5.11"
OffsetArrays = "1, 1.12"
julia = "1, 1.6, 1.7, 1.8"
PaddedViews = "0.5.8, 0.5.9, 0.5.10, 0.5.11, 0.5.12"
OffsetArrays = "1, 1.12, 1.13, 1.14"
julia = "1, 1.6, 1.7, 1.8, 1.9, 1.10, 1.11"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
13 changes: 11 additions & 2 deletions src/selection_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export slice
"""
slice(arr, dim, index)
Return a `N` dimensional slice (where one dimensions has size 1) of the N-dimensional `arr` at the index position
Return a `N` dimensional slice view (where one dimensions has size 1) of the N-dimensional `arr` at the index position (or range)
`index` in the `dim` dimension of the array.
It holds `size(out)[dim] == 1`.
Expand All @@ -22,7 +22,7 @@ julia> NDTools.slice(x, 1, 1)
1 2 3
```
"""
function slice(arr::AbstractArray{T, N}, dim::Integer, index::Integer) where {T, N}
function slice(arr::AbstractArray{T, N}, dim::Integer, index::Union{Integer, UnitRange}) where {T, N}
inds = slice_indices(axes(arr), dim, index)
return @view arr[inds...]
end
Expand All @@ -33,6 +33,9 @@ end
# Arguments:
`a` should be the axes obtained by `axes(arr)` of an array.
`dim` is the dimension to be selected and `index` the index of it.
`index` can be an integer or a range,but the dimensions is always kepts
# Returns: a tuple of ranges used for slicing
Examples
```jldoctest
Expand All @@ -47,6 +50,12 @@ function slice_indices(a::NTuple{N, T}, dim::Integer, index::Integer) where {T,
return inds
end

function slice_indices(a::NTuple{N, T}, dim::Integer, index::UnitRange) where {T, N}
inds = ntuple(i -> i == dim ? (a[i][index]) : (first(a[i]):last(a[i])),
Val(N))
return inds
end

"""
expand_dims(x, ::Val{N})
Expand Down
5 changes: 5 additions & 0 deletions test/selection_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
x = randn((5,2,3,4))
y = NDTools.slice(x, 1, 4)
@test x[4:4, :, :, :] == y
y = NDTools.slice(x, 1, 2:3)
@test x[2:3, :, :, :] == y
y = NDTools.slice(x, 4, 4:4)
@test x[:, :, :, 4:4] == y


x = randn((5))
y = NDTools.slice(x, 1, 5)
Expand Down

0 comments on commit 61cbb4a

Please sign in to comment.