diff --git a/l0bnb/node/core.py b/l0bnb/node/core.py index 3936eb0..8efa25e 100644 --- a/l0bnb/node/core.py +++ b/l0bnb/node/core.py @@ -78,23 +78,22 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs): self.gs_xtr = np.copy(parent.gs_xtr) self.gs_xb = np.copy(parent.gs_xb) - def lower_solve(self, l0, l2, m, solver, rel_tol, int_tol=1e-6, tree_upper_bound=None, mio_gap=None): if solver == 'l1cd': - sol, gs_xtr, gs_xb = cd_solve(x=self.x, y=self.y, l0=l0, l2=l2, m=m, zlb=self.zlb, + sol = cd_solve(x=self.x, y=self.y, l0=l0, l2=l2, m=m, zlb=self.zlb, zub=self.zub, xi_norm=self.xi_norm, rel_tol=rel_tol, warm_start=self.warm_start, r=self.r, tree_upper_bound=tree_upper_bound, mio_gap=mio_gap, - gs_xtr=self.gs_xtr, gs_xb=self.gs_xb, return_gs=True) + gs_xtr=self.gs_xtr, gs_xb=self.gs_xb) self.primal_value = sol.primal_value self.dual_value = sol.dual_value self.primal_beta = sol.primal_beta self.z = sol.z self.support = sol.support self.r = sol.r - self.gs_xtr = gs_xtr - self.gs_xb = gs_xb + self.gs_xtr = sol.gs_xtr + self.gs_xb = sol.gs_xb else: full_zlb = np.zeros(self.x.shape[1]) full_zlb[self.zlb] = 1 diff --git a/l0bnb/relaxation/core.py b/l0bnb/relaxation/core.py index 624eb9e..8fa91a0 100644 --- a/l0bnb/relaxation/core.py +++ b/l0bnb/relaxation/core.py @@ -70,7 +70,7 @@ def _initialize(x, y, l0, l2, m, fixed_lb, fixed_ub, xi_norm, warm_start, r): @njit(cache=True, parallel=True) def _above_threshold_indices(zub, r, x, threshold): rx = r @ x - above_threshold = np.where(zub * np.abs(r @ x) - threshold > 0)[0] + above_threshold = np.where(zub * np.abs(rx) - threshold > 0)[0] return above_threshold, rx @@ -114,9 +114,10 @@ def _above_threshold_indices_gs(zub, r, x, y, threshold, gs_xtr, gs_xb, beta): def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None, warm_start=None, r=None, rel_tol=1e-4, tree_upper_bound=None, mio_gap=0, - check_if_integral=True, return_gs=False): + check_if_integral=True): st = time() - _sol_str = 'primal_value dual_value support primal_beta sol_time z r' + _sol_str = \ + 'primal_value dual_value support primal_beta sol_time z r gs_xtr gs_xb' Solution = namedtuple('Solution', _sol_str) beta, r, support, zub, zlb, xi_norm = \ @@ -128,8 +129,9 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None, beta, cost, r = cd(x, beta, cost, l0, l2, m, xi_norm, zlb, zub, support, r, cd_tol) if GS_FLAG and gs_xtr is None: - above_threshold, rx, gs_xtr, gs_xb = _above_threshold_indices_root_first_call_gs( - zub, r, x, y, threshold) + above_threshold, rx, gs_xtr, gs_xb = \ + _above_threshold_indices_root_first_call_gs( + zub, r, x, y, threshold) elif GS_FLAG: above_threshold, rx, gs_xtr, gs_xb = _above_threshold_indices_gs( zub, r, x, y, threshold, gs_xtr, gs_xb, beta) @@ -169,18 +171,13 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None, if is_integral(z_active, 1e-4): ws = {i: j for i, j in zip(active_set, beta_active)} sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_active, - zub=zub_active, xi_norm=xi_norm, warm_start=ws, - r=r, rel_tol=rel_tol, - tree_upper_bound=tree_upper_bound, mio_gap=1, - check_if_integral=False) - if return_gs: - return sol, gs_xtr, gs_xb - else: - return sol + zub=zub_active, gs_xtr=gs_xtr, gs_xb=gs_xb, + xi_norm=xi_norm, warm_start=ws, r=r, + rel_tol=rel_tol, tree_upper_bound=tree_upper_bound, + mio_gap=1, check_if_integral=False) + return sol sol = Solution(primal_value=primal_cost, dual_value=dual_cost, support=active_set, primal_beta=beta_active, - sol_time=time() - st, z=z_active, r=r) - if return_gs: - return sol, gs_xtr, gs_xb - else: - return sol + sol_time=time() - st, z=z_active, r=r, gs_xtr=gs_xtr, + gs_xb=gs_xb) + return sol