Skip to content

Commit

Permalink
aggregate all relaxation output to Solution
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaab committed Apr 19, 2021
1 parent d784a70 commit 059ce32
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 23 deletions.
9 changes: 4 additions & 5 deletions l0bnb/node/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,22 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs):
self.gs_xtr = np.copy(parent.gs_xtr)
self.gs_xb = np.copy(parent.gs_xb)


def lower_solve(self, l0, l2, m, solver, rel_tol, int_tol=1e-6,
tree_upper_bound=None, mio_gap=None):
if solver == 'l1cd':
sol, gs_xtr, gs_xb = cd_solve(x=self.x, y=self.y, l0=l0, l2=l2, m=m, zlb=self.zlb,
sol = cd_solve(x=self.x, y=self.y, l0=l0, l2=l2, m=m, zlb=self.zlb,
zub=self.zub, xi_norm=self.xi_norm, rel_tol=rel_tol,
warm_start=self.warm_start, r=self.r,
tree_upper_bound=tree_upper_bound, mio_gap=mio_gap,
gs_xtr=self.gs_xtr, gs_xb=self.gs_xb, return_gs=True)
gs_xtr=self.gs_xtr, gs_xb=self.gs_xb)
self.primal_value = sol.primal_value
self.dual_value = sol.dual_value
self.primal_beta = sol.primal_beta
self.z = sol.z
self.support = sol.support
self.r = sol.r
self.gs_xtr = gs_xtr
self.gs_xb = gs_xb
self.gs_xtr = sol.gs_xtr
self.gs_xb = sol.gs_xb
else:
full_zlb = np.zeros(self.x.shape[1])
full_zlb[self.zlb] = 1
Expand Down
33 changes: 15 additions & 18 deletions l0bnb/relaxation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _initialize(x, y, l0, l2, m, fixed_lb, fixed_ub, xi_norm, warm_start, r):
@njit(cache=True, parallel=True)
def _above_threshold_indices(zub, r, x, threshold):
rx = r @ x
above_threshold = np.where(zub * np.abs(r @ x) - threshold > 0)[0]
above_threshold = np.where(zub * np.abs(rx) - threshold > 0)[0]
return above_threshold, rx


Expand Down Expand Up @@ -114,9 +114,10 @@ def _above_threshold_indices_gs(zub, r, x, y, threshold, gs_xtr, gs_xb, beta):
def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None,
warm_start=None, r=None,
rel_tol=1e-4, tree_upper_bound=None, mio_gap=0,
check_if_integral=True, return_gs=False):
check_if_integral=True):
st = time()
_sol_str = 'primal_value dual_value support primal_beta sol_time z r'
_sol_str = \
'primal_value dual_value support primal_beta sol_time z r gs_xtr gs_xb'
Solution = namedtuple('Solution', _sol_str)

beta, r, support, zub, zlb, xi_norm = \
Expand All @@ -128,8 +129,9 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None,
beta, cost, r = cd(x, beta, cost, l0, l2, m, xi_norm, zlb, zub,
support, r, cd_tol)
if GS_FLAG and gs_xtr is None:
above_threshold, rx, gs_xtr, gs_xb = _above_threshold_indices_root_first_call_gs(
zub, r, x, y, threshold)
above_threshold, rx, gs_xtr, gs_xb = \
_above_threshold_indices_root_first_call_gs(
zub, r, x, y, threshold)
elif GS_FLAG:
above_threshold, rx, gs_xtr, gs_xb = _above_threshold_indices_gs(
zub, r, x, y, threshold, gs_xtr, gs_xb, beta)
Expand Down Expand Up @@ -169,18 +171,13 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None,
if is_integral(z_active, 1e-4):
ws = {i: j for i, j in zip(active_set, beta_active)}
sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_active,
zub=zub_active, xi_norm=xi_norm, warm_start=ws,
r=r, rel_tol=rel_tol,
tree_upper_bound=tree_upper_bound, mio_gap=1,
check_if_integral=False)
if return_gs:
return sol, gs_xtr, gs_xb
else:
return sol
zub=zub_active, gs_xtr=gs_xtr, gs_xb=gs_xb,
xi_norm=xi_norm, warm_start=ws, r=r,
rel_tol=rel_tol, tree_upper_bound=tree_upper_bound,
mio_gap=1, check_if_integral=False)
return sol
sol = Solution(primal_value=primal_cost, dual_value=dual_cost,
support=active_set, primal_beta=beta_active,
sol_time=time() - st, z=z_active, r=r)
if return_gs:
return sol, gs_xtr, gs_xb
else:
return sol
sol_time=time() - st, z=z_active, r=r, gs_xtr=gs_xtr,
gs_xb=gs_xb)
return sol

0 comments on commit 059ce32

Please sign in to comment.