diff --git a/cvxpygen/cpg.py b/cvxpygen/cpg.py index ab9561c..55add54 100644 --- a/cvxpygen/cpg.py +++ b/cvxpygen/cpg.py @@ -12,27 +12,45 @@ """ import os -import sys -import shutil import pickle +import shutil +import sys import warnings +from subprocess import call -from cvxpygen import utils -from cvxpygen.utils import write_file, read_write_file, write_example_def, write_module_prot, write_module_def, \ - write_canon_cmake, write_method, replace_cmake_data, replace_setup_data, replace_html_data -from cvxpygen.mappings import Configuration, PrimalVariableInfo, DualVariableInfo, ConstraintInfo, \ - ParameterCanon, ParameterInfo -from cvxpygen.solvers import get_interface_class import cvxpy as cp import numpy as np -from scipy import sparse -from subprocess import call -from cvxpy.problems.objective import Maximize from cvxpy.cvxcore.python import canonInterface as cI from cvxpy.expressions.variable import upper_tri_to_full +from cvxpy.problems.objective import Maximize +from scipy import sparse - -def generate_code(problem, code_dir='CPG_code', solver=None, solver_opts=None, +from cvxpygen import utils +from cvxpygen.mappings import ( + Configuration, + ConstraintInfo, + DualVariableInfo, + ParameterCanon, + ParameterInfo, + PrimalVariableInfo, +) +from cvxpygen.solvers import get_interface_class +from cvxpygen.utils import ( + read_write_file, + replace_cmake_data, + replace_html_data, + replace_setup_data, + write_canon_cmake, + write_example_def, + write_file, + write_method, + write_interface, + write_module_def, + write_module_prot, +) + + +def generate_code(problem: cp.Problem, code_dir='CPG_code', solver=None, solver_opts=None, enable_settings=[], unroll=False, prefix='', wrapper=True): """ Generate C code to solve a CVXPY problem @@ -375,6 +393,11 @@ def write_c_code(problem: cp.Problem, configuration: Configuration, variable_inf configuration, variable_info, dual_variable_info, parameter_info, solver_interface) + write_file(os.path.join(configuration.code_dir, 'cpg_module.pyi'), 'w', + write_interface, + configuration, variable_info, dual_variable_info, + parameter_info, solver_interface) + write_file(os.path.join(configuration.code_dir, 'problem.pickle'), 'wb', lambda x, y: pickle.dump(y, x), cp.Problem(problem.objective, problem.constraints)) diff --git a/cvxpygen/utils.py b/cvxpygen/utils.py index 3a9b931..5cb9518 100644 --- a/cvxpygen/utils.py +++ b/cvxpygen/utils.py @@ -11,10 +11,18 @@ limitations under the License. """ +from io import TextIOWrapper +import textwrap +from typing import TYPE_CHECKING, Iterable import numpy as np from datetime import datetime +if TYPE_CHECKING: + from cvxpygen.mappings import Configuration, DualVariableInfo, ParameterInfo, VariableInfo + from cvxpygen.solvers import SolverInterface + + def write_file(path, mode, function, *args): """Write data to a file using a specific utility function.""" with open(path, mode) as file: @@ -1248,6 +1256,77 @@ def write_module_prot(f, configuration, parameter_info, variable_info, dual_vari f'struct {configuration.prefix}CPG_Params_cpp_t& CPG_Params_cpp);\n') +def write_interface( + f: TextIOWrapper, + configuration: "Configuration", + variable_info: "VariableInfo", + dual_variable_info: "DualVariableInfo", + parameter_info: "ParameterInfo", + solver_interface: "SolverInterface", +): + write_description(f, 'py', 'Python extension stub file.') + interface_content = "" + + def define_struct( + cls_name: str, + properties: Iterable[str] = [], + methods: Iterable[str] = [], + ): + decl_ = ["", f"class {configuration.prefix}{cls_name}:", ""] + for name in properties: + decl_ += [ + " @property", + f" def {name}(self):", + " ...", + "" + ] + for name in methods: + decl_ += [ + f" def {name}(self):" + " ...", + "" + ] + + return "\n".join(decl_) + "\n" + + interface_content += define_struct("cpg_params", parameter_info.name_to_size_usp.keys()) + interface_content += define_struct("cpg_updated", parameter_info.name_to_size_usp.keys()) + interface_content += define_struct("cpg_prim", variable_info.name_to_init.keys()) + + if len(dual_variable_info.name_to_init) > 0: + interface_content += define_struct("cpg_dual", dual_variable_info.name_to_init.keys()) + + interface_content += define_struct( + "cpg_info", + properties=[ + "obj_val", + "iter", + "status", + "pri_res", + "dua_res", + "time", + ] + ) + + interface_content += define_struct( + "cpg_result", + ["cpg_prim", "cpg_info"] + (["cpg_dual"] if len(dual_variable_info.name_to_init) > 0 else []) + ) + + interface_content += "\ndef solve(arg0: cpg_updated, arg1: cpg_params):\n ...\n" + + interface_content += "\ndef set_solver_default_settings():\n ...\n" + for name, type_ in solver_interface.stgs_names_to_type.items(): + pytype = type_.removeprefix("cpg_") + match pytype: + case "const char*": + pytype = "str" + interface_content += f"\ndef set_solver_{name}(arg0: {pytype}):\n ...\n" + + f.write( + interface_content + ) + def replace_setup_data(text): """ Replace placeholder strings in setup.py file @@ -1258,7 +1337,14 @@ def replace_setup_data(text): return text.replace('%DATE', now.strftime("on %B %d, %Y at %H:%M:%S")) -def write_method(f, configuration, variable_info, dual_variable_info, parameter_info, solver_interface): +def write_method( + f: TextIOWrapper, + configuration: "Configuration", + variable_info: "VariableInfo", + dual_variable_info: "DualVariableInfo", + parameter_info: "ParameterInfo", + solver_interface: "SolverInterface", +): """ Write function to be registered as custom CVXPY solve method """ @@ -1289,7 +1375,7 @@ def write_method(f, configuration, variable_info, dual_variable_info, parameter_ f.write(' cpg_module.set_solver_default_settings()\n') f.write(' for key, value in kwargs.items():\n') f.write(' try:\n') - f.write(' eval(f\'cpg_module.set_solver_{standard_settings_names.get(key, key)}(value)\')\n') + f.write(' getattr(cpg_module, f\'set_solver_{standard_settings_names.get(key, key)}\')(value)\n') f.write(' except AttributeError:\n') f.write(' raise AttributeError(f\'Solver setting "{key}" not available.\')\n\n') diff --git a/tests/test_stub_gen.py b/tests/test_stub_gen.py new file mode 100644 index 0000000..89988d9 --- /dev/null +++ b/tests/test_stub_gen.py @@ -0,0 +1,79 @@ +import ast +import itertools +import logging +from io import StringIO +from tempfile import TemporaryDirectory +from unittest import TestCase + +import cvxpy as cp +import test_E2E_LP +import test_E2E_QP +import test_E2E_SOCP + +from cvxpygen.cpg import ( + get_configuration, + get_constraint_info, + get_dual_variable_info, + get_interface_class, + get_parameter_info, + get_variable_info, + handle_sparsity, +) +from cvxpygen.utils import write_interface + + +class test_stub_gen(TestCase): + def setUp(self) -> None: + self.all_problems = itertools.chain( + test_E2E_LP.name_to_prob.items(), + test_E2E_QP.name_to_prob.items(), + test_E2E_SOCP.name_to_prob.items(), + ) + self.tempdir = TemporaryDirectory() + return super().setUp() + + def tearDown(self) -> None: + self.tempdir.cleanup() + return super().tearDown() + + def get_codegen_context(self, problem: cp.Problem): + # problem data + data, solving_chain, inverse_data = problem.get_problem_data( + solver=None, + ) + param_prob = data['param_prob'] + solver_name = solving_chain.solver.name() + interface_class, cvxpy_interface_class = get_interface_class(solver_name) + + # configuration + configuration = get_configuration(self.tempdir, solver_name, False, "") + + # cone problems check + if hasattr(param_prob, 'cone_dims'): + cone_dims = param_prob.cone_dims + interface_class.check_unsupported_cones(cone_dims) + + handle_sparsity(param_prob) + + solver_interface = interface_class(data, param_prob, []) # noqa + variable_info = get_variable_info(problem, inverse_data) + dual_variable_info = get_dual_variable_info(inverse_data, solver_interface, cvxpy_interface_class) + parameter_info = get_parameter_info(param_prob) + constraint_info = get_constraint_info(solver_interface) + return dict( + configuration=configuration, + solver_interface=solver_interface, + variable_info=variable_info, + dual_variable_info=dual_variable_info, + parameter_info=parameter_info, + ) + + def test_stub_valid(self): + for name, problem in self.all_problems: + with StringIO() as f: + write_interface(f=f, **self.get_codegen_context(problem)) + try: + ast.parse(f.read()) + except SyntaxError: + logging.exception(f"Generated stub file for problem {name} has incoorect syntax") + raise \ No newline at end of file