Skip to content

Commit

Permalink
add timings plot
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian-sp committed Apr 30, 2024
1 parent 25787e2 commit 5237127
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 1 deletion.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ This example is taken from Example 5.3 in [1]. We minimize the q-norm $||x||_q$

[Link to example script](examples/example_residual.py)

When $q\geq 1$ the problem is convex and rather easy to solve. For $q<1$, we observed that the number of sample points need to be increased a lot in order to make the subproblems solvable in reasonable time.

This problem has dimension 256, with one scalar constraint. To give a feeling for runtime, below is the runtime per iteration (for $q=1$), split into the main parts: sample points and compute gradients (`sample_and_grad`), solve the quadratic subproblem (`subproblem`), do the update step of iterate and approximate Hessian (`step`), and all other routines.

![Timings for SQP-GS with dim 256](data/img/timings_residual.png "Timings for SQP-GS with dim 256")

### Pretrained neural network constraint

This toy example illustrates how to use a pretrained neural network as constraint function in `ncOPT`. We train a simple model to learn the mapping $(x_1,x_2) \mapsto \max(\sqrt{2}x_1, 2x_2) -1 $. Then, we load the model checkpoint to use it as constraint.
Expand Down
Binary file added data/img/timings_residual.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/example_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
# %% Plotting

fig, ax = problem.plot_timings()
fig.savefig("../data/img/timings_residual.png")
fig, ax = problem.plot_metrics()


Expand Down
2 changes: 1 addition & 1 deletion src/ncopt/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def plot_timings(timings, ax=None):
if None not in val:
summed += np.array(val)

ax.plot(summed, lw=1, ls="--", color="grey", label="summed")
# ax.plot(summed, lw=1, ls="--", color="grey", label="summed")

ax.set_xlabel("Iteration")
ax.set_ylabel("Runtime [sec]")
Expand Down

0 comments on commit 5237127

Please sign in to comment.