1
+ import ast
1
2
import itertools
3
+ import logging
4
+ from io import StringIO
5
+ from tempfile import TemporaryDirectory
2
6
from unittest import TestCase
7
+
8
+ import cvxpy as cp
3
9
import test_E2E_LP
4
10
import test_E2E_QP
5
11
import test_E2E_SOCP
6
12
13
+ from cvxpygen .cpg import (
14
+ get_configuration ,
15
+ get_constraint_info ,
16
+ get_dual_variable_info ,
17
+ get_interface_class ,
18
+ get_parameter_info ,
19
+ get_variable_info ,
20
+ handle_sparsity ,
21
+ )
22
+ from cvxpygen .utils import write_interface
23
+
24
+
7
25
class test_stub_gen (TestCase ):
8
26
def setUp (self ) -> None :
9
27
self .all_problems = itertools .chain (
10
28
test_E2E_LP .name_to_prob .items (),
11
29
test_E2E_QP .name_to_prob .items (),
12
30
test_E2E_SOCP .name_to_prob .items (),
13
31
)
32
+ self .tempdir = TemporaryDirectory ()
14
33
return super ().setUp ()
34
+
35
+ def tearDown (self ) -> None :
36
+ self .tempdir .cleanup ()
37
+ return super ().tearDown ()
15
38
39
+ def get_codegen_context (self , problem : cp .Problem ):
40
+ # problem data
41
+ data , solving_chain , inverse_data = problem .get_problem_data (
42
+ solver = None ,
43
+ )
44
+ param_prob = data ['param_prob' ]
45
+ solver_name = solving_chain .solver .name ()
46
+ interface_class , cvxpy_interface_class = get_interface_class (solver_name )
47
+
48
+ # configuration
49
+ configuration = get_configuration (self .tempdir , solver_name , False , "" )
50
+
51
+ # cone problems check
52
+ if hasattr (param_prob , 'cone_dims' ):
53
+ cone_dims = param_prob .cone_dims
54
+ interface_class .check_unsupported_cones (cone_dims )
55
+
56
+ handle_sparsity (param_prob )
57
+
58
+ solver_interface = interface_class (data , param_prob , []) # noqa
59
+ variable_info = get_variable_info (problem , inverse_data )
60
+ dual_variable_info = get_dual_variable_info (inverse_data , solver_interface , cvxpy_interface_class )
61
+ parameter_info = get_parameter_info (param_prob )
62
+ constraint_info = get_constraint_info (solver_interface )
63
+ return dict (
64
+ configuration = configuration ,
65
+ solver_interface = solver_interface ,
66
+ variable_info = variable_info ,
67
+ dual_variable_info = dual_variable_info ,
68
+ parameter_info = parameter_info ,
69
+ )
70
+
16
71
def test_stub_valid (self ):
17
- ...
72
+ for name , problem in self .all_problems :
73
+ with StringIO () as f :
74
+ write_interface (f = f , ** self .get_codegen_context (problem ))
75
+ try :
76
+ ast .parse (f .read ())
77
+ except SyntaxError :
78
+ logging .exception (f"Generated stub file for problem { name } has incoorect syntax" )
79
+ raise
0 commit comments