Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 141 additions & 4 deletions src/smith/differentiable_numerics/differentiable_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,155 @@ std::vector<DifferentiableBlockSolver::FieldPtr> LinearDifferentiableBlockSolver
return u_duals;
}

NonlinearDifferentiableBlockSolver::NonlinearDifferentiableBlockSolver(std::unique_ptr<EquationSolver> s)
: nonlinear_solver_(std::move(s))
{
}

void NonlinearDifferentiableBlockSolver::completeSetup(const std::vector<FieldT>&)
{
// initializeSolver(&nonlinear_solver_->preconditioner(), u);
}

std::vector<DifferentiableBlockSolver::FieldPtr> NonlinearDifferentiableBlockSolver::solve(
const std::vector<FieldPtr>& u_guesses,
std::function<std::vector<mfem::Vector>(const std::vector<FieldPtr>&)> residual_funcs,
std::function<std::vector<std::vector<MatrixPtr>>(const std::vector<FieldPtr>&)> jacobian_funcs) const
{
SMITH_MARK_FUNCTION;

int num_rows = static_cast<int>(u_guesses.size());
SLIC_ERROR_IF(num_rows < 0, "Number of residual rows must be non-negative");

mfem::Array<int> block_offsets;
block_offsets.SetSize(num_rows + 1);
block_offsets[0] = 0;
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_offsets[row_i + 1] = u_guesses[static_cast<size_t>(row_i)]->space().TrueVSize();
}
block_offsets.PartialSum();

auto block_u = std::make_unique<mfem::BlockVector>(block_offsets);
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_u->GetBlock(row_i) = *u_guesses[static_cast<size_t>(row_i)];
}

auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);

auto residual_op_ = std::make_unique<mfem_ext::StdFunctionOperator>(
block_u->Size(),
[&residual_funcs, num_rows, &u_guesses, &block_r](const mfem::Vector& u_, mfem::Vector& r_) {
const mfem::BlockVector* u = dynamic_cast<const mfem::BlockVector*>(&u_);
SLIC_ERROR_IF(!u, "Invalid u cast in block differentiable solver to a blocl vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = u->GetBlock(row_i);
}
auto residuals = residual_funcs(u_guesses);
// auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);
// auto block_r = dynamic_cast<mfem::BlockVector*>(&r_);
SLIC_ERROR_IF(!block_r, "Invalid r cast in block differentiable solver to a block vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
auto r = residuals[static_cast<size_t>(row_i)];
block_r->GetBlock(row_i) = r;
}
r_ = *block_r;
},
[this, &block_offsets, &u_guesses, jacobian_funcs, num_rows](const mfem::Vector& u_) -> mfem::Operator& {
const mfem::BlockVector* u = dynamic_cast<const mfem::BlockVector*>(&u_);
SLIC_ERROR_IF(!u, "Invalid u cast in block differentiable solver to a block vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = u->GetBlock(row_i);
}
block_jac_ = std::make_unique<mfem::BlockOperator>(block_offsets);
matrix_of_jacs_ = jacobian_funcs(u_guesses);
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_rows; ++j) {
auto& J = matrix_of_jacs_[static_cast<size_t>(i)][static_cast<size_t>(j)];
if (J) {
block_jac_->SetBlock(i, j, J.get());
}
}
}
return *block_jac_;
});
nonlinear_solver_->setOperator(*residual_op_);
nonlinear_solver_->solve(*block_u);

for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = block_u->GetBlock(row_i);
}

return u_guesses;
}

std::vector<DifferentiableBlockSolver::FieldPtr> NonlinearDifferentiableBlockSolver::solveAdjoint(
const std::vector<DualPtr>& u_bars, std::vector<std::vector<MatrixPtr>>& jacobian_transposed) const
{
SMITH_MARK_FUNCTION;

int num_rows = static_cast<int>(u_bars.size());
SLIC_ERROR_IF(num_rows < 0, "Number of residual rows must be non-negative");

std::vector<DifferentiableBlockSolver::FieldPtr> u_duals(static_cast<size_t>(num_rows));
for (int row_i = 0; row_i < num_rows; ++row_i) {
u_duals[static_cast<size_t>(row_i)] = std::make_shared<DifferentiableBlockSolver::FieldT>(
u_bars[static_cast<size_t>(row_i)]->space(), "u_dual_" + std::to_string(row_i));
}

mfem::Array<int> block_offsets;
block_offsets.SetSize(num_rows + 1);
block_offsets[0] = 0;
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_offsets[row_i + 1] = u_bars[static_cast<size_t>(row_i)]->space().TrueVSize();
}
block_offsets.PartialSum();

auto block_ds = std::make_unique<mfem::BlockVector>(block_offsets);
*block_ds = 0.0;

auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_r->GetBlock(row_i) = *u_bars[static_cast<size_t>(row_i)];
}

auto block_jac = std::make_unique<mfem::BlockOperator>(block_offsets);
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_rows; ++j) {
block_jac->SetBlock(i, j, jacobian_transposed[static_cast<size_t>(i)][static_cast<size_t>(j)].get());
}
}

auto& linear_solver = nonlinear_solver_->linearSolver();
linear_solver.SetOperator(*block_jac);
linear_solver.Mult(*block_r, *block_ds);

for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_duals[static_cast<size_t>(row_i)] = block_ds->GetBlock(row_i);
}

return u_duals;
}

std::shared_ptr<LinearDifferentiableSolver> buildDifferentiableLinearSolver(LinearSolverOptions linear_opts,
const smith::Mesh& mesh)
{
auto [linear_solver, precond] = smith::buildLinearSolverAndPreconditioner(linear_opts, mesh.getComm());
return std::make_shared<smith::LinearDifferentiableSolver>(std::move(linear_solver), std::move(precond));
}

std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolver(
smith::NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh)
std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolver(NonlinearSolverOptions nonlinear_opts,
LinearSolverOptions linear_opts,
const smith::Mesh& mesh)
{
auto solid_solver = std::make_unique<EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<NonlinearDifferentiableSolver>(std::move(solid_solver));
}

std::shared_ptr<NonlinearDifferentiableBlockSolver> buildDifferentiableNonlinearBlockSolver(
NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh)
{
auto solid_solver = std::make_unique<smith::EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<smith::NonlinearDifferentiableSolver>(std::move(solid_solver));
auto solid_solver = std::make_unique<EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<NonlinearDifferentiableBlockSolver>(std::move(solid_solver));
}

} // namespace smith
38 changes: 38 additions & 0 deletions src/smith/differentiable_numerics/differentiable_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace mfem {
class Solver;
class Vector;
class HypreParMatrix;
class BlockOperator;
} // namespace mfem

namespace smith {
Expand Down Expand Up @@ -175,6 +176,36 @@ class LinearDifferentiableBlockSolver : public DifferentiableBlockSolver {
mutable std::unique_ptr<mfem::Solver> mfem_preconditioner; ///< stored mfem block preconditioner
};

/// @brief Implementation of the DifferentiableBlockSolver interface for the special case of nonlinear solves with
/// linear adjoint solves
class NonlinearDifferentiableBlockSolver : public DifferentiableBlockSolver {
public:
/// @brief Construct from a linear solver and linear block precondition which may be used by the linear solver
NonlinearDifferentiableBlockSolver(std::unique_ptr<EquationSolver> s);

/// @overload
void completeSetup(const std::vector<FieldT>& us) override;

/// @overload
std::vector<FieldPtr> solve(
const std::vector<FieldPtr>& u_guesses,
std::function<std::vector<mfem::Vector>(const std::vector<FieldPtr>&)> residuals,
std::function<std::vector<std::vector<MatrixPtr>>(const std::vector<FieldPtr>&)> jacobians) const override;

/// @overload
std::vector<FieldPtr> solveAdjoint(const std::vector<DualPtr>& u_bars,
std::vector<std::vector<MatrixPtr>>& jacobian_transposed) const override;

mutable std::unique_ptr<mfem::BlockOperator>
block_jac_; ///< Need to hold an instance of a block operator to work with the mfem solver interface
mutable std::vector<std::vector<MatrixPtr>>
matrix_of_jacs_; ///< Holding vectors of block matrices to that do not going out of scope before the mfem solver
///< is done with using them in the block_jac_

mutable std::unique_ptr<EquationSolver>
nonlinear_solver_; ///< the nonlinear equation solver used for the forward pass
};

/// @brief Create a differentiable linear solver
/// @param linear_opts linear options struct
/// @param mesh mesh
Expand All @@ -189,4 +220,11 @@ std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolve
LinearSolverOptions linear_opts,
const smith::Mesh& mesh);

/// @brief Create a differentiable nonlinear block solver
/// @param nonlinear_opts nonlinear options struct
/// @param linear_opts linear options struct
/// @param mesh mesh
std::shared_ptr<NonlinearDifferentiableBlockSolver> buildDifferentiableNonlinearBlockSolver(
NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh);

} // namespace smith
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,22 @@ class DirichletBoundaryConditions {

/// @brief Specify time and space varying Dirichlet boundary conditions over a domain.
/// @param domain All dofs in this domain have boundary conditions applied to it.
/// @param component component direction to apply boundary condition to if the underlying field is a vector-field.
/// @param applied_displacement applied_displacement is a functor which takes time, and a
/// smith::tensor<double,spatial_dim> corresponding to the spatial coordinate. The functor must return a double. For
/// example: [](double t, smith::tensor<double, dim> X) { return 1.0; }
template <int spatial_dim, typename AppliedDisplacementFunction>
void setVectorBCs(const Domain& domain, int component, AppliedDisplacementFunction applied_displacement)
void setScalarBCs(const Domain& domain, AppliedDisplacementFunction applied_displacement)
{
const int field_dim = space_.GetVDim();
SLIC_ERROR_IF(component >= field_dim,
axom::fmt::format("Trying to set boundary conditions on a field with dim {}, using component {}",
field_dim, component));
auto mfem_coefficient_function = [applied_displacement](const mfem::Vector& X_mfem, double t) {
auto X = make_tensor<spatial_dim>([&X_mfem](int k) { return X_mfem[k]; });
return applied_displacement(t, X);
};

auto dof_list = domain.dof_list(&space_);
// scalar ldofs -> vector ldofs
space_.DofsToVDofs(component, dof_list);
space_.DofsToVDofs(static_cast<int>(0), dof_list);

auto component_disp_bdr_coef_ = std::make_shared<mfem::FunctionCoefficient>(mfem_coefficient_function);
bcs_.addEssential(dof_list, component_disp_bdr_coef_, space_, component);
}

/// @brief Specify time and space varying Dirichlet boundary conditions over a domain.
/// @param domain All dofs in this domain have boundary conditions applied to it.
/// @param applied_displacement applied_displacement is a functor which takes time, and a
/// smith::tensor<double,spatial_dim> corresponding to the spatial coordinate. The functor must return a double. For
/// example: [](double t, smith::tensor<double, dim> X) { return 1.0; }
template <int spatial_dim, typename AppliedDisplacementFunction>
void setScalarBCs(const Domain& domain, AppliedDisplacementFunction applied_displacement)
{
setScalarBCs<spatial_dim>(domain, 0, applied_displacement);
bcs_.addEssential(dof_list, component_disp_bdr_coef_, space_, 0);
}

/// @brief Constrain the dofs of a scalar field over a domain
Expand Down
10 changes: 8 additions & 2 deletions src/smith/differentiable_numerics/field_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,28 @@ inline std::vector<const mfem::ParFiniteElementSpace*> spaces(const std::vector<
};

/// @brief Get a vector of FieldPtr or DualFieldPtr from a vector of FieldState
inline std::vector<FiniteElementState*> getFieldPointers(std::vector<FieldState>& states)
inline std::vector<FiniteElementState*> getFieldPointers(std::vector<FieldState>& states, std::vector<FieldState> params = {})
{
std::vector<FiniteElementState*> pointers;
for (auto& t : states) {
pointers.push_back(t.get().get());
}
for (auto& p : params) {
pointers.push_back(p.get().get());
}
return pointers;
}

/// @brief Get a vector of ConstFieldPtr or ConstDualFieldPtr from a vector of FieldState
inline std::vector<const FiniteElementState*> getConstFieldPointers(const std::vector<FieldState>& states)
inline std::vector<const FiniteElementState*> getConstFieldPointers(const std::vector<FieldState>& states, const std::vector<FieldState>& params = {})
{
std::vector<const FiniteElementState*> pointers;
for (auto& t : states) {
pointers.push_back(t.get().get());
}
for (auto& p : params) {
pointers.push_back(p.get().get());
}
return pointers;
}

Expand Down
Loading
Loading