Skip to content

Commit

Permalink
Added a casadi function interface with tests (example to come)
Browse files Browse the repository at this point in the history
  • Loading branch information
pariterre committed Dec 3, 2024
1 parent 30488fb commit f1902e9
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 24 deletions.
25 changes: 1 addition & 24 deletions bioptim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,48 +165,26 @@
"""

from .dynamics.configure_problem import ConfigureProblem, DynamicsFcn, DynamicsList, Dynamics
from .dynamics.configure_problem import ConfigureProblem, DynamicsFcn, DynamicsList, Dynamics
from .dynamics.dynamics_evaluation import DynamicsEvaluation
from .dynamics.dynamics_evaluation import DynamicsEvaluation
from .dynamics.dynamics_functions import DynamicsFunctions
from .dynamics.dynamics_functions import DynamicsFunctions
from .dynamics.fatigue.effort_perception import EffortPerception, TauEffortPerception
from .dynamics.fatigue.effort_perception import EffortPerception, TauEffortPerception
from .dynamics.fatigue.fatigue_dynamics import FatigueList
from .dynamics.fatigue.fatigue_dynamics import FatigueList
from .dynamics.fatigue.michaud_fatigue import MichaudFatigue, MichaudTauFatigue
from .dynamics.fatigue.michaud_fatigue import MichaudFatigue, MichaudTauFatigue
from .dynamics.fatigue.xia_fatigue import XiaFatigue, XiaTauFatigue, XiaFatigueStabilized
from .dynamics.fatigue.xia_fatigue import XiaFatigue, XiaTauFatigue, XiaFatigueStabilized
from .dynamics.ode_solver import OdeSolver, OdeSolverBase
from .dynamics.ode_solver import OdeSolver, OdeSolverBase
from .gui.online_callback_server import PlottingServer
from .gui.online_callback_server import PlottingServer
from .gui.plot import CustomPlot
from .gui.plot import CustomPlot
from .interfaces import Solver
from .interfaces import Solver
from .limits.constraints import ConstraintFcn, ConstraintList, Constraint, ParameterConstraintList
from .interfaces import Solver, CasadiFunctionInterface
from .limits.constraints import ConstraintFcn, ConstraintList, Constraint, ParameterConstraintList
from .limits.fatigue_path_conditions import FatigueBounds, FatigueInitialGuess
from .limits.fatigue_path_conditions import FatigueBounds, FatigueInitialGuess
from .limits.multinode_constraint import MultinodeConstraintFcn, MultinodeConstraintList, MultinodeConstraint
from .limits.multinode_constraint import MultinodeConstraintFcn, MultinodeConstraintList, MultinodeConstraint
from .limits.multinode_objective import MultinodeObjectiveFcn, MultinodeObjectiveList, MultinodeObjective
from .limits.multinode_objective import MultinodeObjectiveFcn, MultinodeObjectiveList, MultinodeObjective
from .limits.objective_functions import ObjectiveFcn, ObjectiveList, Objective, ParameterObjectiveList
from .limits.objective_functions import ObjectiveFcn, ObjectiveList, Objective, ParameterObjectiveList
from .limits.path_conditions import BoundsList, InitialGuessList, Bounds, InitialGuess
from .limits.path_conditions import BoundsList, InitialGuessList, Bounds, InitialGuess
from .limits.penalty_controller import PenaltyController
from .limits.penalty_controller import PenaltyController
from .limits.penalty_helpers import PenaltyHelpers
from .limits.penalty_helpers import PenaltyHelpers
from .limits.phase_transition import PhaseTransitionFcn, PhaseTransitionList, PhaseTransition
from .limits.phase_transition import PhaseTransitionFcn, PhaseTransitionList, PhaseTransition
from .misc.__version__ import __version__
from .misc.__version__ import __version__
from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero
from .misc.casadi_expand import lt, le, gt, ge, if_else, if_else_zero
from .misc.enums import (
Axis,
Expand All @@ -227,7 +205,6 @@
OnlineOptim,
)
from .misc.mapping import BiMappingList, BiMapping, Mapping, NodeMapping, NodeMappingList, SelectionMapping, Dependency
from .misc.mapping import BiMappingList, BiMapping, Mapping, NodeMapping, NodeMappingList, SelectionMapping, Dependency
from .models.biorbd.biorbd_model import BiorbdModel
from .models.biorbd.external_forces import ExternalForceSetTimeSeries
from .models.biorbd.holonomic_biorbd_model import HolonomicBiorbdModel
Expand Down
1 change: 1 addition & 0 deletions bioptim/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .ipopt_options import IPOPT
from .acados_options import ACADOS
from .sqp_options import SQP_METHOD
from .casadi_function_interface import CasadiFunctionInterface


class Solver:
Expand Down
153 changes: 153 additions & 0 deletions bioptim/interfaces/casadi_function_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from abc import ABC, abstractmethod

from casadi import Callback, Function, Sparsity, DM, MX, SX
import numpy as np


class CasadiFunctionInterface(Callback, ABC):
def __init__(self, name: str, opts={}):
self.reverse_function = None

super(CasadiFunctionInterface, self).__init__()
self.construct(name, opts) # Defines the self.mx_in()
self._cached_mx_in = super().mx_in()

@abstractmethod
def inputs_len(self) -> list[int]:
"""
The len of the inputs of the function. This will help create the MX/SX vectors such that each element of the list
is the length of the input vector (i.e. the sparsity of the input vector).
Example:
def inputs_len(self) -> list[int]:
return [3, 4] # Assuming two inputs x and y of length 3 and 4 respectively
"""
pass

@abstractmethod
def outputs_len(self) -> list[int]:
"""
The len of the outputs of the function. This will help create the MX/SX vectors such that each element of the list
is the length of the output vector (i.e. the sparsity of the output vector).
Example:
def outputs_len(self) -> list[int]:
return [5] # Assuming the output is a 5x1 vector
"""
pass

@abstractmethod
def function(self, *args) -> np.ndarray | DM:
"""
The actual function to interface with casadi. The callable that returns should be callable by function(*mx_in).
If your function needs more parameters, they should be encapsulated in a partial.
Example:
def function(self, x, y):
x = np.array(x)[:, 0]
y = np.array(y)[:, 0]
return np.array(
[
x[0] * y[1] + x[0] * y[0] * y[0],
x[1] * x[1] + 2 * y[1],
x[0] * x[1] * x[2],
x[2] * x[1] * y[2] + 2 * y[3] * y[2],
y[0] * y[1] * y[2] * y[3],
]
)
"""
pass

@abstractmethod
def jacobians(self, *args) -> list[np.ndarray | DM]:
"""
All the jacobians evaluated at *args. Each of the jacobian should be of the shape (n_out, n_in), where n_out is
the length of the output vector (the same for all) and n_in is the length of the input element (specific to each
input element).
Example:
def jacobians(self, x, y):
x = np.array(x)[:, 0]
y = np.array(y)[:, 0]
jacobian_x = np.array(
[
[y[1] + y[0] * y[0], 0, 0],
[0, 2 * x[1], 0],
[x[1] * x[2], x[0] * x[2], x[0] * x[1]],
[0, x[2] * y[2], x[1] * y[2]],
[0, 0, 0],
]
)
jacobian_y = np.array(
[
[x[0] * 2 * y[0], x[0], 0, 0],
[0, 2, 0, 0],
[0, 0, 0, 0],
[0, 0, x[1] * x[2] + 2 * y[3], 2 * y[2]],
[y[1] * y[2] * y[3], y[0] * y[2] * y[3], y[0] * y[1] * y[3], y[0] * y[1] * y[2]],
]
)
return [jacobian_x, jacobian_y] # There are as many jacobians as there are inputs
"""
pass

def mx_in(self) -> MX:
"""
Get the MX in, but it is ensured that the MX are the same at each call
"""
return self._cached_mx_in

def get_n_in(self):
return len(self.inputs_len())

def get_n_out(self):
return len(self.outputs_len())

def get_sparsity_in(self, i):
return Sparsity.dense(self.inputs_len()[i], 1)

def get_sparsity_out(self, i):
return Sparsity.dense(self.outputs_len()[i], 1)

def eval(self, *args):
return [self.function(*args[0])]

def has_reverse(self, nadj):
return nadj == 1

def get_reverse(self, nadj, name, inames, onames, opts):
class Reverse(Callback):
def __init__(self, parent, jacobian_functions, opts={}):
self._sparsity_in = parent.mx_in() + parent.mx_out()
self._sparsity_out = parent.mx_in()

self.jacobian_functions = jacobian_functions
Callback.__init__(self)
self.construct("Reverse", opts)

def get_n_in(self):
return len(self._sparsity_in)

def get_n_out(self):
return len(self._sparsity_out)

def get_sparsity_in(self, i):
return Sparsity.dense(self._sparsity_in[i].shape)

def get_sparsity_out(self, i):
return Sparsity.dense(self._sparsity_out[i].shape)

def eval(self, arg):
# Find the index to evaluate from the last parameter which is a DM vector of 0s with one value being 1
index = arg[-1].toarray()[:, 0].tolist().index(1.0)
inputs = arg[:-1]
return [jaco[index, :].T for jaco in self.jacobian_functions(*inputs)]

# Package it in the [nominal_in + nominal_out + adj_seed] form that CasADi expects
if self.reverse_function is None:
self.reverse_function = Reverse(self, self.jacobians)

cx_in = self.mx_in()
nominal_out = self.mx_out()
adj_seed = self.mx_out()
return Function(name, cx_in + nominal_out + adj_seed, self.reverse_function(*cx_in, adj_seed[0]))
110 changes: 110 additions & 0 deletions tests/shard5/test_casadi_function_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from casadi import MX, vertcat, Function, jacobian
import numpy as np
import numpy.testing as npt
from bioptim import CasadiFunctionInterface


class CasadiFunctionInterfaceTest(CasadiFunctionInterface):
"""
This example implements a somewhat simple 5x1 function, with x and y inputs (x => 3x1; y => 4x1) of the form
f(x, y) = np.array(
[
x[0] * y[1] + y[0] * y[0],
x[1] * x[1] + 2 * y[1],
x[0] * x[1] * x[2],
x[2] * x[1] + 2 * y[3] * y[2],
y[0] * y[1] * y[2] * y[3],
]
)
It implements the equation (5x1) and the jacobians for the inputs x (5x3) and y (5x4).
"""

def __init__(self, opts={}):
super(CasadiFunctionInterfaceTest, self).__init__("CasadiFunctionInterfaceTest", opts)

def inputs_len(self) -> list[int]:
return [3, 4]

def outputs_len(self) -> list[int]:
return [5]

def function(self, *args):
x, y = args
x = np.array(x)[:, 0]
y = np.array(y)[:, 0]
return np.array(
[
x[0] * y[1] + x[0] * y[0] * y[0],
x[1] * x[1] + 2 * y[1],
x[0] * x[1] * x[2],
x[2] * x[1] * y[2] + 2 * y[3] * y[2],
y[0] * y[1] * y[2] * y[3],
]
)

def jacobians(self, *args):
x, y = args
x = np.array(x)[:, 0]
y = np.array(y)[:, 0]
jacobian_x = np.array(
[
[y[1] + y[0] * y[0], 0, 0],
[0, 2 * x[1], 0],
[x[1] * x[2], x[0] * x[2], x[0] * x[1]],
[0, x[2] * y[2], x[1] * y[2]],
[0, 0, 0],
]
)
jacobian_y = np.array(
[
[x[0] * 2 * y[0], x[0], 0, 0],
[0, 2, 0, 0],
[0, 0, 0, 0],
[0, 0, x[1] * x[2] + 2 * y[3], 2 * y[2]],
[y[1] * y[2] * y[3], y[0] * y[2] * y[3], y[0] * y[1] * y[3], y[0] * y[1] * y[2]],
]
)
return [jacobian_x, jacobian_y]


def test_penalty_minimize_time():
"""
These tests seem to test the interface, but actually all the internal methods are also called, which is what should
be tested.
"""

# Computing the example
interface_test = CasadiFunctionInterfaceTest()

# Testing the interface
npt.assert_equal(interface_test.inputs_len(), [3, 4])
npt.assert_equal(interface_test.outputs_len(), [5])
assert id(interface_test.mx_in()) == id(interface_test.mx_in()) # Calling twice returns the same object

# Test the class can be called with DM
x_num = np.array([1.1, 2.3, 3.5])
y_num = np.array([4.2, 5.4, 6.6, 7.7])
npt.assert_almost_equal(interface_test(x_num, y_num), np.array([[25.344, 16.09, 8.855, 154.77, 1152.5976]]).T)

# Test the jacobian is correct
x = MX.sym("x", interface_test.inputs_len()[0], 1)
y = MX.sym("y", interface_test.inputs_len()[1], 1)
jaco_x = Function("jaco_x", [x, y], [jacobian(interface_test(x, y), x)])
jaco_y = Function("jaco_y", [x, y], [jacobian(interface_test(x, y), y)])

# Computing the same equations (and derivative) by casadi
real = vertcat(
x[0] * y[1] + x[0] * y[0] * y[0],
x[1] * x[1] + 2 * y[1],
x[0] * x[1] * x[2],
x[2] * x[1] * y[2] + 2 * y[3] * y[2],
y[0] * y[1] * y[2] * y[3],
)
real_function = Function("real", [x, y], [real])
jaco_x_real = Function("jaco_x_real", [x, y], [jacobian(real, x)])
jaco_y_real = Function("jaco_y_real", [x, y], [jacobian(real, y)])

npt.assert_almost_equal(np.array(interface_test(x_num, y_num)), real_function(x_num, y_num))
npt.assert_almost_equal(np.array(jaco_x(x_num, y_num)), jaco_x_real(x_num, y_num))
npt.assert_almost_equal(np.array(jaco_y(x_num, y_num)), jaco_y_real(x_num, y_num))

0 comments on commit f1902e9

Please sign in to comment.