Skip to content

Commit

Permalink
OptimalControlSolution from CTBase
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier Cots committed Mar 1, 2023
1 parent 50d4ad9 commit f6f56fb
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 158 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CTDirectShooting"
uuid = "e0e6c04b-5022-4cd2-bea2-4a09fff39444"
authors = ["Olivier Cots <olivier.cots@enseeiht.fr>"]
version = "0.1.2"
version = "0.1.3"

[deps]
CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd"
Expand All @@ -12,7 +12,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
CTBase = "0.2"
CTBase = "0.4"
CTOptimization = "0.1"
HamiltonianFlows = "1.0"
LinearAlgebra = "1.8"
Expand Down
15 changes: 3 additions & 12 deletions src/CTDirectShooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,17 @@ module CTDirectShooting
# using
#
using CTBase
import CTBase: DirectShootingSolution

#
using LinearAlgebra # for the norm for instance
using Printf # to print iterations results for instance

# todo: use RecipesBase instead of plot
using Plots
import Plots: plot, plot! # import instead of using to overload the plot and plot! functions

# flows
using HamiltonianFlows

# nlp solvers
using CTOptimization
import CTOptimization: solve #todo: remove this
#import CTOptimization: solve #todo: remove this

# Other declarations
const nlp_constraints = CTBase.nlp_constraints
Expand All @@ -31,21 +26,17 @@ const __iterations = CTBase.__iterations
const __absoluteTolerance = CTBase.__absoluteTolerance
const __optimalityTolerance = CTBase.__optimalityTolerance
const __stagnationTolerance = CTBase.__stagnationTolerance
const ctgradient = CTBase.ctgradient
const ctjacobian = CTBase.ctjacobian
const expand = CTBase.expand
const vec2vec = CTBase.vec2vec

# includes
include("init.jl")
include("utils.jl")
include("init.jl")
include("problem.jl")
include("solve.jl")
include("solution.jl")
include("plot.jl")
include("solve.jl")

# export functions only for user
export solve
export plot, plot!

end
9 changes: 2 additions & 7 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,8 @@ function CTOptimizationInit(t0::Time, tf::Time, m::Dimension, init::Tuple{TimesD
end

#
function CTOptimizationInit(t0::Time, tf::Time, m::Dimension, S::DirectShootingSolution, grid, interp::Function)
return CTOptimizationInit(t0, tf, m, (time_steps(S), control(S)), grid, interp)
end

#
function CTOptimizationInit(t0::Time, tf::Time, m::Dimension, S::DirectSolution, grid, interp::Function)
return CTOptimizationInit(t0, tf, m, (time_steps(S), control(S)), grid, interp)
function CTOptimizationInit(t0::Time, tf::Time, m::Dimension, S::OptimalControlSolution, grid, interp::Function)
return CTOptimizationInit(t0, tf, m, control(S), grid, interp)
end

#
Expand Down
119 changes: 0 additions & 119 deletions src/plot.jl

This file was deleted.

43 changes: 42 additions & 1 deletion src/solution.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@

# ------------------------------------------------------------------------------------
# Direct shooting solution
#
struct DirectShootingSolution <: AbstractOptimalControlSolution
T::TimesDisc # the times
X::States # the states at the times T
U::Controls # the controls at T
P::Adjoints # the adjoint at T
objective::MyNumber
state_dimension::Dimension # the dimension of the state
control_dimension::Dimension # the dimension of the control
stopping::Symbol # the stopping criterion
message::String # the message corresponding to the stopping criterion
success::Bool # whether or not the method has finished successfully: CN1, stagnation vs iterations max
iterations::Integer # the number of iterations
end

function DirectShootingSolution(sol::CTOptimization.UnconstrainedSolution,
ocp::OptimalControlModel, grid::TimesDisc, penalty_constraint::Real)

Expand Down Expand Up @@ -42,7 +60,30 @@ function DirectShootingSolution(sol::CTOptimization.UnconstrainedSolution,
end
objective = J(U⁺)

return CTBase.DirectShootingSolution(T, X⁺, U⁺, P⁺, objective, n, m,
dssol = DirectShootingSolution(T, X⁺, U⁺, P⁺, objective, n, m,
sol.stopping, sol.message, sol.success, sol.iterations)

return _OptimalControlSolution(ocp, dssol)

end

function _OptimalControlSolution(ocp::OptimalControlModel, dssol::DirectShootingSolution)
x = ctinterpolate(dssol.T, dssol.X) # je ne peux pas donner directement la sortie de ctinterpolate car ce n'est pas une Function
u = ctinterpolate(dssol.T[1:end-1], dssol.U)
p = ctinterpolate(dssol.T, dssol.P)
sol = OptimalControlSolution()
sol.state_dimension = dssol.state_dimension
sol.control_dimension = dssol.control_dimension
sol.times = dssol.T
sol.state = t -> x(t)
sol.state_labels = ocp.state_labels # update CTBase to have a getter
sol.adjoint = t -> p(t)
sol.control = t -> u(t)
sol.control_labels = ocp.control_labels
sol.objective = dssol.objective
sol.iterations = dssol.iterations
sol.stopping = dssol.stopping
sol.message = dssol.message
sol.success = dssol.success
return sol
end
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ algorithmes = add(algorithmes, (:descent, :gradient, :fixedstep))
function CTDirectShooting.solve(
ocp::OptimalControlModel,
description...;
init::Union{Nothing,Controls,Tuple{TimesDisc,Controls},Function,DirectShootingSolution,DirectSolution}=nothing,
init::Union{Nothing,Controls,Tuple{TimesDisc,Controls},Function,OptimalControlSolution}=nothing,
grid::Union{Nothing,TimesDisc}=nothing,
penalty_constraint::Real=__penalty_constraint(),
display::Bool=__display(),
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd"
CTProblemLibrary = "0649932a-8c77-4f67-b1e4-c19ddd080280"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using CTDirectShooting
using Test
using CTBase
using CTProblemLibrary
using Plots

# CTDirectShooting
const CTOptimizationInit = CTDirectShooting.CTOptimizationInit
Expand Down
25 changes: 11 additions & 14 deletions test/test_CTOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ prob = Problem(:integrator, :dim2, :energy)
ocp = prob.model

# solution
u_sol(t) = prob.solution.control_function(t)[1]
u_sol(t) = prob.solution.control(t)[1]

#
t0 = initial_time(ocp)
Expand Down Expand Up @@ -137,16 +137,13 @@ u_init(t) = [u_sol(t)-1.0]

# resolution with different init
common_args = (iterations=5, display=false)
sol = CTDirectShooting.solve(ocp, :descent, init=nothing, grid=nothing; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=nothing, grid=T; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=U, grid=nothing; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=U, grid=T; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=(T,U), grid=nothing; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=(T,U), grid=T_; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=sol, grid=nothing; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=u_init, grid=nothing; common_args...); @test typeof(sol) == DirectShootingSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=u_init, grid=T; common_args...); @test typeof(sol) == DirectShootingSolution;

# plots
@test typeof(plot(sol)) == Plots.Plot{Plots.GRBackend}
@test typeof(plot(sol, :time, (:control, 1))) == Plots.Plot{Plots.GRBackend}
sol = CTDirectShooting.solve(ocp, :descent, init=nothing, grid=nothing; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=nothing, grid=T; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=U, grid=nothing; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=U, grid=T; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=(T,U), grid=nothing; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=(T,U), grid=T_; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=sol, grid=nothing; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=sol, grid=T; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=u_init, grid=nothing; common_args...); @test typeof(sol) == OptimalControlSolution;
sol = CTDirectShooting.solve(ocp, :descent, init=u_init, grid=T; common_args...); @test typeof(sol) == OptimalControlSolution;
7 changes: 7 additions & 0 deletions test/test_plot_manual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using CTDirectShooting
using CTProblemLibrary
using CTBase

prob = Problem(:integrator, :dim2, :energy); ocp = prob.model
sol = solve(ocp) # print problem
plot(sol)

0 comments on commit f6f56fb

Please sign in to comment.