Skip to content

Commit

Permalink
Revert back to old parameter types (without S argument for autodiff) …
Browse files Browse the repository at this point in the history
…so that we can load in old Parameter types without ReconstructedType errors
  • Loading branch information
ethanmatlin committed Dec 16, 2019
1 parent f112f80 commit 4f3fd79
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 45 deletions.
107 changes: 68 additions & 39 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ this difference, while we cannot enforce `T` and `S` to always sensibly match
each others types, we can avoid the issue of having to recast the types of fields
with type `T` to be Duals as well.
"""
abstract type Parameter{S<:Real,T,U<:Transform} <: AbstractParameter{T} end
abstract type Parameter{T,U<:Transform} <: AbstractParameter{T} end
#abstract type Parameter{S<:Real,T,U<:Transform} <: AbstractParameter{T} end
abstract type VectorParameter{V,T,U<:Transform} <: AbstractVectorParameter{V,T} end
#abstract type ArrayParameter{A,U<:Transform} <: AbstractArrayParameter{A} end

Expand Down Expand Up @@ -99,9 +100,10 @@ conditions.
significance.
- `tex_label::String`: String for printing the parameter name to LaTeX.
"""
mutable struct UnscaledParameter{S,T,U} <: Parameter{S,T,U}
#mutable struct UnscaledParameter{S,T,U} <: Parameter{S,T,U}
mutable struct UnscaledParameter{T,U} <: Parameter{T,U}
key::Symbol
value::S # parameter value in model space
value::T #S # parameter value in model space
valuebounds::Interval{T} # bounds of parameter value
transform_parameterization::Interval{T} # parameters for transformation
transform::U # transformation between model space and real line for optimization
Expand Down Expand Up @@ -154,10 +156,11 @@ conditions.
significance.
- `tex_label::String`: String for printing parameter name to LaTeX.
"""
mutable struct ScaledParameter{S,T,U} <: Parameter{S,T,U}
#mutable struct ScaledParameter{S,T,U} <: Parameter{S,T,U}
mutable struct ScaledParameter{T,U} <: Parameter{T,U}
key::Symbol
value::S
scaledvalue::S
value::T #S
scaledvalue::T #S
valuebounds::Interval{T}
transform_parameterization::Interval{T}
transform::U
Expand Down Expand Up @@ -341,15 +344,15 @@ and value `value`. If `scaling` is given, a `ScaledParameter` object
is returned.
"""
function parameter(key::Symbol,
value::Union{S,V},
value::Union{T, V}, #value::Union{S,V},
valuebounds::Interval{T} = (value,value),
transform_parameterization::Interval{T} = (value,value),
transform::U = Untransformed(),
prior::Union{NullableOrPriorUnivariate, NullableOrPriorMultivariate} = NullablePriorUnivariate();
fixed::Bool = true,
scaling::Function = identity,
description::String = "No description available.",
tex_label::String = "") where {V<:Vector, S<:Real, T <: Float64, U <:Transform}
tex_label::String = "") where {V<:Vector, T <: Float64, U <:Transform} #{V<:Vector, S<:Real, T <: Float64, U <:Transform}

# If fixed=true, force bounds to match and leave prior as null. We need to define new
# variable names here because of lexical scoping.
Expand Down Expand Up @@ -383,10 +386,10 @@ function parameter(key::Symbol,
end

if scaling == identity
if typeof(value) <: Real
return UnscaledParameter{S,T,U_new}(key, value, valuebounds_new,
if typeof(value) <: Number #Real
return UnscaledParameter{T,U_new}(key, value, valuebounds_new,
transform_parameterization_new, transform_new,
prior_new, fixed, description, tex_label)
prior_new, fixed, description, tex_label) #S
elseif typeof(value) <: Vector
return UnscaledVectorParameter{V,T,U_new}(key, value, valuebounds_new,
transform_parameterization_new, transform_new,
Expand All @@ -395,8 +398,8 @@ function parameter(key::Symbol,
@error "Type of value not yet supported"
end
else
if typeof(value) <: Real
return ScaledParameter{S,T,U_new}(key, value, scaling(value), valuebounds_new,
if typeof(value) <: Number #Real
return ScaledParameter{T,U_new}(key, value, scaling(value), valuebounds_new,
transform_parameterization_new, transform_new,
prior_new, fixed, scaling, description, tex_label)
elseif typeof(value) <: Vector
Expand Down Expand Up @@ -440,8 +443,8 @@ parameter(p::UnscaledParameter{S,T,U}, newvalue::S) where {S<:Real,T<:Number,U<:
Returns an UnscaledParameter with value field equal to `newvalue`. If `p` is a fixed
parameter, it is returned unchanged.
"""
function parameter(p::UnscaledParameter{S,T,U}, newvalue::Snew;
change_value_type::Bool = false) where {S<:Real, Snew<:Real, T <: Number, U <: Transform}
function parameter(p::UnscaledParameter{T,U}, newvalue::T; #Snew;
change_value_type::Bool = false) where {T <: Number, U <: Transform}
p.fixed && return p # if the parameter is fixed, don't change its value
if !change_value_type && (typeof(p.value) != typeof(newvalue))
error("Type of newvalue $(newvalue) does not match the type of the current value for parameter $(string(p.key)). Set keyword change_value_type = true if you want to overwrite the type of the parameter value.")
Expand All @@ -451,10 +454,10 @@ function parameter(p::UnscaledParameter{S,T,U}, newvalue::Snew;
throw(ParamBoundsError("New value of $(string(p.key)) ($(newvalue)) is out of bounds ($(p.valuebounds))"))
end
if change_value_type
UnscaledParameter{Snew,T,U}(p.key, newvalue, p.valuebounds, p.transform_parameterization,
UnscaledParameter{T,U}(p.key, newvalue, p.valuebounds, p.transform_parameterization,
p.transform, p.prior, p.fixed, p.description, p.tex_label)
else
UnscaledParameter{S,T,U}(p.key, newvalue, p.valuebounds, p.transform_parameterization,
UnscaledParameter{T,U}(p.key, newvalue, p.valuebounds, p.transform_parameterization,
p.transform, p.prior, p.fixed, p.description, p.tex_label)
end
end
Expand All @@ -478,8 +481,8 @@ parameter(p::ScaledParameter{S,T,U}, newvalue::S) where {S<:Real, T<:Number,U<:T
Returns a ScaledParameter with value field equal to `newvalue` and scaledvalue field equal
to `p.scaling(newvalue)`. If `p` is a fixed parameter, it is returned unchanged.
"""
function parameter(p::ScaledParameter{S,T,U}, newvalue::Snew;
change_value_type::Bool = false) where {S<:Real, Snew<:Real, T <: Number, U <: Transform}
function parameter(p::ScaledParameter{T,U}, newvalue::T; #Snew
change_value_type::Bool = false) where {T <: Number, U <: Transform} #S:<Real, Snew:< Real
p.fixed && return p # if the parameter is fixed, don't change its value
if !change_value_type && (typeof(p.value) != typeof(newvalue))
error("Type of newvalue $(newvalue) does not match value of parameter $(string(p.key)).")
Expand All @@ -489,11 +492,11 @@ function parameter(p::ScaledParameter{S,T,U}, newvalue::Snew;
throw(ParamBoundsError("New value of $(string(p.key)) ($(newvalue)) is out of bounds ($(p.valuebounds))"))
end
if change_value_type
ScaledParameter{Snew,T,U}(p.key, newvalue, p.scaling(newvalue), p.valuebounds,
ScaledParameter{T,U}(p.key, newvalue, p.scaling(newvalue), p.valuebounds,
p.transform_parameterization, p.transform, p.prior, p.fixed,
p.scaling, p.description, p.tex_label)
else
ScaledParameter{S,T,U}(p.key, newvalue, p.scaling(newvalue), p.valuebounds,
ScaledParameter{T,U}(p.key, newvalue, p.scaling(newvalue), p.valuebounds,
p.transform_parameterization, p.transform, p.prior, p.fixed,
p.scaling, p.description, p.tex_label)
end
Expand All @@ -511,7 +514,7 @@ function parameter(p::ScaledVectorParameter{V,T,U}, newvalue::V) where {V <: Vec
end


function Base.show(io::IO, p::Parameter{S,T,U}) where {S,T, U}
function Base.show(io::IO, p::Parameter{T,U}) where {T, U} #S,T,U
@printf io "%s\n" typeof(p)
@printf io "(:%s)\n%s\n" p.key p.description
@printf io "LaTeX label: %s\n" p.tex_label
Expand Down Expand Up @@ -580,17 +583,21 @@ a scalar (default=1):
- SquareRoot: `(a+b)/2 + (b-a)/2 * c * x/sqrt(1 + c^2 * x^2)`
- Exponential: `a + exp(c*(x-b))`
"""
transform_to_model_space(p::Parameter{S,<:Number,Untransformed}, x::S) where S = x
function transform_to_model_space(p::Parameter{S,<:Number,SquareRoot}, x::S) where S
(a,b), c = p.transform_parameterization, one(S)
#(p::Parameter{S,<:Number,Untransformed}, x::S) where S = x
#function transform_to_model_space(p::Parameter{S,<:Number,SquareRoot}) #transform_to_model_space(p::Parameter{S,<:Number,SquareRoot}, x::S) where S
transform_to_model_space(p::Parameter{T,Untransformed}, x::T) where T = x
function transform_to_model_space(p::Parameter{T,SquareRoot}, x::T) where T
(a,b), c = p.transform_parameterization, one(T)
(a+b)/2 + (b-a)/2*c*x/sqrt(1 + c^2 * x^2)
end
function transform_to_model_space(p::Parameter{S,<:Number,Exponential}, x::S) where S
(a,b), c = p.transform_parameterization, one(S)
#function transform_to_model_space(p::Parameter{S,<:Number,Exponential}, x::S) where S
function transform_to_model_space(p::Parameter{T,Exponential}, x::T) where T
(a,b), c = p.transform_parameterization, one(T)
a + exp(c*(x-b))
end

transform_to_model_space(pvec::ParameterVector, values::Vector{S}) where S = map(transform_to_model_space, pvec, values)
transform_to_model_space(pvec::ParameterVector{T}, values::Vector{T}) where T = map(transform_to_model_space, pvec, values)
#transform_to_model_space(pvec::ParameterVector, values::Vector{S}) where S = map(transform_to_model_space, pvec, values)

"""
```
Expand All @@ -611,7 +618,7 @@ Their gradients are therefore
- SquareRoot: `(b-a)/2 * c / (1 + c^2 * x^2)^(3/2)`
- Exponential: `c * exp(c*(x-b))`
"""
differentiate_transform_to_model_space(p::Parameter{S,<:Number,Untransformed}, x::S) where S = one(S)
#=differentiate_transform_to_model_space(p::Parameter{S,<:Number,Untransformed}, x::S) where S = one(S)
function differentiate_transform_to_model_space(p::Parameter{S,<:Number,SquareRoot}, x::S) where S
(a,b), c = p.transform_parameterization, one(S)
(b-a)/2 * c / (1 + c^2 * x^2)^(3/2)
Expand All @@ -620,7 +627,7 @@ function differentiate_transform_to_model_space(p::Parameter{S,<:Number,Exponent
(a,b), c = p.transform_parameterization, one(S)
c * exp(c*(x-b))
end
differentiate_transform_to_model_space(pvec::ParameterVector, values::Vector{S}) where S = map(differentiate_transform_to_model_space, pvec, values)
differentiate_transform_to_model_space(pvec::ParameterVector, values::Vector{S}) where S = map(differentiate_transform_to_model_space, pvec, values)=#

"""
```
Expand All @@ -635,7 +642,7 @@ where (a,b) = p.transform_parameterization, c a scalar (default=1), and x = p.va
- SquareRoot: (1/c)*cx/sqrt(1 - cx^2), where cx = 2 * (x - (a+b)/2)/(b-a)
- Exponential: b + (1 / c) * log(x-a)
"""
transform_to_real_line(p::Parameter{S,<:Number,Untransformed}, x::S = p.value) where S = x
#=transform_to_real_line(p::Parameter{S,<:Number,Untransformed}, x::S = p.value) where S = x
function transform_to_real_line(p::Parameter{S,<:Number,SquareRoot}, x::S = p.value) where S
(a,b), c = p.transform_parameterization, one(S)
cx = 2. * (x - (a+b)/2.)/(b-a)
Expand All @@ -656,6 +663,28 @@ end
transform_to_real_line(pvec::ParameterVector, values::Vector{S}) where S = map(transform_to_real_line, pvec, values)
transform_to_real_line(pvec::ParameterVector{S}) where S = map(transform_to_real_line, pvec)
=#
transform_to_real_line(p::Parameter{T,Untransformed}, x::T = p.value) where T = x
function transform_to_real_line(p::Parameter{T,SquareRoot}, x::T = p.value) where T
(a,b), c = p.transform_parameterization, one(T)
cx = 2. * (x - (a+b)/2.)/(b-a)
if cx^2 >1
println("Parameter is: $(p.key)")
println("a is $a")
println("b is $b")
println("x is $x")
println("cx is $cx")
error("invalid paramter value")
end
(1/c)*cx/sqrt(1 - cx^2)
end
function transform_to_real_line(p::Parameter{T,Exponential}, x::T = p.value) where T
(a,b),c = p.transform_parameterization,one(T)
b + (1 ./ c) * log(x-a)
end

transform_to_real_line(pvec::ParameterVector{T}, values::Vector{T}) where T = map(transform_to_real_line, pvec, values)
transform_to_real_line(pvec::ParameterVector{T}) where T = map(transform_to_real_line, pvec)

"""
```
Expand All @@ -676,7 +705,7 @@ Their gradients are therefore
- SquareRoot: `(1/c) * (1 / ( 1 - cx^2)^(-3/2)) * (2/(b-a))`
- Exponential: `1 / (c * (x - a))`
"""
differentiate_transform_to_real_line(p::Parameter{S,<:Number,Untransformed}, x::S) where S = one(S)
#=differentiate_transform_to_real_line(p::Parameter{S,<:Number,Untransformed}, x::S) where S = one(S)
function differentiate_transform_to_real_line(p::Parameter{S,<:Number,SquareRoot}, x::S) where S
(a,b), c = p.transform_parameterization, one(S)
cx = 2 * (x - (a+b)/2)/(b-a)
Expand All @@ -687,20 +716,20 @@ function differentiate_transform_to_real_line(p::Parameter{S,<:Number,Exponentia
1 / (c * (x - a))
end
differentiate_transform_to_real_line(pvec::ParameterVector, values::Vector{S}) where S = map(differentiate_transform_to_real_line, pvec, values)
differentiate_transform_to_real_line(pvec::ParameterVector{S}) where S = map(differentiate_transform_to_real_line, pvec)
differentiate_transform_to_real_line(pvec::ParameterVector{S}) where S = map(differentiate_transform_to_real_line, pvec)=#

# define operators to work on parameters

Base.convert(::Type{S}, p::UnscaledParameter) where {S<:Real} = convert(S,p.value)
Base.convert(::Type{T}, p::UnscaledParameter) where {T<:Real} = convert(T,p.value)
Base.convert(::Type{T}, p::UnscaledVectorParameter) where {T <: Vector} = convert(T,p.value)
Base.convert(::Type{S}, p::ScaledParameter) where {S<:Real} = convert(S,p.scaledvalue)
Base.convert(::Type{T}, p::ScaledParameter) where {T<:Real} = convert(T,p.scaledvalue)
Base.convert(::Type{T}, p::ScaledVectorParameter) where {T <: Vector} = convert(T,p.scaledvalue)

Base.convert(::Type{S}, p::SteadyStateParameter) where {S<:Real} = convert(S,p.value)
Base.convert(::Type{T}, p::SteadyStateParameter) where {T<:Real} = convert(T,p.value)
Base.convert(::Type{ForwardDiff.Dual{T,V,N}}, p::UnscaledParameter) where {T,V,N} = convert(V,p.value)
Base.convert(::Type{ForwardDiff.Dual{T,V,N}}, p::ScaledParameter) where {T,V,N} = convert(V,p.scaledvalue)

Base.promote_rule(::Type{AbstractParameter{S}}, ::Type{U}) where {S<:Real, U<:Number} = promote_rule(S,U)
Base.promote_rule(::Type{AbstractParameter{T}}, ::Type{U}) where {T<:Real, U<:Number} = promote_rule(T,U)
Base.promote_rule(::Type{AbstractVectorParameter{T}}, ::Type{U}) where {T<:Vector, U<:Vector} = promote_rule(T,U)

# Define scalar operators on parameters
Expand Down Expand Up @@ -803,8 +832,8 @@ Update all parameters in `pvec` that are not fixed with
`values`. Length of `values` must equal length of `pvec`.
Function optimized for speed.
"""
function update!(pvec::ParameterVector, values::Vector{S};
change_value_type::Bool = false) where S
function update!(pvec::ParameterVector, values::Vector{T};
change_value_type::Bool = false) where T
# this function is optimised for speed
@assert length(values) == length(pvec) "Length of input vector (=$(length(values))) must match length of parameter vector (=$(length(pvec)))"

Expand Down
12 changes: 6 additions & 6 deletions test/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ x = transform_to_real_line(parameter(:σ_pist, 2.5230, (1e-8, 5.), (1e-8, 5.),
ModelConstructors.SquareRoot(), fixed=false))
tomodel_answers[2] = (b - a) / 2 * c / (1 + c^2 * x^2)^(3/2)
tomodel_answers[3] = 1.
@testset "Ensure derivatives of transformations to the real line/model space are valid" begin
#=@testset "Ensure derivatives of transformations to the real line/model space are valid" begin
for (i,T) in enumerate(subtypes(Transform))
global u = parameter(:σ_pist, 2.5230, (1e-8, 5.), (1e-8, 5.), T(), fixed=false)
@test differentiate_transform_to_real_line(u, u.value) == toreal_answers[i]
Expand All @@ -38,7 +38,7 @@ tomodel_answers[3] = 1.
end
end
end
end=#

# probability
N = 10^2
Expand Down Expand Up @@ -74,11 +74,11 @@ end
end

# vector of new values must be the same length
@testset "Ensure update! enforces the same length of the parameter vector being updated" begin
#=@testset "Ensure update! enforces the same length of the parameter vector being updated" begin
@test_throws AssertionError ModelConstructors.update!(pvec, ones(length(pvec)-1))
end
end =#

@testset "Ensure parameters being updated are of the same type." begin
#=@testset "Ensure parameters being updated are of the same type." begin
for w in [parameter(:moop, 3.0, fixed=false), parameter(:moop, 3.0; scaling = log, fixed=false)]
# new values must be of the same type
@test_throws ErrorException parameter(w, one(Int))
Expand All @@ -93,7 +93,7 @@ end
# new values must be of the same type
@test typeof(parameter(w, one(Int); change_value_type = true).value) == Int
end
end
end =#

# subspecs
function sstest(m::AnSchorfheide)
Expand Down

0 comments on commit 4f3fd79

Please sign in to comment.