Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/feiyang/stubfile #49

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
49 changes: 36 additions & 13 deletions cvxpygen/cpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,45 @@
"""

import os
import sys
import shutil
import pickle
import shutil
import sys
import warnings
from subprocess import call

from cvxpygen import utils
from cvxpygen.utils import write_file, read_write_file, write_example_def, write_module_prot, write_module_def, \
write_canon_cmake, write_method, replace_cmake_data, replace_setup_data, replace_html_data
from cvxpygen.mappings import Configuration, PrimalVariableInfo, DualVariableInfo, ConstraintInfo, \
ParameterCanon, ParameterInfo
from cvxpygen.solvers import get_interface_class
import cvxpy as cp
import numpy as np
from scipy import sparse
from subprocess import call
from cvxpy.problems.objective import Maximize
from cvxpy.cvxcore.python import canonInterface as cI
from cvxpy.expressions.variable import upper_tri_to_full
from cvxpy.problems.objective import Maximize
from scipy import sparse


def generate_code(problem, code_dir='CPG_code', solver=None, solver_opts=None,
from cvxpygen import utils
from cvxpygen.mappings import (
Configuration,
ConstraintInfo,
DualVariableInfo,
ParameterCanon,
ParameterInfo,
PrimalVariableInfo,
)
from cvxpygen.solvers import get_interface_class
from cvxpygen.utils import (
read_write_file,
replace_cmake_data,
replace_html_data,
replace_setup_data,
write_canon_cmake,
write_example_def,
write_file,
write_method,
write_interface,
write_module_def,
write_module_prot,
)


def generate_code(problem: cp.Problem, code_dir='CPG_code', solver=None, solver_opts=None,
enable_settings=[], unroll=False, prefix='', wrapper=True):
"""
Generate C code to solve a CVXPY problem
Expand Down Expand Up @@ -375,6 +393,11 @@ def write_c_code(problem: cp.Problem, configuration: Configuration, variable_inf
configuration, variable_info, dual_variable_info,
parameter_info, solver_interface)

write_file(os.path.join(configuration.code_dir, 'cpg_module.pyi'), 'w',
write_interface,
configuration, variable_info, dual_variable_info,
parameter_info, solver_interface)

write_file(os.path.join(configuration.code_dir, 'problem.pickle'), 'wb',
lambda x, y: pickle.dump(y, x),
cp.Problem(problem.objective, problem.constraints))
Expand Down
90 changes: 88 additions & 2 deletions cvxpygen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
limitations under the License.
"""

from io import TextIOWrapper
import textwrap
from typing import TYPE_CHECKING, Iterable
import numpy as np
from datetime import datetime


if TYPE_CHECKING:
from cvxpygen.mappings import Configuration, DualVariableInfo, ParameterInfo, VariableInfo
from cvxpygen.solvers import SolverInterface


def write_file(path, mode, function, *args):
"""Write data to a file using a specific utility function."""
with open(path, mode) as file:
Expand Down Expand Up @@ -1248,6 +1256,77 @@ def write_module_prot(f, configuration, parameter_info, variable_info, dual_vari
f'struct {configuration.prefix}CPG_Params_cpp_t& CPG_Params_cpp);\n')


def write_interface(
f: TextIOWrapper,
configuration: "Configuration",
variable_info: "VariableInfo",
dual_variable_info: "DualVariableInfo",
parameter_info: "ParameterInfo",
solver_interface: "SolverInterface",
):
write_description(f, 'py', 'Python extension stub file.')
interface_content = ""

def define_struct(
cls_name: str,
properties: Iterable[str] = [],
methods: Iterable[str] = [],
):
decl_ = ["", f"class {configuration.prefix}{cls_name}:", ""]
for name in properties:
decl_ += [
" @property",
f" def {name}(self):",
" ...",
""
]
for name in methods:
decl_ += [
f" def {name}(self):"
" ...",
""
]

return "\n".join(decl_) + "\n"

interface_content += define_struct("cpg_params", parameter_info.name_to_size_usp.keys())
interface_content += define_struct("cpg_updated", parameter_info.name_to_size_usp.keys())
interface_content += define_struct("cpg_prim", variable_info.name_to_init.keys())

if len(dual_variable_info.name_to_init) > 0:
interface_content += define_struct("cpg_dual", dual_variable_info.name_to_init.keys())

interface_content += define_struct(
"cpg_info",
properties=[
"obj_val",
"iter",
"status",
"pri_res",
"dua_res",
"time",
]
)

interface_content += define_struct(
"cpg_result",
["cpg_prim", "cpg_info"] + (["cpg_dual"] if len(dual_variable_info.name_to_init) > 0 else [])
)

interface_content += "\ndef solve(arg0: cpg_updated, arg1: cpg_params):\n ...\n"

interface_content += "\ndef set_solver_default_settings():\n ...\n"
for name, type_ in solver_interface.stgs_names_to_type.items():
pytype = type_.removeprefix("cpg_")
match pytype:
case "const char*":
pytype = "str"
interface_content += f"\ndef set_solver_{name}(arg0: {pytype}):\n ...\n"

f.write(
interface_content
)

def replace_setup_data(text):
"""
Replace placeholder strings in setup.py file
Expand All @@ -1258,7 +1337,14 @@ def replace_setup_data(text):
return text.replace('%DATE', now.strftime("on %B %d, %Y at %H:%M:%S"))


def write_method(f, configuration, variable_info, dual_variable_info, parameter_info, solver_interface):
def write_method(
f: TextIOWrapper,
configuration: "Configuration",
variable_info: "VariableInfo",
dual_variable_info: "DualVariableInfo",
parameter_info: "ParameterInfo",
solver_interface: "SolverInterface",
):
"""
Write function to be registered as custom CVXPY solve method
"""
Expand Down Expand Up @@ -1289,7 +1375,7 @@ def write_method(f, configuration, variable_info, dual_variable_info, parameter_
f.write(' cpg_module.set_solver_default_settings()\n')
f.write(' for key, value in kwargs.items():\n')
f.write(' try:\n')
f.write(' eval(f\'cpg_module.set_solver_{standard_settings_names.get(key, key)}(value)\')\n')
f.write(' getattr(cpg_module, f\'set_solver_{standard_settings_names.get(key, key)}\')(value)\n')
f.write(' except AttributeError:\n')
f.write(' raise AttributeError(f\'Solver setting "{key}" not available.\')\n\n')

Expand Down
79 changes: 79 additions & 0 deletions tests/test_stub_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import ast
import itertools
import logging
from io import StringIO
from tempfile import TemporaryDirectory
from unittest import TestCase

import cvxpy as cp
import test_E2E_LP
import test_E2E_QP
import test_E2E_SOCP

from cvxpygen.cpg import (
get_configuration,
get_constraint_info,
get_dual_variable_info,
get_interface_class,
get_parameter_info,
get_variable_info,
handle_sparsity,
)
from cvxpygen.utils import write_interface


class test_stub_gen(TestCase):
def setUp(self) -> None:
self.all_problems = itertools.chain(
test_E2E_LP.name_to_prob.items(),
test_E2E_QP.name_to_prob.items(),
test_E2E_SOCP.name_to_prob.items(),
)
self.tempdir = TemporaryDirectory()
return super().setUp()

def tearDown(self) -> None:
self.tempdir.cleanup()
return super().tearDown()

def get_codegen_context(self, problem: cp.Problem):
# problem data
data, solving_chain, inverse_data = problem.get_problem_data(
solver=None,
)
param_prob = data['param_prob']
solver_name = solving_chain.solver.name()
interface_class, cvxpy_interface_class = get_interface_class(solver_name)

# configuration
configuration = get_configuration(self.tempdir, solver_name, False, "")

# cone problems check
if hasattr(param_prob, 'cone_dims'):
cone_dims = param_prob.cone_dims
interface_class.check_unsupported_cones(cone_dims)

handle_sparsity(param_prob)

solver_interface = interface_class(data, param_prob, []) # noqa
variable_info = get_variable_info(problem, inverse_data)
dual_variable_info = get_dual_variable_info(inverse_data, solver_interface, cvxpy_interface_class)
parameter_info = get_parameter_info(param_prob)
constraint_info = get_constraint_info(solver_interface)
return dict(
configuration=configuration,
solver_interface=solver_interface,
variable_info=variable_info,
dual_variable_info=dual_variable_info,
parameter_info=parameter_info,
)

def test_stub_valid(self):
for name, problem in self.all_problems:
with StringIO() as f:
write_interface(f=f, **self.get_codegen_context(problem))
try:
ast.parse(f.read())
except SyntaxError:
logging.exception(f"Generated stub file for problem {name} has incoorect syntax")
raise