Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo committed Apr 23, 2019
2 parents bb426cc + 3ef0624 commit b540df0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 38 deletions.
71 changes: 34 additions & 37 deletions src/abstract_block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ setiparams!(x::AbstractBlock, it::Symbol) = setiparams!(x, render_params(x, it))
Set parameters of `block` to the value in `collection` mapped by `f`.
"""
setiparams!(f::Function, x::AbstractBlock, it) = setiparams!(x, map(x->f(x...), zip(getiparams(x), it)))
setiparams!(f::Nothing, x::AbstractBlock, it) = setiparams!(x, it)

"""
setiparams(f, block, symbol)
Expand All @@ -152,7 +153,7 @@ setiparams!(f::Function, x::AbstractBlock, it::Symbol) = setiparams!(f, x, rende
Returns all the parameters contained in block tree with given root `block`.
"""
@interface parameters(x::AbstractBlock) = parameters!(allparams_eltype(x)[], x)
@interface parameters(x::AbstractBlock) = parameters!(parameters_eltype(x)[], x)

"""
parameters!(out, block)
Expand All @@ -178,67 +179,63 @@ Return number of parameters in `block`. See also [`nparameters`](@ref).
end

"""
params_eltype(block)
iparams_eltype(block)
Return the element type of [`getiparams`](@ref).
"""
@interface params_eltype(x::AbstractBlock) = eltype(getiparams(x))
@interface iparams_eltype(x::AbstractBlock) = eltype(getiparams(x))

"""
allparams_eltype(x)
parameters_eltype(x)
Return the element type of [`parameters`](@ref).
"""
@interface function allparams_eltype(x::AbstractBlock)
T = params_eltype(x)
@interface function parameters_eltype(x::AbstractBlock)
T = iparams_eltype(x)
for each in subblocks(x)
T = promote_type(T, allparams_eltype(each))
T = promote_type(T, parameters_eltype(each))
end
return T
end

mutable struct Dispatcher{VT}
params::VT
loc::Int
end

Dispatcher(params) = Dispatcher(params, 0)

function consume!(d::Dispatcher, n::Int)
d.loc += n
d.params[d.loc-n+1:d.loc]
end

function consume!(d::Dispatcher{<:Symbol}, n::Int)
d.loc += n
d.params
end

"""
dispatch!(x::AbstractBlock, collection)
Dispatch parameters in collection to block tree `x`.
"""
@interface function dispatch!(f::Function, x::AbstractBlock, it)
@assert length(it) == nparameters(x) "expect $(nparameters(x)) parameters, got $(length(it))"
setiparams!(f, x, Iterators.take(it, nparameters(x)))
it = Iterators.drop(it, nparameters(x))
for each in subblocks(x)
dispatch!(f, each, Iterators.take(it, niparams(each)))
it = Iterators.drop(it, niparams(each))
end
return x
end

function dispatch!(f::Function, x::AbstractBlock, it::Symbol)
setiparams!(f, x, it)
@interface function dispatch!(f::Union{Function, Nothing}, x::AbstractBlock, it::Dispatcher)
setiparams!(f, x, consume!(it, niparams(x)))
for each in subblocks(x)
dispatch!(f, each, it)
end
return x
end

@interface function dispatch!(x::AbstractBlock, it)
@assert length(it) == nparameters(x) "expect $(nparameters(x)) parameters, got $(length(it))"
setiparams!(x, Iterators.take(it, niparams(x)))
it = Iterators.drop(it, niparams(x))
for each in subblocks(x)
dispatch!(each, Iterators.take(it, niparams(each)))
it = Iterators.drop(it, niparams(each))
end
return x
@interface function dispatch!(f::Union{Function, Nothing}, x::AbstractBlock, it)
dp = Dispatcher(it)
res = dispatch!(f, x, dp)
@assert (it isa Symbol || length(it) == dp.loc) "expect $(nparameters(x)) parameters, got $(length(it))"
res
end

function dispatch!(x::AbstractBlock, it::Symbol)
setiparams!(x, it)
for each in subblocks(x)
dispatch!(each, it)
end
return x
end
dispatch!(x::AbstractBlock, it) = dispatch!(nothing, x, it)

"""
popdispatch!(f, block, list)
Expand Down Expand Up @@ -271,7 +268,7 @@ end
render_params(r::AbstractBlock, params) = params
render_params(r::AbstractBlock, params::Symbol) = render_params(r, Val(params))
render_params(r::AbstractBlock, ::Val{:random}) = (rand() for i=1:niparams(r))
render_params(r::AbstractBlock, ::Val{:zero}) = (zero(params_eltype(r)) for i in 1:niparams(r))
render_params(r::AbstractBlock, ::Val{:zero}) = (zero(iparams_eltype(r)) for i in 1:niparams(r))

"""
HasParameters{X} <: SimpleTraits.Trait
Expand Down
3 changes: 2 additions & 1 deletion src/composite/tag/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
CacheServers.iscached(c::CachedBlock) = iscached(c.server, c.content)
iscacheable(c::CachedBlock) = iscacheable(c.server, c.content)
chsubblocks(cb::CachedBlock, blk::AbstractBlock) = CachedBlock(cb.server, blk, cb.level)
occupied_locs(x::CachedBlock) = occupied_locs(parent(x))
occupied_locs(x::CachedBlock) = occupied_locs(content(x))
PreserveStyle(::CachedBlock) = PreserveAll()

function update_cache(c::CachedBlock)
Expand Down Expand Up @@ -94,6 +94,7 @@ function apply!(r::AbstractRegister, c::CachedBlock, signal)
end
return r
end

apply!(r::ArrayReg, c::CachedBlock) = (r.state .= mat(c) * r.state; r)

Base.similar(c::CachedBlock, level::Int) = CachedBlock(c.server, c.content, level)
Expand Down
4 changes: 4 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
# deprecations
@deprecate iparameters(args...) getiparams(args...)
@deprecate setiparameters!(args...) setiparams!(args...)
@deprecate niparameters(args...) niparams(args...)
@deprecate parameter_type(args...) parameters_eltype(args...)

0 comments on commit b540df0

Please sign in to comment.