Skip to content

Commit

Permalink
grab bag of fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Oct 30, 2024
1 parent 8f57af0 commit 3274432
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 67 deletions.
17 changes: 16 additions & 1 deletion src/DynamicGrids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ import Stencils: BoundaryCondition, Padding

@deprecate positions Stencils.indices

# Here so it can be used in @generated
# _asiterable
# Return some iterable value from a
# Symbol, Tuple or tuple type
@inline _asiterable(x) = (x,)
@inline _asiterable(x::Symbol) = (x,)
@inline _asiterable(x::Type{<:Tuple}) = x.parameters
@inline _asiterable(x::Tuple) = x
@inline _asiterable(x::AbstractArray) = x

# Unwrap a Val or Val type to its internal value
_unwrap(x) = x
_unwrap(::Val{X}) where X = X
_unwrap(::Type{<:Val{X}}) where X = X

include("interface.jl")
include("flags.jl")
include("rules.jl")
Expand Down Expand Up @@ -138,10 +153,10 @@ include("generated.jl")
include("maprules.jl")
include("boundaries.jl")
include("sparseopt.jl")
include("utils.jl")
include("copyto.jl")
include("life.jl")
include("adapt.jl")
include("utils.jl")
include("show.jl")

function __init__()
Expand Down
1 change: 0 additions & 1 deletion src/atomic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ for (f, op) in ATOMIC_OPS
I1 = add_halo(d, _maybe_complete_indices(d, I))
@boundscheck checkbounds(dest(d), I1...)
@inbounds _setoptindex!(d, x, I1...)
# @show I I1
@inbounds dest(d)[I1...] = ($op)(dest(d)[I1...], x)
end
end
Expand Down
12 changes: 11 additions & 1 deletion src/extent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ mutable struct Extent{I<:Union{AbstractArray,NamedTuple},
gridsize = size(init)
end
if (mask !== nothing) && (size(mask) != gridsize)
throw(ArgumentError("`mask` size do not match `init`"))
if mask isa AbstractDimArray && hasdim(mask, Ti())
last(dims(mask)) isa Ti || throw(ArgumentError("Time dimension mus be the last dimension. Use `permutedims` on the mask first."))
m1 = view(mask, Ti(1))
size(m1) == gridsize || _masksize_error(size(m1), gridsize)
else
_masksize_error(size(mask), gridsize)
end
end
new{I,M,A,typeof(padval),R}(init, mask, aux, padval, replicates, tspan)
end
Expand All @@ -96,6 +102,10 @@ Extent(; init, mask=nothing, aux=nothing, padval=_padval(init), replicates=nothi
Extent(init, mask, aux, padval, replicates, tspan)
Extent(init::Union{AbstractArray,NamedTuple}; kw...) = Extent(; init, kw...)

_masksize_error(masksize, gridsize) =
throw(ArgumentError("`mask` size $masksize do not match `init` size $gridsize"))


settspan!(e::Extent, tspan) = e.tspan = tspan

_padval(init::NamedTuple) = map(_padval, init)
Expand Down
10 changes: 5 additions & 5 deletions src/generated.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# _vals2syms => Union{Symbol,Tuple}
# Must be at the top of the file for world age problems
# Get symbols from a Val or Tuple type
@inline _vals2syms(x::Type{<:Tuple}) = map(v -> _vals2syms(v), x.parameters)
@inline _vals2syms(::Type{<:Val{X}}) where X = X

# Low-level generated functions for working with grids

Expand Down Expand Up @@ -232,8 +237,3 @@ end
NamedTuple{$allkeys,typeof(vals)}(vals)
end
end

# _vals2syms => Union{Symbol,Tuple}
# Get symbols from a Val or Tuple type
@inline _vals2syms(x::Type{<:Tuple}) = map(v -> _vals2syms(v), x.parameters)
@inline _vals2syms(::Type{<:Val{X}}) where X = X
10 changes: 2 additions & 8 deletions src/maprules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,7 @@ end
# dimension we hide the extra dimension from rules.
I1 = _strip_replicates(data, I)
# We skip the cell if there is a mask layer
m = mask(data)
if !isnothing(m)
m[I1...] || return nothing
end
ismasked(data, I1...) && return nothing
# We read a value from the grid
readval = _readcell(data, rkeys, I...)
# Update the data object
Expand All @@ -230,10 +227,7 @@ end
end
@inline function cell_kernel!(data::RuleData, ::Val{<:SetRule}, rule, rkeys, wkeys, I...)
I1 = _strip_replicates(data, I)
m = mask(data)
if !isnothing(m)
m[I1...] || return nothing
end
m = ismasked(data, I...) && return nothing
readval = _readcell(data, rkeys, I...)
data1 = ConstructionBase.setproperties(data, (value=readval, indices = I))
# Rules will manually write to grids in `applyrule!`
Expand Down
25 changes: 15 additions & 10 deletions src/parametersources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ end
) where {N1,N2}
if hasdim(A, TimeDim)
last(dims(A)) isa TimeDim || throw(ArgumentError("The time dimensions in aux data must be the last dimension"))
A[ntuple(i -> I[i], Val{N1-1}())..., auxframe(data, key)]
af = auxframe(data)
if !isnothing(af) && haskey(af, _unwrap(key))
A[ntuple(i -> I[i], Val{N1-1}())..., af[_unwrap(key)]]
else
# Catch static situations with no known auxframe
A[ntuple(i -> I[i], Val{N1-1}())..., 1]
end
else
A[ntuple(i -> I[i], Val{N1-1}())...]
end
Expand Down Expand Up @@ -127,25 +133,24 @@ function boundscheck_aux(data::AbstractSimData, A::AbstractDimArray, key::Aux{Ke
end
end

# _calc_auxframe
# Calculate the frame to use in the aux data for this timestep.
# _calc_frame
# Calculate the frame to use in the aux or mask data for this timestep.
# This uses the index of any AbstractDimArray, which must be a
# matching type to the simulation tspan.
# This is called from _updatetime in simulationdata.jl
_calc_auxframe(data::AbstractSimData) = _calc_auxframe(aux(data), data)
function _calc_auxframe(aux::NamedTuple{K}, data::AbstractSimData) where K
map((A, k) -> _calc_auxframe(A, data, k), aux, NamedTuple{K}(K))
function _calc_frame(aux::NamedTuple{K}, data::AbstractSimData) where K
map((A, k) -> _calc_frame(A, data, k), aux, NamedTuple{K}(K))
end
function _calc_auxframe(A::AbstractDimArray, data, key)
function _calc_frame(A::AbstractDimArray, data, key)
hasdim(A, TimeDim) || return nothing
timedim = dims(A, TimeDim)
curtime = currenttime(data)
if !hasselection(timedim, Contains(curtime))
if lookup(timedim) isa Cyclic
if sampling(timedim) isa Points
throw(ArgumentError("$(_no_valid_time(timedim,key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact."))
throw(ArgumentError("$(_no_valid_time(timedim, key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact."))
else
throw(ArgumentError("$(_no_valid_time(timedim,key, curtime))"))
throw(ArgumentError("$(_no_valid_time(timedim, key, curtime))"))
end
elseif sampling(timedim) isa Points
throw(ArgumentError("$(_no_valid_time(timedim,key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact."))
Expand All @@ -155,7 +160,7 @@ function _calc_auxframe(A::AbstractDimArray, data, key)
end
return DimensionalData.selectindices(timedim, Contains(curtime))
end
_calc_auxframe(args...) = nothing
_calc_frame(args...) = nothing

_no_valid_time(timedim, key, curtime) = "Time dimension over $(bounds(timedim)) of aux `$key` has no valid selection for `Contains($curtime)`."

Expand Down
Loading

0 comments on commit 3274432

Please sign in to comment.