diff --git a/src/DynamicGrids.jl b/src/DynamicGrids.jl index 07e857a..cac3f5d 100644 --- a/src/DynamicGrids.jl +++ b/src/DynamicGrids.jl @@ -104,6 +104,21 @@ import Stencils: BoundaryCondition, Padding @deprecate positions Stencils.indices +# Here so it can be used in @generated +# _asiterable +# Return some iterable value from a +# Symbol, Tuple or tuple type +@inline _asiterable(x) = (x,) +@inline _asiterable(x::Symbol) = (x,) +@inline _asiterable(x::Type{<:Tuple}) = x.parameters +@inline _asiterable(x::Tuple) = x +@inline _asiterable(x::AbstractArray) = x + +# Unwrap a Val or Val type to its internal value +_unwrap(x) = x +_unwrap(::Val{X}) where X = X +_unwrap(::Type{<:Val{X}}) where X = X + include("interface.jl") include("flags.jl") include("rules.jl") @@ -138,10 +153,10 @@ include("generated.jl") include("maprules.jl") include("boundaries.jl") include("sparseopt.jl") -include("utils.jl") include("copyto.jl") include("life.jl") include("adapt.jl") +include("utils.jl") include("show.jl") function __init__() diff --git a/src/atomic.jl b/src/atomic.jl index 82a9339..e0ec924 100644 --- a/src/atomic.jl +++ b/src/atomic.jl @@ -20,7 +20,6 @@ for (f, op) in ATOMIC_OPS I1 = add_halo(d, _maybe_complete_indices(d, I)) @boundscheck checkbounds(dest(d), I1...) @inbounds _setoptindex!(d, x, I1...) - # @show I I1 @inbounds dest(d)[I1...] = ($op)(dest(d)[I1...], x) end end diff --git a/src/extent.jl b/src/extent.jl index 57d3f0b..0eb256e 100644 --- a/src/extent.jl +++ b/src/extent.jl @@ -87,7 +87,13 @@ mutable struct Extent{I<:Union{AbstractArray,NamedTuple}, gridsize = size(init) end if (mask !== nothing) && (size(mask) != gridsize) - throw(ArgumentError("`mask` size do not match `init`")) + if mask isa AbstractDimArray && hasdim(mask, Ti()) + last(dims(mask)) isa Ti || throw(ArgumentError("Time dimension mus be the last dimension. Use `permutedims` on the mask first.")) + m1 = view(mask, Ti(1)) + size(m1) == gridsize || _masksize_error(size(m1), gridsize) + else + _masksize_error(size(mask), gridsize) + end end new{I,M,A,typeof(padval),R}(init, mask, aux, padval, replicates, tspan) end @@ -96,6 +102,10 @@ Extent(; init, mask=nothing, aux=nothing, padval=_padval(init), replicates=nothi Extent(init, mask, aux, padval, replicates, tspan) Extent(init::Union{AbstractArray,NamedTuple}; kw...) = Extent(; init, kw...) +_masksize_error(masksize, gridsize) = + throw(ArgumentError("`mask` size $masksize do not match `init` size $gridsize")) + + settspan!(e::Extent, tspan) = e.tspan = tspan _padval(init::NamedTuple) = map(_padval, init) diff --git a/src/generated.jl b/src/generated.jl index e9d254c..bc88e86 100644 --- a/src/generated.jl +++ b/src/generated.jl @@ -1,3 +1,8 @@ +# _vals2syms => Union{Symbol,Tuple} +# Must be at the top of the file for world age problems +# Get symbols from a Val or Tuple type +@inline _vals2syms(x::Type{<:Tuple}) = map(v -> _vals2syms(v), x.parameters) +@inline _vals2syms(::Type{<:Val{X}}) where X = X # Low-level generated functions for working with grids @@ -232,8 +237,3 @@ end NamedTuple{$allkeys,typeof(vals)}(vals) end end - -# _vals2syms => Union{Symbol,Tuple} -# Get symbols from a Val or Tuple type -@inline _vals2syms(x::Type{<:Tuple}) = map(v -> _vals2syms(v), x.parameters) -@inline _vals2syms(::Type{<:Val{X}}) where X = X diff --git a/src/maprules.jl b/src/maprules.jl index 8aa92bd..a12c2f7 100644 --- a/src/maprules.jl +++ b/src/maprules.jl @@ -213,10 +213,7 @@ end # dimension we hide the extra dimension from rules. I1 = _strip_replicates(data, I) # We skip the cell if there is a mask layer - m = mask(data) - if !isnothing(m) - m[I1...] || return nothing - end + ismasked(data, I1...) && return nothing # We read a value from the grid readval = _readcell(data, rkeys, I...) # Update the data object @@ -230,10 +227,7 @@ end end @inline function cell_kernel!(data::RuleData, ::Val{<:SetRule}, rule, rkeys, wkeys, I...) I1 = _strip_replicates(data, I) - m = mask(data) - if !isnothing(m) - m[I1...] || return nothing - end + m = ismasked(data, I...) && return nothing readval = _readcell(data, rkeys, I...) data1 = ConstructionBase.setproperties(data, (value=readval, indices = I)) # Rules will manually write to grids in `applyrule!` diff --git a/src/parametersources.jl b/src/parametersources.jl index 9b7484c..9964e1c 100644 --- a/src/parametersources.jl +++ b/src/parametersources.jl @@ -94,7 +94,13 @@ end ) where {N1,N2} if hasdim(A, TimeDim) last(dims(A)) isa TimeDim || throw(ArgumentError("The time dimensions in aux data must be the last dimension")) - A[ntuple(i -> I[i], Val{N1-1}())..., auxframe(data, key)] + af = auxframe(data) + if !isnothing(af) && haskey(af, _unwrap(key)) + A[ntuple(i -> I[i], Val{N1-1}())..., af[_unwrap(key)]] + else + # Catch static situations with no known auxframe + A[ntuple(i -> I[i], Val{N1-1}())..., 1] + end else A[ntuple(i -> I[i], Val{N1-1}())...] end @@ -127,25 +133,24 @@ function boundscheck_aux(data::AbstractSimData, A::AbstractDimArray, key::Aux{Ke end end -# _calc_auxframe -# Calculate the frame to use in the aux data for this timestep. +# _calc_frame +# Calculate the frame to use in the aux or mask data for this timestep. # This uses the index of any AbstractDimArray, which must be a # matching type to the simulation tspan. # This is called from _updatetime in simulationdata.jl -_calc_auxframe(data::AbstractSimData) = _calc_auxframe(aux(data), data) -function _calc_auxframe(aux::NamedTuple{K}, data::AbstractSimData) where K - map((A, k) -> _calc_auxframe(A, data, k), aux, NamedTuple{K}(K)) +function _calc_frame(aux::NamedTuple{K}, data::AbstractSimData) where K + map((A, k) -> _calc_frame(A, data, k), aux, NamedTuple{K}(K)) end -function _calc_auxframe(A::AbstractDimArray, data, key) +function _calc_frame(A::AbstractDimArray, data, key) hasdim(A, TimeDim) || return nothing timedim = dims(A, TimeDim) curtime = currenttime(data) if !hasselection(timedim, Contains(curtime)) if lookup(timedim) isa Cyclic if sampling(timedim) isa Points - throw(ArgumentError("$(_no_valid_time(timedim,key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact.")) + throw(ArgumentError("$(_no_valid_time(timedim, key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact.")) else - throw(ArgumentError("$(_no_valid_time(timedim,key, curtime))")) + throw(ArgumentError("$(_no_valid_time(timedim, key, curtime))")) end elseif sampling(timedim) isa Points throw(ArgumentError("$(_no_valid_time(timedim,key, curtime)) Did you mean to use `Intervals` for the time dimension `sampling`? `Contains` on `Points` defaults to `At`, and must be exact.")) @@ -155,7 +160,7 @@ function _calc_auxframe(A::AbstractDimArray, data, key) end return DimensionalData.selectindices(timedim, Contains(curtime)) end -_calc_auxframe(args...) = nothing +_calc_frame(args...) = nothing _no_valid_time(timedim, key, curtime) = "Time dimension over $(bounds(timedim)) of aux `$key` has no valid selection for `Contains($curtime)`." diff --git a/src/simulationdata.jl b/src/simulationdata.jl index a2bf3e1..6f06b94 100644 --- a/src/simulationdata.jl +++ b/src/simulationdata.jl @@ -2,8 +2,8 @@ """ AbstractSimData -Supertype for simulation data objects. Thes hold [`GridData`](@ref), -[`SimSettings`](@ref) and other objects needed to run the simulation, +Supertype for simulation data objects. Thes hold [`GridData`](@ref), +[`SimSettings`](@ref) and other objects needed to run the simulation, and potentially required from within rules. An `AbstractSimData` object is accessable in [`applyrule`](@ref) as the first parameter. @@ -19,7 +19,7 @@ funciton applyrule(data::AbstractSimData, rule::SomeRule{Tuple{A,B}},W}, (a, b), end ``` -In single-grid simulations `AbstractSimData` objects can be indexed directly as +In single-grid simulations `AbstractSimData` objects can be indexed directly as if they are a `Matrix`. ## Methods @@ -34,7 +34,7 @@ if they are a `Matrix`. - `padding(data)` : returns the value to use as grid border padding. These are also available, but you probably shouldn't use them and their behaviour -is not guaranteed in furture versions. Using them will also mean a rule is useful +is not guaranteed in furture versions. Using them will also mean a rule is useful only in specific contexts, which is discouraged. - `settings(data)`: get the simulaitons [`SimSettings`](@ref) object. @@ -53,6 +53,8 @@ extent(d::AbstractSimData) = d.extent frames(d::AbstractSimData) = d.frames grids(d::AbstractSimData) = d.grids auxframe(d::AbstractSimData) = d.auxframe +auxframe(d::AbstractSimData, key) = auxframe(d)[_unwrap(key)] +maskframe(d::AbstractSimData) = d.maskframe currentframe(d::AbstractSimData) = d.currentframe # Forwarded to the Extent object @@ -61,11 +63,25 @@ padval(d::AbstractSimData) = padval(extent(d)) init(d::AbstractSimData) = init(extent(d)) mask(d::AbstractSimData) = mask(extent(d)) aux(d::AbstractSimData, args...) = aux(extent(d), args...) -auxframe(d::AbstractSimData, key) = auxframe(d)[_unwrap(key)] tspan(d::AbstractSimData) = tspan(extent(d)) timestep(d::AbstractSimData) = step(tspan(d)) radius(d::AbstractSimData) = max(map(radius, grids(d))...) + +Base.@propagate_inbounds ismasked(sd::AbstractSimData, I...) = + ismasked(sd, mask(sd), I...) +Base.@propagate_inbounds ismasked(::AbstractSimData, ::Nothing, I...) = + false +Base.@propagate_inbounds ismasked(::AbstractSimData, mask::AbstractArray, I...) = + !mask[I...] +Base.@propagate_inbounds function ismasked(sd::AbstractSimData, mask::AbstractDimArray, I...) + if hasdim(mask, Ti()) + mask[I..., maskframe(sd)] + else + mask[I...] + end +end + # Calculated: # Get the current time for this frame currenttime(d::AbstractSimData) = tspan(d)[currentframe(d)] @@ -88,9 +104,11 @@ Base.size(d::AbstractSimData{S}) where S = Tuple(StaticArrays.Size(S)) @propagate_inbounds Base.getindex(d::AbstractSimData, I...) = getindex(first(grids(d)), I...) # Uptate timestamp -function _updatetime(simdata::AbstractSimData, f::Integer) +function _updatetime(simdata::AbstractSimData, f::Integer) @set! simdata.currentframe = f - @set simdata.auxframe = _calc_auxframe(simdata) + @set! simdata.auxframe = _calc_frame(aux(simdata)) + @set! simdata.maskframe = _calc_frame(mask(simdata)) + simdata end """ @@ -106,18 +124,19 @@ Additional methods not found in [`AbstractSimData`](@ref): - `rules(d::SimData)` : get the simulation rules. - `ruleset(d::SimData)` : get the simulation [`AbstractRuleset`](@ref). """ -struct SimData{S<:Tuple,N,G<:NamedTuple{<:Any,<:Tuple{<:GridData,Vararg{GridData}}},E,RS,F,CF,AF} <: AbstractSimData{S,N,G} +struct SimData{S<:Tuple,N,G<:NamedTuple{<:Any,<:Tuple{<:GridData,Vararg{GridData}}},E,RS,F,CF,AF,MF} <: AbstractSimData{S,N,G} grids::G extent::E ruleset::RS frames::F currentframe::CF auxframe::AF + maskframe::MF end function SimData{S,N}( - grids::G, extent::E, ruleset::RS, frames::F, currentframe::CF, auxframe::AF -) where {S,N,G,E,RS,F,CF,AF} - SimData{S,N,G,E,RS,F,CF,AF}(grids, extent, ruleset, frames, currentframe, auxframe) + grids::G, extent::E, ruleset::RS, frames::F, currentframe::CF, auxframe::AF, maskframe::MF +) where {S,N,G,E,RS,F,CF,AF,MF} + SimData{S,N,G,E,RS,F,CF,AF,MF}(grids, extent, ruleset, frames, currentframe, auxframe, maskframe) end SimData(o, ruleset::AbstractRuleset) = SimData(o, extent(o), ruleset) SimData(o, r1::Rule, rs::Rule...) = SimData(o, extent(o), Ruleset(r1, rs...)) @@ -130,24 +149,24 @@ end function SimData( simdata::SimData, output, extent::AbstractExtent, ruleset::AbstractRuleset ) - (replicates(simdata) == replicates(output) == replicates(extent)) || + (replicates(simdata) == replicates(output) == replicates(extent)) || throw(ArgumentError("`simdata` must have same numver of replicates as `output`")) @assert simdata.extent == StaticExtent(extent) @set! simdata.ruleset = StaticRuleset(ruleset) - if hasdelay(rules(ruleset)) + if hasdelay(rules(ruleset)) isstored(output) || _not_stored_delay_error() - @set! simdata.frames = frames(output) + @set! simdata.frames = frames(output) end return simdata end function SimData(o, extent::AbstractExtent, ruleset::AbstractRuleset) - frames_ = if hasdelay(rules(ruleset)) + frames_ = if hasdelay(rules(ruleset)) isstored(o) || _notstorederror() - frames(o) + frames(o) else - nothing + nothing end return SimData(extent, ruleset, frames_) end @@ -157,7 +176,7 @@ SimData(extent::AbstractExtent, rs::Tuple{<:Rule,Vararg}) = SimData(extent, Rule # Convert grids in extent to NamedTuple function SimData(extent::AbstractExtent, ruleset::AbstractRuleset, frames=nothing) nt_extent = _asnamedtuple(extent) - SimData(nt_extent, ruleset, frames) + SimData(nt_extent, ruleset, frames) end function SimData(extent::AbstractExtent{<:NamedTuple}, ruleset::AbstractRuleset, frames=nothing) # Calculate the stencil array for each grid @@ -168,13 +187,14 @@ end function SimData( grids::G, extent::AbstractExtent, ruleset::AbstractRuleset, frames ) where {G<:Union{<:NamedTuple{<:Any,<:Tuple{<:GridData,Vararg}},<:GridData}} - currentframe = 1; auxframe = nothing + currentframe = 1 + auxframe = maskframe = nothing S = Tuple{size(extent)...} N = ndims(extent) # SimData is isbits-only, so use Static versions s_extent = StaticExtent(extent) s_ruleset = StaticRuleset(ruleset) - SimData{S,N}(grids, s_extent, s_ruleset, frames, currentframe, auxframe) + SimData{S,N}(grids, s_extent, s_ruleset, frames, currentframe, auxframe, maskframe) end # Build the grids for the simulation from the extent, ruleset, init and padding @@ -190,7 +210,7 @@ function _buildgrids(extent, ruleset, s::Val, radii::NamedTuple) end end function _buildgrids(extent, ruleset, ::Val{S}, ::Val{R}, init, padval) where {S,R} - stencil = Window{R}() + stencil = Window{R}() padding = Halo{:out}() # We always pad out in DynamicGrids - it should pay back for multiple time steps bc = _update_padval(boundary(ruleset), padval) data = _replicate_init(init, replicates(extent)) @@ -225,42 +245,44 @@ replicates(d::SimData) = replicates(extent(d)) RuleData(extent::AbstractExtent, settings::SimSettings) -[`AbstractSimData`](@ref) object that is passed to rules. +[`AbstractSimData`](@ref) object that is passed to rules. Basically a trimmed-down version of [`SimData`](@ref). The simplified object actually passed to rules with the current design. Passing a smaller object than `SimData` to rules leads to faster GPU compilation. """ -struct RuleData{S<:Tuple,N,G<:NamedTuple,E,Se,F,CF,AF,R,V,I} <: AbstractSimData{S,N,G} +struct RuleData{S<:Tuple,N,G<:NamedTuple,E,Se,F,CF,AF,MF,R,V,I} <: AbstractSimData{S,N,G} grids::G extent::E settings::Se frames::F currentframe::CF auxframe::AF + maskframe::MF replicates::R value::V indices::I end function RuleData{S,N}( - grids::G, extent::E, settings::Se, frames::F, currentframe::CF, auxframe::AF, replicates::Re, value::V, indices::I -) where {S,N,G,E,Se,F,CF,AF,Re,V,I} - RuleData{S,N,G,E,Se,F,CF,AF,Re,V,I}(grids, extent, settings, frames, currentframe, auxframe, replicates, value, indices) + grids::G, extent::E, settings::Se, frames::F, currentframe::CF, auxframe::AF, maskframe::MF, replicates::Re, value::V, indices::I +) where {S,N,G,E,Se,F,CF,AF,MF,Re,V,I} + RuleData{S,N,G,E,Se,F,CF,AF,MF,Re,V,I}(grids, extent, settings, frames, currentframe, auxframe, maskframe, replicates, value, indices) end function RuleData(d::AbstractSimData{S,N}; - grids=grids(d), - extent=extent(d), + grids=grids(d), + extent=extent(d), settings=settings(d), - frames=frames(d), - currentframe=currentframe(d), - auxframe=auxframe(d), + frames=frames(d), + currentframe=currentframe(d), + auxframe=auxframe(d), + maskframe=maskframe(d), replicates=replicates(d), - value=nothing, + value=nothing, indices=nothing, ) where {S,N} RuleData{S,N}( - grids, extent, settings, frames, currentframe, auxframe, replicates, value, indices + grids, extent, settings, frames, currentframe, auxframe, maskframe, replicates, value, indices ) end # Thin down the aux data for just this rule. @@ -283,7 +305,7 @@ function RuleData(d::AbstractSimData, rule::Rule) return RuleData(d; extent=rule_extent) end -function Base.getindex(d::RuleData, key::Symbol) +function Base.getindex(d::RuleData, key::Symbol) grid = getindex(grids(d), key) @set grid.indices = d.indices end diff --git a/src/utils.jl b/src/utils.jl index 63439ce..efff56d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -47,7 +47,7 @@ end if rule isa SetGridRule return true else - _demo(m, RuleData(simdata), rule, Val{ruletype(rule)}()) + _demo(m, RuleData(simdata, rule), rule, Val{ruletype(rule)}()) end end |> all end @@ -143,12 +143,6 @@ end @inline _samenamedtuple(init::NamedTuple{K}, x::Tuple) where K = NamedTuple{K}(x) @inline _samenamedtuple(init::NamedTuple, x) = map(_ -> x, init) - -# Unwrap a Val or Val type to its internal value -_unwrap(x) = x -_unwrap(::Val{X}) where X = X -_unwrap(::Type{<:Val{X}}) where X = X - @inline _firstgrid(simdata, ::Val{K}) where K = simdata[K] @inline _firstgrid(simdata, ::Tuple{Val{K},Vararg}) where K = simdata[K]