Skip to content

Commit

Permalink
Improve CasADi Function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
tttapa committed Jul 25, 2024
1 parent cf070cd commit e65d848
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,26 @@ namespace casadi {
/// Designed to match (part of) the `casadi::Function` API.
class CASADI_LOADER_EXPORT Function {
public:
struct Functions {
fname_incref::signature_t *incref = nullptr;
fname_decref::signature_t *decref = nullptr;
fname_n_in::signature_t *n_in = nullptr;
fname_n_out::signature_t *n_out = nullptr;
fname_name_in::signature_t *name_in = nullptr;
fname_name_out::signature_t *name_out = nullptr;
fname_sparsity_in::signature_t *sparsity_in = nullptr;
fname_sparsity_out::signature_t *sparsity_out = nullptr;
fname_alloc_mem::signature_t *alloc_mem = nullptr;
fname_init_mem::signature_t *init_mem = nullptr;
fname_free_mem::signature_t *free_mem = nullptr;
fname_work::signature_t *work = nullptr;
fname::signature_t *call = nullptr;
};

public:
Function();
Function(std::shared_ptr<void> so_handle, const std::string &func_name);
Function(const Functions &functions);
Function(const Function &);
Function(Function &&) noexcept;
~Function();
Expand Down Expand Up @@ -80,29 +99,15 @@ class CASADI_LOADER_EXPORT Function {

private:
std::shared_ptr<void> so_handle;
struct Functions {
fname_incref::signature_t *incref = nullptr;
fname_decref::signature_t *decref = nullptr;
fname_n_in::signature_t *n_in = nullptr;
fname_n_out::signature_t *n_out = nullptr;
fname_name_in::signature_t *name_in = nullptr;
fname_name_out::signature_t *name_out = nullptr;
fname_sparsity_in::signature_t *sparsity_in = nullptr;
fname_sparsity_out::signature_t *sparsity_out = nullptr;
fname_alloc_mem::signature_t *alloc_mem = nullptr;
fname_init_mem::signature_t *init_mem = nullptr;
fname_free_mem::signature_t *free_mem = nullptr;
fname_work::signature_t *work = nullptr;
fname::signature_t *call = nullptr;
} functions;
Functions functions;
struct Work {
std::vector<const casadi_real *> arg;
std::vector<casadi_real *> res;
std::vector<casadi_int> iw;
std::vector<casadi_real> w;
};
std::optional<Work> work;
void *mem = nullptr;
int mem = 0;
};

inline std::pair<casadi_int, casadi_int> Function::Sparsity::size() const {
Expand Down
8 changes: 4 additions & 4 deletions src/interop/casadi/include/alpaqa/casadi/casadi-functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ using fname_name_in = ExternalFunction<"_name_in", const char *(casadi_int
using fname_name_out = ExternalFunction<"_name_out", const char *(casadi_int ind)>;
using fname_sparsity_in = ExternalFunction<"_sparsity_in", const casadi_int *(casadi_int ind)>;
using fname_sparsity_out = ExternalFunction<"_sparsity_out", const casadi_int *(casadi_int ind)>;
using fname_alloc_mem = ExternalFunction<"_alloc_mem", void *(void)>;
using fname_init_mem = ExternalFunction<"_init_mem", int(void *mem)>;
using fname_free_mem = ExternalFunction<"_free_mem", int(void *mem)>;
using fname_alloc_mem = ExternalFunction<"_alloc_mem", int(void)>;
using fname_init_mem = ExternalFunction<"_init_mem", int(int mem)>;
using fname_free_mem = ExternalFunction<"_free_mem", void(int mem)>;
using fname_work = ExternalFunction<"_work", int(casadi_int *sz_arg, casadi_int *sz_res, casadi_int *sz_iw, casadi_int *sz_w)>;
using fname = ExternalFunction<"", int(const casadi_real **arg, casadi_real **res, casadi_int *iw, casadi_real *w, void *mem)>;
using fname = ExternalFunction<"", int(const casadi_real **arg, casadi_real **res, casadi_int *iw, casadi_real *w, int mem)>;
// clang-format on

template <Name Nm, class Sgn>
Expand Down
10 changes: 8 additions & 2 deletions src/interop/casadi/src/casadi-external-function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,27 @@ void Function::init_work() {
w.w.resize(static_cast<size_t>(sz_w));
}

Function::Function() = default;
Function::Function(std::shared_ptr<void> so_handle,
const std::string &func_name)
: so_handle{std::move(so_handle)} {
load(this->so_handle.get(), func_name);
}
static char no_handle;
Function::Function(const Functions &functions)
: so_handle{&no_handle, [](void *) {}}, functions{functions} {
functions.incref();
}
Function::Function(const Function &o)
: so_handle{o.so_handle}, functions{o.functions} {
functions.incref();
}
Function::Function(Function &&o) noexcept
: so_handle{std::move(o.so_handle)}, functions{o.functions},
work{std::move(o.work)}, mem{std::exchange(o.mem, nullptr)} {}
work{std::move(o.work)}, mem{std::exchange(o.mem, 0)} {}
Function::~Function() {
if (so_handle) {
if (mem)
if (work)
functions.free_mem(mem);
functions.decref();
}
Expand Down

0 comments on commit e65d848

Please sign in to comment.