From cd7a6b71b70f3674d5a954eeb882f66bdfb0b4d8 Mon Sep 17 00:00:00 2001 From: phschiele Date: Thu, 7 Mar 2024 13:32:31 +0100 Subject: [PATCH 1/7] New branch --- .github/workflows/unit_tests.yml | 35 +- .pre-commit-config.yaml | 9 + dev/debug_sp_class.py | 4 +- example_rosenbrock.py | 45 ++- ncopt/tests/test_rosenbrock.py | 47 --- pyproject.toml | 103 ++++++ src/ncopt/__about__.py | 1 + {ncopt => src/ncopt}/__init__.py | 1 - {ncopt => src/ncopt}/funs.py | 76 +++-- {ncopt => src/ncopt}/sqpgs.py | 527 ++++++++++++++++++------------ {ncopt => src/ncopt}/torch_obj.py | 37 ++- tests/test_rosenbrock.py | 43 +++ train_max_fun.py | 128 ++++---- 13 files changed, 640 insertions(+), 416 deletions(-) create mode 100644 .pre-commit-config.yaml delete mode 100755 ncopt/tests/test_rosenbrock.py create mode 100644 pyproject.toml create mode 100644 src/ncopt/__about__.py rename {ncopt => src/ncopt}/__init__.py (50%) rename {ncopt => src/ncopt}/funs.py (56%) rename {ncopt => src/ncopt}/sqpgs.py (57%) rename {ncopt => src/ncopt}/torch_obj.py (61%) create mode 100755 tests/test_rosenbrock.py diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index e9e5879..9e39d0f 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -9,20 +9,25 @@ on: workflow_dispatch: jobs: - build: + lint: + name: Lint runs-on: ubuntu-latest steps: - - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - - name: Test with pytest - run: | - pip install pytest - pytest ncopt/tests/ \ No newline at end of file + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + - uses: pre-commit/action@v3.0.1 + + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + - run: | + pipx install hatch + hatch env create + hatch run test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a0eb12b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.0 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/dev/debug_sp_class.py b/dev/debug_sp_class.py index d8f4a76..dece296 100755 --- a/dev/debug_sp_class.py +++ b/dev/debug_sp_class.py @@ -1,7 +1,7 @@ # test Subproblem class -#%% +# %% # dim = 10 # nI = 3 # nE = 2 @@ -27,7 +27,7 @@ # inG = SP.inG -# SP.solve() +# SP.solve() # d_k=SP.d diff --git a/example_rosenbrock.py b/example_rosenbrock.py index aabcd2c..1ce70d1 100755 --- a/example_rosenbrock.py +++ b/example_rosenbrock.py @@ -1,52 +1,51 @@ """ author: Fabian Schaipp """ -import numpy as np + import matplotlib.pyplot as plt +import numpy as np +from ncopt.funs import f_rosenbrock, g_linear, g_max from ncopt.sqpgs import SQP_GS -from ncopt.funs import f_rosenbrock, g_max, g_linear -#from ncopt.torch_obj import Net -#%% +# from ncopt.torch_obj import Net + +# %% f = f_rosenbrock() g = g_max() -A = np.eye(2); b = np.ones(2)*5; g1 = g_linear(A, b) -#D = Net(model) +A = np.eye(2) +b = np.ones(2) * 5 +g1 = g_linear(A, b) +# D = Net(model) # inequality constraints (list of functions) gI = [g] # equality constraints (list of scalar functions) gE = [] -xstar = np.array([1/np.sqrt(2), 0.5]) +xstar = np.array([1 / np.sqrt(2), 0.5]) -#%% -X, Y = np.meshgrid(np.linspace(-2,2,100), np.linspace(-2,2,100)) +# %% +X, Y = np.meshgrid(np.linspace(-2, 2, 100), np.linspace(-2, 2, 100)) Z = np.zeros_like(X) for j in np.arange(100): for i in np.arange(100): - Z[i,j] = f.eval(np.array([X[i,j], Y[i,j]])) + Z[i, j] = f.eval(np.array([X[i, j], Y[i, j]])) fig, ax = plt.subplots() -ax.contourf(X,Y,Z, levels = 20) -ax.scatter(xstar[0], xstar[1], marker = "*", s = 200, c = "gold", alpha = 1, zorder = 200) +ax.contourf(X, Y, Z, levels=20) +ax.scatter(xstar[0], xstar[1], marker="*", s=200, c="gold", alpha=1, zorder=200) for i in range(20): - x0 = np.random.randn(2)# np.zeros(2) - x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol = 1e-6, max_iter = 100, verbose = False) + x0 = np.random.randn(2) # np.zeros(2) + x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol=1e-6, max_iter=100, verbose=False) print(x_k) - ax.plot(x_hist[:,0], x_hist[:,1], c = "silver", lw = 0.7, ls = '--', alpha = 0.5) - ax.scatter(x_k[0], x_k[1], marker = "+", s = 50, c = "k", alpha = 1, zorder = 210) - -ax.set_xlim(-2,2) -ax.set_ylim(-2,2) - - - - + ax.plot(x_hist[:, 0], x_hist[:, 1], c="silver", lw=0.7, ls="--", alpha=0.5) + ax.scatter(x_k[0], x_k[1], marker="+", s=50, c="k", alpha=1, zorder=210) +ax.set_xlim(-2, 2) +ax.set_ylim(-2, 2) diff --git a/ncopt/tests/test_rosenbrock.py b/ncopt/tests/test_rosenbrock.py deleted file mode 100755 index 9018e7d..0000000 --- a/ncopt/tests/test_rosenbrock.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -author: Fabian Schaipp -""" - -import numpy as np -import sys, os - -tests_path = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, tests_path + '/../..') - -#os.chdir('../..') - -from ncopt.sqpgs import SQP_GS -from ncopt.funs import f_rosenbrock, g_max, g_linear - -f = f_rosenbrock() -g = g_max() - -def test_rosenbrock_from_zero(): - gI = [g] - gE = [] - xstar = np.array([1/np.sqrt(2), 0.5]) - x_k, x_hist, SP = SQP_GS(f, gI, gE, tol = 1e-8, max_iter = 200, verbose = False) - np.testing.assert_array_almost_equal(x_k, xstar, decimal = 4) - - return - -def test_rosenbrock_from_rand(): - gI = [g] - gE = [] - xstar = np.array([1/np.sqrt(2), 0.5]) - x0 = np.random.rand(2) - x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol = 1e-8, max_iter = 200, verbose = False) - np.testing.assert_array_almost_equal(x_k, xstar, decimal = 4) - - return - -def test_rosenbrock_with_eq(): - g1 = g_linear(A=np.eye(2), b=np.ones(2)) - gI = [] - gE = [g1] - xstar = np.ones(2) - x0 = np.zeros(2) - x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol = 1e-8, max_iter = 200, verbose = False) - np.testing.assert_array_almost_equal(x_k, xstar, decimal = 4) - - return \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..98bfb34 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,103 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "ncopt" +dynamic = ["version"] +description = 'This repository contains a Python implementation of the SQP-GS (Sequential Quadratic Programming - Gradient Sampling) algorithm by Curtis and Overton.' +readme = "README.md" +requires-python = ">=3.9" +license = "BSD-3-Clause" +keywords = [] +authors = [ + { name = "fabian-sp", email = "fabian.schaipp@gmail.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "numpy", + "torch", + "cvxopt", +] + +[project.urls] +Documentation = "https://github.com/fabian-sp/ncOPT#readme" +Issues = "https://github.com/fabian-sp/ncOPT/issues" +Source = "https://github.com/fabian-sp/ncOPT" + +[tool.hatch.version] +path = "src/ncopt/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.9", "3.10", "3.11", "3.12"] + +[tool.hatch.envs.types] +dependencies = [ + "mypy>=1.0.0", +] +[tool.hatch.envs.types.scripts] +check = "mypy --install-types --non-interactive {args:src/ncopt tests}" + +[tool.coverage.run] +source_pkgs = ["ncopt", "tests"] +branch = true +parallel = true +omit = [ + "src/ncopt/__about__.py", +] + +[tool.coverage.paths] +ncopt = ["src/ncopt", "*/ncopt/src/ncopt"] +tests = ["tests", "*/ncopt/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[tool.ruff] +lint.select = [ + "E", + "F", + "I", + "NPY201", + "W605", # Check for invalid escape sequences in docstrings (errors in py >= 3.11) +] +lint.ignore = [ + "E741", # ambiguous variable name +] +line-length = 100 + +# The minimum Python version that should be supported +target-version = "py39" + +src = ["src"] diff --git a/src/ncopt/__about__.py b/src/ncopt/__about__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/src/ncopt/__about__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/ncopt/__init__.py b/src/ncopt/__init__.py similarity index 50% rename from ncopt/__init__.py rename to src/ncopt/__init__.py index 139597f..8b13789 100644 --- a/ncopt/__init__.py +++ b/src/ncopt/__init__.py @@ -1,2 +1 @@ - diff --git a/ncopt/funs.py b/src/ncopt/funs.py similarity index 56% rename from ncopt/funs.py rename to src/ncopt/funs.py index 7aea406..206f939 100755 --- a/ncopt/funs.py +++ b/src/ncopt/funs.py @@ -4,60 +4,64 @@ import numpy as np + class f_rosenbrock: """ - Nonsmooth Rosenbrock function (see 5.1 in Curtis, Overton "SQP FOR NONSMOOTH CONSTRAINED OPTIMIZATION") - + Nonsmooth Rosenbrock function (see 5.1 in Curtis, Overton + "SQP FOR NONSMOOTH CONSTRAINED OPTIMIZATION") + x -> w|x_1^2 − x_2| + (1 − x_1)^2 """ - def __init__(self, w = 8): - self.name = 'rosenbrock' + + def __init__(self, w=8): + self.name = "rosenbrock" self.dim = 2 self.dimOut = 1 self.w = w - - def eval(self, x): - return self.w*np.abs(x[0]**2-x[1]) + (1-x[0])**2 - + + def eval(self, x): + return self.w * np.abs(x[0] ** 2 - x[1]) + (1 - x[0]) ** 2 + def differentiable(self, x): - return np.abs(x[0]**2 - x[1]) > 1e-10 - + return np.abs(x[0] ** 2 - x[1]) > 1e-10 + def grad(self, x): - a = np.array([-2+x[0], 0]) - sign = np.sign(x[0]**2 -x[1]) - + a = np.array([-2 + x[0], 0]) + sign = np.sign(x[0] ** 2 - x[1]) + if sign == 1: - b = np.array([2*x[0], -1]) + b = np.array([2 * x[0], -1]) elif sign == -1: - b = np.array([-2*x[0], 1]) + b = np.array([-2 * x[0], 1]) else: - b = np.array([-2*x[0], 1]) - - #b = np.sign(x[0]**2 -x[1]) * np.array([2*x[0], -1]) + b = np.array([-2 * x[0], 1]) + + # b = np.sign(x[0]**2 -x[1]) * np.array([2*x[0], -1]) return a + b - + + class g_max: """ maximum function (see 5.1 in Curtis, Overton "SQP FOR NONSMOOTH CONSTRAINED OPTIMIZATION") - + x -> max(c1*x_1, c2*x_2) - 1 """ - def __init__(self, c1 = np.sqrt(2), c2 = 2.): - self.name = 'max' + + def __init__(self, c1=np.sqrt(2), c2=2.0): + self.name = "max" self.c1 = c1 self.c2 = c2 self.dimOut = 1 return - + def eval(self, x): - return np.maximum(self.c1*x[0], self.c2*x[1]) - 1 - + return np.maximum(self.c1 * x[0], self.c2 * x[1]) - 1 + def differentiable(self, x): - return np.abs(self.c1*x[0] -self.c2*x[1]) > 1e-10 - + return np.abs(self.c1 * x[0] - self.c2 * x[1]) > 1e-10 + def grad(self, x): - - sign = np.sign(self.c1*x[0] - self.c2*x[1]) + sign = np.sign(self.c1 * x[0] - self.c2 * x[1]) if sign == 1: g = np.array([self.c1, 0]) elif sign == -1: @@ -66,27 +70,27 @@ def grad(self, x): g = np.array([0, self.c2]) return g + class g_linear: """ linear constraint: - + x -> Ax - b """ + def __init__(self, A, b): - self.name = 'linear' + self.name = "linear" self.A = A self.b = b self.dim = A.shape[0] self.dimOut = A.shape[1] return - + def eval(self, x): return self.A @ x - self.b - + def differentiable(self, x): return True - + def grad(self, x): return self.A - - diff --git a/ncopt/sqpgs.py b/src/ncopt/sqpgs.py similarity index 57% rename from ncopt/sqpgs.py rename to src/ncopt/sqpgs.py index 7829022..eae5ca7 100755 --- a/ncopt/sqpgs.py +++ b/src/ncopt/sqpgs.py @@ -1,42 +1,45 @@ """ author: Fabian Schaipp -Notation is (wherever possible) inspired by Curtis, Overton "SQP FOR NONSMOOTH CONSTRAINED OPTIMIZATION" +Notation is (wherever possible) inspired by Curtis, Overton +"SQP FOR NONSMOOTH CONSTRAINED OPTIMIZATION" """ -import numpy as np import cvxopt as cx - +import numpy as np + + def sample_points(x, eps, N): """ sample N points uniformly distributed in eps-ball around x """ dim = len(x) U = np.random.randn(N, dim) - norm_U = np.linalg.norm(U, axis = 1) - R = np.random.rand(N)**(1/dim) - Z = eps * (R/norm_U)[:,np.newaxis] * U - + norm_U = np.linalg.norm(U, axis=1) + R = np.random.rand(N) ** (1 / dim) + Z = eps * (R / norm_U)[:, np.newaxis] * U + return x + Z def q_rho(d, rho, H, f_k, gI_k, gE_k, D_f, D_gI, D_gE): - term1 = rho* (f_k + np.max(D_f @ d)) - + term1 = rho * (f_k + np.max(D_f @ d)) + term2 = 0 for j in np.arange(len(D_gI)): term2 += np.maximum(gI_k[j] + D_gI[j] @ d, 0).max() - + term3 = 0 for l in np.arange(len(D_gE)): term3 += np.abs(gE_k[l] + D_gE[l] @ d).max() - - term4 = 0.5 * d.T@H@d - return term1+term2+term3+term4 + + term4 = 0.5 * d.T @ H @ d + return term1 + term2 + term3 + term4 + def phi_rho(x, f, gI, gE, rho): - term1 = rho*f.eval(x) - + term1 = rho * f.eval(x) + # inequalities if len(gI) > 0: term2 = np.sum(np.hstack([np.maximum(gI[j].eval(x), 0) for j in range(len(gI))])) @@ -47,81 +50,85 @@ def phi_rho(x, f, gI, gE, rho): term3 = np.sum(np.hstack([np.abs(gE[l].eval(x)) for l in range(len(gE))])) else: term3 = 0 - - return term1+term2+term3 + + return term1 + term2 + term3 + def stop_criterion(gI, gE, g_k, SP, gI_k, gE_k, B_gI, B_gE, nI_, nE_, pI, pE): """ computes E_k in the paper """ val1 = np.linalg.norm(g_k, np.inf) - + # as gI or gE could be empty, we need a max value for empty arrays --> initial argument - val2 = np.max(gI_k, initial = -np.inf) - val3 = np.max(np.abs(gE_k), initial = -np.inf) - + val2 = np.max(gI_k, initial=-np.inf) + val3 = np.max(np.abs(gE_k), initial=-np.inf) + gI_vals = list() for j in np.arange(nI_): gI_vals += eval_ineq(gI[j], B_gI[j]) - + val4 = -np.inf for j in np.arange(len(gI_vals)): val4 = np.maximum(val4, np.max(SP.lambda_gI[j] * gI_vals[j])) - + gE_vals = list() for j in np.arange(nE_): gE_vals += eval_ineq(gE[j], B_gE[j]) - + val5 = -np.inf for j in np.arange(len(gE_vals)): val5 = np.maximum(val5, np.max(SP.lambda_gE[j] * gE_vals[j])) - + return np.max(np.array([val1, val2, val3, val4, val5])) + def eval_ineq(fun, X): """ evaluate function at multiple inputs needed in stop_criterion - + Returns ------- - list of array, number of entries = fun.dimOut + list of array, number of entries = fun.dimOut """ (N, _) = X.shape D = np.zeros((N, fun.dimOut)) for i in np.arange(N): - D[i,:,] = fun.eval(X[i,:]) - - return [D[:,j] for j in range(fun.dimOut)] + D[ + i, + :, + ] = fun.eval(X[i, :]) + + return [D[:, j] for j in range(fun.dimOut)] def compute_gradients(fun, X): - """ + """ computes gradients of function object f at all rows of array X - + Returns ------- list of 2d-matrices, length of fun.dimOut """ (N, dim) = X.shape - + # fun.grad returns Jacobian, i.e. dimOut x dim D = np.zeros((N, fun.dimOut, dim)) for i in np.arange(N): - D[i,:,:] = fun.grad(X[i,:]) - + D[i, :, :] = fun.grad(X[i, :]) + # TODO: write this directly in return statement D_list = list() for j in np.arange(fun.dimOut): - D_list.append(D[:,j,:]) - - return D_list + D_list.append(D[:, j, :]) + return D_list -def SQP_GS(f, gI, gE, x0 = None, tol = 1e-8, max_iter = 100, verbose = True, assert_tol = 1e-5): +def SQP_GS(f, gI, gE, x0=None, tol=1e-8, max_iter=100, verbose=True, assert_tol=1e-5): """ - each element of gI, gE needs attribute g.dimOut + each element of gI, gE needs attribute g.dimOut Parameters ---------- @@ -150,29 +157,28 @@ def SQP_GS(f, gI, gE, x0 = None, tol = 1e-8, max_iter = 100, verbose = True, ass DESCRIPTION. """ - eps = 1e-1 # sampling radius + eps = 1e-1 # sampling radius rho = 1e-1 theta = 1e-1 - + # extract dimensions of constraints dim = f.dim - dimI = np.array([g.dimOut for g in gI], dtype = int) - dimE = np.array([g.dimOut for g in gE], dtype = int) - - - nI_ = len(gI) # number of inequality function objects - nE_ = len(gE) # number of equality function objects - - nI = sum(dimI) # number of inequality costraints - nE = sum(dimE) # number of equality costraints - - p0 = 2 # sample points for objective - pI_ = 3 * np.ones(nI_, dtype = int) # sample points for ineq constraint - pE_ = 4 * np.ones(nE_, dtype = int) # sample points for eq constraint - + dimI = np.array([g.dimOut for g in gI], dtype=int) + dimE = np.array([g.dimOut for g in gE], dtype=int) + + nI_ = len(gI) # number of inequality function objects + nE_ = len(gE) # number of equality function objects + + nI = sum(dimI) # number of inequality costraints + nE = sum(dimE) # number of equality costraints + + p0 = 2 # sample points for objective + pI_ = 3 * np.ones(nI_, dtype=int) # sample points for ineq constraint + pE_ = 4 * np.ones(nE_, dtype=int) # sample points for eq constraint + pI = np.repeat(pI_, dimI) pE = np.repeat(pE_, dimE) - + # parameters (set after recommendations in paper) # TODO: add params argument to set these eta = 1e-8 @@ -184,162 +190,172 @@ def SQP_GS(f, gI, gE, x0 = None, tol = 1e-8, max_iter = 100, verbose = True, ass xi_s = 1e3 xi_y = 1e3 xi_sy = 1e-6 - + # initialize subproblem object SP = Subproblem(dim, nI, nE, p0, pI, pE) - + if x0 is None: x_k = np.zeros(dim) else: x_k = x0.copy() - + iter_H = 10 E_k = np.inf - + x_hist = [x_k] - x_kmin1 = None; g_kmin1 = None; + x_kmin1 = None + g_kmin1 = None s_hist = np.zeros((dim, iter_H)) y_hist = np.zeros((dim, iter_H)) - + H = np.eye(dim) - - status = 'not optimal'; step = np.nan - + + status = "not optimal" + step = np.nan + hdr_fmt = "%4s\t%10s\t%5s\t%5s\t%10s\t%10s" out_fmt = "%4d\t%10.4g\t%10.4g\t%10.4g\t%10.4g\t%10s" if verbose: print(hdr_fmt % ("iter", "f(x_k)", "max(g_j(x_k))", "E_k", "step", "subproblem status")) - + ############################################## # START OF LOOP ############################################## - + for iter_k in range(max_iter): - if E_k <= tol: - status = 'optimal' + status = "optimal" break - + ############################################## # SAMPLING ############################################## B_f = sample_points(x_k, eps, p0) B_f = np.vstack((x_k, B_f)) - + B_gI = list() for j in np.arange(nI_): B_j = sample_points(x_k, eps, pI_[j]) B_j = np.vstack((x_k, B_j)) B_gI.append(B_j) - + B_gE = list() for j in np.arange(nE_): B_j = sample_points(x_k, eps, pE_[j]) B_j = np.vstack((x_k, B_j)) B_gE.append(B_j) - - + #################################### # COMPUTE GRADIENTS AND EVALUATE ################################### - D_f = compute_gradients(f, B_f)[0] # returns list, always has one element - + D_f = compute_gradients(f, B_f)[0] # returns list, always has one element + D_gI = list() for j in np.arange(nI_): D_gI += compute_gradients(gI[j], B_gI[j]) - + D_gE = list() for j in np.arange(nE_): D_gE += compute_gradients(gE[j], B_gE[j]) - + f_k = f.eval(x_k) # hstack cannot handle empty lists! if nI_ > 0: gI_k = np.hstack([gI[j].eval(x_k) for j in range(nI_)]) else: gI_k = np.array([]) - + if nE_ > 0: gE_k = np.hstack([gE[j].eval(x_k) for j in range(nE_)]) else: gE_k = np.array([]) - + ############################################## # SUBPROBLEM ############################################## - #print("H EIGVALS", np.linalg.eigh(H)[0]) + # print("H EIGVALS", np.linalg.eigh(H)[0]) SP.update(H, rho, D_f, D_gI, D_gE, f_k, gI_k, gE_k) SP.solve() - + d_k = SP.d.copy() - # compute g_k from paper - g_k = SP.lambda_f @ D_f + np.sum([SP.lambda_gI[j] @ D_gI[j] for j in range(nI)], axis = 0) \ - + np.sum([SP.lambda_gE[j] @ D_gE[j] for j in range(nE)], axis = 0) - + # compute g_k from paper + g_k = ( + SP.lambda_f @ D_f + + np.sum([SP.lambda_gI[j] @ D_gI[j] for j in range(nI)], axis=0) + + np.sum([SP.lambda_gE[j] @ D_gE[j] for j in range(nE)], axis=0) + ) + # evaluate v(x) at x=x_k v_k = np.maximum(gI_k, 0).sum() + np.sum(np.abs(gE_k)) - phi_k = rho*f_k + v_k - delta_q = phi_k - q_rho(d_k, rho, H, f_k, gI_k, gE_k, D_f, D_gI, D_gE) - + phi_k = rho * f_k + v_k + delta_q = phi_k - q_rho(d_k, rho, H, f_k, gI_k, gE_k, D_f, D_gI, D_gE) + assert delta_q >= -assert_tol assert np.abs(SP.lambda_f.sum() - rho) <= assert_tol, f"{np.abs(SP.lambda_f.sum() - rho)}" - - + if verbose: - print(out_fmt % (iter_k, f_k, np.max(np.hstack((gI_k,gE_k))), E_k, step, SP.status)) - + print(out_fmt % (iter_k, f_k, np.max(np.hstack((gI_k, gE_k))), E_k, step, SP.status)) + new_E_k = stop_criterion(gI, gE, g_k, SP, gI_k, gE_k, B_gI, B_gE, nI_, nE_, pI, pE) E_k = min(E_k, new_E_k) - + ############################################## # STEP ############################################## - - step = delta_q > nu*eps**2 + + step = delta_q > nu * eps**2 if step: - alpha = 1. - phi_new = phi_rho(x_k + alpha*d_k, f, gI, gE, rho) - + alpha = 1.0 + phi_new = phi_rho(x_k + alpha * d_k, f, gI, gE, rho) + # Armijo step size rule - while phi_new > phi_k - eta*alpha*delta_q: + while phi_new > phi_k - eta * alpha * delta_q: alpha *= gamma - phi_new = phi_rho(x_k + alpha*d_k, f, gI, gE, rho) - + phi_new = phi_rho(x_k + alpha * d_k, f, gI, gE, rho) + # update Hessian if x_kmin1 is not None: s_k = x_k - x_kmin1 - s_hist = np.roll(s_hist, 1, axis = 1) - s_hist[:,0] = s_k - + s_hist = np.roll(s_hist, 1, axis=1) + s_hist[:, 0] = s_k + y_k = g_k - g_kmin1 - y_hist = np.roll(y_hist, 1, axis = 1) - y_hist[:,0] = y_k - + y_hist = np.roll(y_hist, 1, axis=1) + y_hist[:, 0] = y_k + hH = np.eye(dim) for l in np.arange(iter_H): - sl = s_hist[:,l] - yl = y_hist[:,l] - - cond = (np.linalg.norm(sl) <= xi_s*eps) and (np.linalg.norm(yl) <= xi_y*eps) and (np.inner(sl,yl) >= xi_sy*eps**2) - + sl = s_hist[:, l] + yl = y_hist[:, l] + + cond = ( + (np.linalg.norm(sl) <= xi_s * eps) + and (np.linalg.norm(yl) <= xi_y * eps) + and (np.inner(sl, yl) >= xi_sy * eps**2) + ) + if cond: - Hs = hH@sl - hH = hH - np.outer(Hs,Hs)/(sl @ Hs + 1e-16) + np.outer(yl,yl)/(yl @ sl + 1e-16) + Hs = hH @ sl + hH = ( + hH + - np.outer(Hs, Hs) / (sl @ Hs + 1e-16) + + np.outer(yl, yl) / (yl @ sl + 1e-16) + ) - # TODO: is this really necessary? + # TODO: is this really necessary? assert np.all(np.abs(hH - hH.T) <= 1e-8), f"{H}" - + H = hH.copy() - + #################################### # ACTUAL STEP ################################### x_kmin1 = x_k.copy() g_kmin1 = g_k.copy() - - x_k = x_k + alpha*d_k - + + x_k = x_k + alpha * d_k + ############################################## # NO STEP ############################################## @@ -348,25 +364,26 @@ def SQP_GS(f, gI, gE, x0 = None, tol = 1e-8, max_iter = 100, verbose = True, ass theta *= beta_theta else: rho *= beta_rho - + eps *= beta_eps - - + x_hist.append(x_k) - + ############################################## # END OF LOOP ############################################## x_hist = np.vstack(x_hist) - + if E_k > tol: - status = 'max iterations reached' - + status = "max iterations reached" + print(f"SQP-GS has terminated with status {status}") - + return x_k, x_hist, SP -#%% + +# %% + class Subproblem: def __init__(self, dim, nI, nE, p0, pI, pE): @@ -380,127 +397,160 @@ def __init__(self, dim, nI, nE, p0, pI, pE): """ assert len(pI) == nI assert len(pE) == nE - + self.dim = dim self.nI = nI self.nE = nE self.p0 = p0 self.pI = pI self.pE = pE - + self.P, self.q, self.inG, self.inh, self.nonnegG, self.nonnegh = self.initialize() - - + def solve(self): """ - This solves the quadratic program. In every iteration, you should call self.update() before solving in order to have the correct subproblem data. - + This solves the quadratic program. In every iteration, you should + call self.update() before solving in order to have the correct subproblem data. + self.d: array search direction - + self.lambda_f: array KKT multipier for objective. - + self.lambda_gE: list - KKT multipier for equality constraints. - + KKT multipier for equality constraints. + self.lambda_gI: list - KKT multipier for inequality constraints. + KKT multipier for inequality constraints. """ - cx.solvers.options['show_progress'] = False - + cx.solvers.options["show_progress"] = False + iG = np.vstack((self.inG, self.nonnegG)) ih = np.hstack((self.inh, self.nonnegh)) - - qp = cx.solvers.qp(P = cx.matrix(self.P), q = cx.matrix(self.q), G = cx.matrix(iG), h = cx.matrix(ih)) - + + qp = cx.solvers.qp( + P=cx.matrix(self.P), q=cx.matrix(self.q), G=cx.matrix(iG), h=cx.matrix(ih) + ) + self.status = qp["status"] - self.cvx_sol_x = np.array(qp['x']).squeeze() - - self.d = self.cvx_sol_x[:self.dim] + self.cvx_sol_x = np.array(qp["x"]).squeeze() + + self.d = self.cvx_sol_x[: self.dim] self.z = self.cvx_sol_x[self.dim] - self.rI = self.cvx_sol_x[self.dim +1 : self.dim +1 +self.nI] - self.rE = self.cvx_sol_x[self.dim +1 + self.nI : ] - + self.rI = self.cvx_sol_x[self.dim + 1 : self.dim + 1 + self.nI] + self.rE = self.cvx_sol_x[self.dim + 1 + self.nI :] + # TODO: pipe through assert_tol assert len(self.rE) == self.nE - assert np.all(self.rI >= -1e-5) , f"{self.rI}" + assert np.all(self.rI >= -1e-5), f"{self.rI}" assert np.all(self.rE >= -1e-5), f"{self.rE}" - + # extract dual variables = KKT multipliers - self.cvx_sol_z = np.array(qp['z']).squeeze() - lambda_f = self.cvx_sol_z[:self.p0+1] - + self.cvx_sol_z = np.array(qp["z"]).squeeze() + lambda_f = self.cvx_sol_z[: self.p0 + 1] + lambda_gI = list() for j in np.arange(self.nI): - start_ix = self.p0+1+(1+self.pI)[:j].sum() - lambda_gI.append( self.cvx_sol_z[start_ix : start_ix + 1+self.pI[j]] ) - + start_ix = self.p0 + 1 + (1 + self.pI)[:j].sum() + lambda_gI.append(self.cvx_sol_z[start_ix : start_ix + 1 + self.pI[j]]) + lambda_gE = list() for j in np.arange(self.nE): - start_ix = self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() - + start_ix = self.p0 + 1 + (1 + self.pI).sum() + (1 + self.pE)[:j].sum() + # from ineq with + - vec1 = self.cvx_sol_z[start_ix : start_ix + 1+self.pE[j]] - + vec1 = self.cvx_sol_z[start_ix : start_ix + 1 + self.pE[j]] + # from ineq with - - vec2 = self.cvx_sol_z[start_ix+(1+self.pE).sum() : start_ix + (1+self.pE).sum() + 1+self.pE[j]] - + vec2 = self.cvx_sol_z[ + start_ix + (1 + self.pE).sum() : start_ix + (1 + self.pE).sum() + 1 + self.pE[j] + ] + # see Direction.m line 620 - lambda_gE.append(vec1-vec2) - + lambda_gE.append(vec1 - vec2) + self.lambda_f = lambda_f.copy() self.lambda_gI = lambda_gI.copy() self.lambda_gE = lambda_gE.copy() - - return - - + + return + def initialize(self): """ The quadratic subrpoblem we solve in every iteration is of the form: - + min_y 1/2* yPy + q*y subject to Gy <= h - + variable structure: y=(d,z,rI,rE) with d = search direction z = helper variable for objective rI = helper variable for inequality constraints rI = helper variable for equality constraints - - This function initializes the variables P,q,G,h. The entries which change in every iteration are then updated in self.update() - + + This function initializes the variables P,q,G,h. The entries which + change in every iteration are then updated in self.update() + G and h consist of two parts: 1) inG, inh: the inequalities from the paper 2) nonnegG, nonnegh: nonnegativity bounds rI >= 0, rE >= 0 """ - - dimQP = self.dim+1 + self.nI + self.nE - + + dimQP = self.dim + 1 + self.nI + self.nE + P = np.zeros((dimQP, dimQP)) q = np.zeros(dimQP) - - inG = np.zeros((1 + self.p0+np.sum(1+self.pI) + 2*np.sum(1+self.pE), dimQP)) - inh = np.zeros( 1 + self.p0+np.sum(1+self.pI) + 2*np.sum(1+self.pE)) - + + inG = np.zeros((1 + self.p0 + np.sum(1 + self.pI) + 2 * np.sum(1 + self.pE), dimQP)) + inh = np.zeros(1 + self.p0 + np.sum(1 + self.pI) + 2 * np.sum(1 + self.pE)) + # structure of inG (p0+1, sum(1+pI), sum(1+pE), sum(1+pE)) - inG[:self.p0+1, self.dim] = -1 - + inG[: self.p0 + 1, self.dim] = -1 + for j in range(self.nI): - inG[self.p0+1+(1+self.pI)[:j].sum() : self.p0+1+(1+self.pI)[:j].sum() + self.pI[j]+1, self.dim+1+j] = -1 - + inG[ + self.p0 + 1 + (1 + self.pI)[:j].sum() : self.p0 + + 1 + + (1 + self.pI)[:j].sum() + + self.pI[j] + + 1, + self.dim + 1 + j, + ] = -1 + for j in range(self.nE): - inG[self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1, self.dim+1+self.nI+j] = -1 - inG[self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1, self.dim+1+self.nI+j] = -1 - + inG[ + self.p0 + 1 + (1 + self.pI).sum() + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1, + self.dim + 1 + self.nI + j, + ] = -1 + inG[ + self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1, + self.dim + 1 + self.nI + j, + ] = -1 + # we have nI+nE r-variables - nonnegG = np.hstack((np.zeros((self.nI + self.nE, self.dim + 1)), -np.eye(self.nI + self.nE))) + nonnegG = np.hstack( + (np.zeros((self.nI + self.nE, self.dim + 1)), -np.eye(self.nI + self.nE)) + ) nonnegh = np.zeros(self.nI + self.nE) - - return P,q,inG,inh,nonnegG,nonnegh + return P, q, inG, inh, nonnegG, nonnegh def update(self, H, rho, D_f, D_gI, D_gE, f_k, gI_k, gE_k): """ @@ -523,31 +573,80 @@ def update(self, H, rho, D_f, D_gI, D_gE, f_k, gI_k, gE_k): evaluation of inequality constraints at x_k. gE_k : array evaluation of equality constraints at x_k. - + Returns ------- None. """ - self.P[:self.dim, :self.dim] = H - self.q = np.hstack((np.zeros(self.dim), rho, np.ones(self.nI), np.ones(self.nE))) - - self.inG[:self.p0+1, :self.dim] = D_f - self.inh[:self.p0+1] = -f_k - + self.P[: self.dim, : self.dim] = H + self.q = np.hstack((np.zeros(self.dim), rho, np.ones(self.nI), np.ones(self.nE))) + + self.inG[: self.p0 + 1, : self.dim] = D_f + self.inh[: self.p0 + 1] = -f_k + for j in range(self.nI): - self.inG[self.p0+1+(1+self.pI)[:j].sum() : self.p0+1+(1+self.pI)[:j].sum() + self.pI[j]+1, :self.dim] = D_gI[j] - self.inh[self.p0+1+(1+self.pI)[:j].sum() : self.p0+1+(1+self.pI)[:j].sum() + self.pI[j]+1] = -gI_k[j] - + self.inG[ + self.p0 + 1 + (1 + self.pI)[:j].sum() : self.p0 + + 1 + + (1 + self.pI)[:j].sum() + + self.pI[j] + + 1, + : self.dim, + ] = D_gI[j] + self.inh[ + self.p0 + 1 + (1 + self.pI)[:j].sum() : self.p0 + + 1 + + (1 + self.pI)[:j].sum() + + self.pI[j] + + 1 + ] = -gI_k[j] + for j in range(self.nE): - self.inG[self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1, :self.dim] = D_gE[j] - self.inG[self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1, :self.dim] = -D_gE[j] - - self.inh[self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1] = -gE_k[j] - self.inh[self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() : self.p0+1+(1+self.pI).sum()+(1+self.pE).sum()+(1+self.pE)[:j].sum() + self.pE[j]+1] = gE_k[j] - - - return - + self.inG[ + self.p0 + 1 + (1 + self.pI).sum() + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1, + : self.dim, + ] = D_gE[j] + self.inG[ + self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1, + : self.dim, + ] = -D_gE[j] + self.inh[ + self.p0 + 1 + (1 + self.pI).sum() + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1 + ] = -gE_k[j] + self.inh[ + self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() : self.p0 + + 1 + + (1 + self.pI).sum() + + (1 + self.pE).sum() + + (1 + self.pE)[:j].sum() + + self.pE[j] + + 1 + ] = gE_k[j] + return diff --git a/ncopt/torch_obj.py b/src/ncopt/torch_obj.py similarity index 61% rename from ncopt/torch_obj.py rename to src/ncopt/torch_obj.py index 3934ef6..770b2b7 100755 --- a/ncopt/torch_obj.py +++ b/src/ncopt/torch_obj.py @@ -1,42 +1,45 @@ """ author: Fabian Schaipp """ -import numpy as np import torch + class Net: - def __init__(self, D, dimOut = None): - self.name = 'pytorch_Net' + def __init__(self, D, dimOut=None): + self.name = "pytorch_Net" self.D = D - + self.D.zero_grad() - + self.dimIn = self.D[0].weight.shape[1] - + # set mode to evaluation self.D.train(False) - - #if type(self.D[-1]) == torch.nn.ReLU: + + # if type(self.D[-1]) == torch.nn.ReLU: if dimOut is None: print("Caution: output dimension of Net is not specified and derived from last module!") self.dimOut = self.D[-1].weight.shape[0] else: self.dimOut = dimOut return - - def eval(self, x): - assert len(x) == self.dimIn, f"Input for Net has wrong dimension, required dimension is {self.dimIn}." - + + def eval(self, x): + assert ( + len(x) == self.dimIn + ), f"Input for Net has wrong dimension, required dimension is {self.dimIn}." + return self.D.forward(torch.tensor(x, dtype=torch.float32)).detach().numpy() - + def grad(self, x): - assert len(x) == self.dimIn, f"Input for Net has wrong dimension, required dimension is {self.dimIn}." - + assert ( + len(x) == self.dimIn + ), f"Input for Net has wrong dimension, required dimension is {self.dimIn}." + x_torch = torch.tensor(x, dtype=torch.float32) x_torch.requires_grad_(True) - + y_torch = self.D(x_torch) y_torch.backward() return x_torch.grad.data.numpy() - \ No newline at end of file diff --git a/tests/test_rosenbrock.py b/tests/test_rosenbrock.py new file mode 100755 index 0000000..0210dcb --- /dev/null +++ b/tests/test_rosenbrock.py @@ -0,0 +1,43 @@ +""" +author: Fabian Schaipp +""" +import numpy as np + +from ncopt.funs import f_rosenbrock, g_linear, g_max +from ncopt.sqpgs import SQP_GS + +f = f_rosenbrock() +g = g_max() + + +def test_rosenbrock_from_zero(): + gI = [g] + gE = [] + xstar = np.array([1 / np.sqrt(2), 0.5]) + x_k, x_hist, SP = SQP_GS(f, gI, gE, tol=1e-8, max_iter=200, verbose=False) + np.testing.assert_array_almost_equal(x_k, xstar, decimal=4) + + return + + +def test_rosenbrock_from_rand(): + gI = [g] + gE = [] + xstar = np.array([1 / np.sqrt(2), 0.5]) + x0 = np.random.rand(2) + x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol=1e-8, max_iter=200, verbose=False) + np.testing.assert_array_almost_equal(x_k, xstar, decimal=4) + + return + + +def test_rosenbrock_with_eq(): + g1 = g_linear(A=np.eye(2), b=np.ones(2)) + gI = [] + gE = [g1] + xstar = np.ones(2) + x0 = np.zeros(2) + x_k, x_hist, SP = SQP_GS(f, gI, gE, x0, tol=1e-8, max_iter=200, verbose=False) + np.testing.assert_array_almost_equal(x_k, xstar, decimal=4) + + return diff --git a/train_max_fun.py b/train_max_fun.py index b18c56a..b3f89c3 100755 --- a/train_max_fun.py +++ b/train_max_fun.py @@ -1,4 +1,4 @@ -""" +r""" This is as rather experimental script for training a NN representing the function: x \mapsto max(c1*x[0], c2*x[1]) - 1 @@ -6,38 +6,43 @@ The idea is to use a neural network as a constraint for SQP-GS. """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np import torch from torch.optim.lr_scheduler import StepLR -c1 = np.sqrt(2); c2 = 2. +c1 = np.sqrt(2) +c2 = 2.0 + @np.vectorize -def g(x0,x1): - return np.maximum(c1*x0, c2*x1) - 1 +def g(x0, x1): + return np.maximum(c1 * x0, c2 * x1) - 1 + def generate_data(N): - x0 = 2*np.random.randn(N)# * 10 - 5 - x1 = 2*np.random.randn(N)# * 10 - 5 - x0.sort();x1.sort() - X0,X1 = np.meshgrid(x0,x1) - return X0,X1 + x0 = 2 * np.random.randn(N) # * 10 - 5 + x1 = 2 * np.random.randn(N) # * 10 - 5 + x0.sort() + x1.sort() + X0, X1 = np.meshgrid(x0, x1) + return X0, X1 + X0, X1 = generate_data(200) -Z = g(X0,X1) +Z = g(X0, X1) -#%% +# %% -tmp = np.stack((X0.reshape(-1),X1.reshape(-1))).T +tmp = np.stack((X0.reshape(-1), X1.reshape(-1))).T # pytorch weights are in torch.float32, numpy data is float64! -tX = torch.tensor(tmp, dtype = torch.float32) -tZ = torch.tensor(Z.reshape(-1), dtype = torch.float32) +tX = torch.tensor(tmp, dtype=torch.float32) +tZ = torch.tensor(Z.reshape(-1), dtype=torch.float32) N = len(tX) -#%% +# %% # # D_in is input dimension; # # H is hidden dimension; D_out is output dimension. @@ -58,117 +63,121 @@ def generate_data(N): # loss_fn = torch.nn.MSELoss(reduction='mean') -#%% +# %% + class myNN(torch.nn.Module): def __init__(self): super().__init__() - self.l1 = torch.nn.Linear(2, 2) # layer 1 - #self.l2 = torch.nn.Linear(20, 2) # layer 2 - #self.relu = torch.nn.ReLU() + self.l1 = torch.nn.Linear(2, 2) # layer 1 + # self.l2 = torch.nn.Linear(20, 2) # layer 2 + # self.relu = torch.nn.ReLU() self.max = torch.max + def forward(self, x): x = self.l1(x) - x,_ = self.max(x, dim = -1) + x, _ = self.max(x, dim=-1) return x -loss_fn = torch.nn.MSELoss(reduction='mean') + +loss_fn = torch.nn.MSELoss(reduction="mean") model = myNN() # set weights manually -#model.state_dict()["l1.weight"][:] = torch.diag(torch.tensor([c1,c2])) -#model.state_dict()["l1.bias"][:] = -torch.ones(2) +# model.state_dict()["l1.weight"][:] = torch.diag(torch.tensor([c1,c2])) +# model.state_dict()["l1.bias"][:] = -torch.ones(2) print(model.l1.weight) print(model.l1.bias) -#testing -x = torch.tensor([1.,4.]) +# testing +x = torch.tensor([1.0, 4.0]) model(x) g(x[0], x[1]) -#%% +# %% learning_rate = 1e-3 N_EPOCHS = 11 b = 15 -#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) -optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum=0.9, nesterov=True) +# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True) scheduler = StepLR(optimizer, step_size=1, gamma=0.5) -def sample_batch(N, b): - S = torch.randint(high = N, size = (b,)) + +def sample_batch(N, b): + S = torch.randint(high=N, size=(b,)) return S + for epoch in range(N_EPOCHS): print(f"..................EPOCH {epoch}..................") - - for t in range(int(N/b)): - + + for t in range(int(N / b)): S = sample_batch(N, b) - x_batch = tX[S]; z_batch = tZ[S] - + x_batch = tX[S] + z_batch = tZ[S] + # forward pass y_pred = model.forward(x_batch) - + # compute loss. loss = loss_fn(y_pred.squeeze(), z_batch) - + # zero gradients optimizer.zero_grad() - + # backward pass loss.backward() - + # iteration optimizer.step() - + print(model.l1.weight) print(model.l1.bias) print(loss.item()) scheduler.step() - #print(optimizer) - -optimizer.zero_grad() - -#%% plot results + # print(optimizer) + +optimizer.zero_grad() + +# %% plot results N_test = 200 -X0_test,X1_test = generate_data(N_test) +X0_test, X1_test = generate_data(N_test) tmp = np.stack((X0_test.reshape(-1), X1_test.reshape(-1))).T # pytorch weights are in torch.float32, numpy data is float64! -X_test = torch.tensor(tmp, dtype = torch.float32) +X_test = torch.tensor(tmp, dtype=torch.float32) Z_test = model.forward(X_test).detach().numpy().squeeze() Z_test_arr = Z_test.reshape(N_test, N_test) -Z_true = g(X0_test,X1_test).reshape(-1) +Z_true = g(X0_test, X1_test).reshape(-1) -fig, axs = plt.subplots(1,2) -axs[0].scatter(tmp[:,0], tmp[:,1], c = Z_test) -#axs[1].scatter(tmp[:,0], tmp[:,1], c = Z_true) -axs[1].scatter(tmp[:,0], tmp[:,1], c = Z_test-Z_true, vmin = -1e-1, vmax = 1e-1, cmap = "coolwarm") +fig, axs = plt.subplots(1, 2) +axs[0].scatter(tmp[:, 0], tmp[:, 1], c=Z_test) +# axs[1].scatter(tmp[:,0], tmp[:,1], c = Z_true) +axs[1].scatter(tmp[:, 0], tmp[:, 1], c=Z_test - Z_true, vmin=-1e-1, vmax=1e-1, cmap="coolwarm") -#%% -from mpl_toolkits.mplot3d import Axes3D +# %% fig = plt.figure() -ax = fig.gca(projection='3d') +ax = fig.gca(projection="3d") # Plot the surface. ax.plot_surface(X0_test, X1_test, Z_test_arr, cmap=plt.cm.coolwarm, linewidth=0, antialiased=False) -#%% test auto-diff gradient +# %% test auto-diff gradient -x0 = torch.tensor([np.sqrt(2),0.5], dtype = torch.float32) +x0 = torch.tensor([np.sqrt(2), 0.5], dtype=torch.float32) x0.requires_grad_(True) model.zero_grad() @@ -179,6 +188,3 @@ def sample_batch(N, b): x0.grad.data W = model[-3].weight.detach().numpy() - - - From e5f4acf367fceda2686f9b540320f9d06b4f7242 Mon Sep 17 00:00:00 2001 From: phschiele Date: Thu, 7 Mar 2024 13:51:20 +0100 Subject: [PATCH 2/7] Fix linting --- scripts/timing_rosenbrock.py | 1 + src/ncopt/sqpgs/main.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/timing_rosenbrock.py b/scripts/timing_rosenbrock.py index 4ae9936..5c0d19e 100644 --- a/scripts/timing_rosenbrock.py +++ b/scripts/timing_rosenbrock.py @@ -1,6 +1,7 @@ """ author: Fabian Schaipp """ + import timeit import numpy as np diff --git a/src/ncopt/sqpgs/main.py b/src/ncopt/sqpgs/main.py index a002815..a010779 100644 --- a/src/ncopt/sqpgs/main.py +++ b/src/ncopt/sqpgs/main.py @@ -3,7 +3,7 @@ Implements the SQP-GS algorithm from - Frank E. Curtis and Michael L. Overton, A sequential quadratic programming + Frank E. Curtis and Michael L. Overton, A sequential quadratic programming algorithm for nonconvex, nonsmooth constrained optimization, SIAM Journal on Optimization 2012 22:2, 474-500, https://doi.org/10.1137/090780201. @@ -17,8 +17,8 @@ import cvxopt as cx import numpy as np -from ncopt.utils import get_logger from ncopt.sqpgs.defaults import DEFAULT_ARG, DEFAULT_OPTION +from ncopt.utils import get_logger class SQPGS: From dc8a27ef7227d1d02ddecc90dae84dd369902df9 Mon Sep 17 00:00:00 2001 From: phschiele Date: Thu, 7 Mar 2024 13:55:10 +0100 Subject: [PATCH 3/7] Split action steps --- .github/workflows/unit_tests.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 9e39d0f..1d54272 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -27,7 +27,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.9" - - run: | - pipx install hatch - hatch env create - hatch run test + - run: pipx install hatch + - run: hatch env create + - run: hatch run test From 23c7cfab3db9c1c76d9da7c9d81cf58479b958c6 Mon Sep 17 00:00:00 2001 From: phschiele Date: Thu, 7 Mar 2024 13:57:36 +0100 Subject: [PATCH 4/7] Adds matplotlib to dev dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 98bfb34..7644d97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ path = "src/ncopt/__about__.py" dependencies = [ "coverage[toml]>=6.5", "pytest", + "matplotlib", ] [tool.hatch.envs.default.scripts] From fb6ad92a370636a3f1edc528f5a065f5964c62b5 Mon Sep 17 00:00:00 2001 From: Fabian Schaipp Date: Thu, 7 Mar 2024 16:39:13 +0100 Subject: [PATCH 5/7] update readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 418fbdb..bc330e9 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,14 @@ This repository contains a `Python` implementation of the SQP-GS (*Sequential Qu **Note:** this implementation is a **prototype code**, it has been tested only for a simple problem and it is not performance-optimized. A Matlab implementation is available from the authors of the paper, see [2]. +## Installation + +If you want to install an editable version of this package in your Python environment, run the command + +``` + python -m pip install --editable . +``` + ## Mathematical description The algorithm can solve problems of the form From 2a96948af1c8bd5c3277f3e836ed1a443e10861f Mon Sep 17 00:00:00 2001 From: Fabian Schaipp Date: Thu, 7 Mar 2024 16:44:18 +0100 Subject: [PATCH 6/7] import in init for handiness --- src/ncopt/sqpgs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ncopt/sqpgs/__init__.py b/src/ncopt/sqpgs/__init__.py index e69de29..fc01377 100644 --- a/src/ncopt/sqpgs/__init__.py +++ b/src/ncopt/sqpgs/__init__.py @@ -0,0 +1 @@ +from .main import SQPGS \ No newline at end of file From df0798b48737bbb71c889c0470abbc03e7f8db94 Mon Sep 17 00:00:00 2001 From: Fabian Schaipp Date: Thu, 7 Mar 2024 17:09:24 +0100 Subject: [PATCH 7/7] fix formating --- src/ncopt/sqpgs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ncopt/sqpgs/__init__.py b/src/ncopt/sqpgs/__init__.py index fc01377..dcab673 100644 --- a/src/ncopt/sqpgs/__init__.py +++ b/src/ncopt/sqpgs/__init__.py @@ -1 +1 @@ -from .main import SQPGS \ No newline at end of file +from .main import SQPGS # noqa