diff --git a/l0bnb/node/core.py b/l0bnb/node/core.py index 8efa25e..4505a38 100644 --- a/l0bnb/node/core.py +++ b/l0bnb/node/core.py @@ -5,7 +5,6 @@ from ..relaxation import cd_solve, l0gurobi, l0mosek from ._utils import upper_bound_solve -GS_FLAG = False class Node: def __init__(self, parent, zlb: list, zub: list, **kwargs): @@ -24,14 +23,14 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs): Other Parameters ---------------- x: np.array - The data matrix (n x p). If not specified the data will be inherited - from the parent node + The data matrix (n x p). If not specified the data will be + inherited from the parent node y: np.array - The data vector (n x 1). If not specified the data will be inherited - from the parent node + The data vector (n x 1). If not specified the data will be + inherited from the parent node xi_xi: np.array - The norm of each column in x (p x 1). If not specified the data will - be inherited from the parent node + The norm of each column in x (p x 1). If not specified the data + will be inherited from the parent node l0: float The zeroth norm coefficient. If not specified the data will be inherited from the parent node @@ -74,9 +73,11 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs): # Gradient screening params. self.gs_xtr = None self.gs_xb = None - if GS_FLAG and parent: - self.gs_xtr = np.copy(parent.gs_xtr) - self.gs_xb = np.copy(parent.gs_xb) + if parent: + if parent.gs_xtr is not None: + self.gs_xtr = parent.gs_xtr.copy() + if parent.gs_xb is not None: + self.gs_xb = parent.gs_xb.copy() def lower_solve(self, l0, l2, m, solver, rel_tol, int_tol=1e-6, tree_upper_bound=None, mio_gap=None): diff --git a/l0bnb/relaxation/__init__.py b/l0bnb/relaxation/__init__.py index 66c5c65..b9fca04 100644 --- a/l0bnb/relaxation/__init__.py +++ b/l0bnb/relaxation/__init__.py @@ -7,6 +7,8 @@ warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) warnings.simplefilter('ignore', category=NumbaPerformanceWarning) +GS_FLAG = False + from .core import solve as cd_solve from .mosek import l0mosek from .gurobi import l0gurobi diff --git a/l0bnb/relaxation/core.py b/l0bnb/relaxation/core.py index 8fa91a0..195797c 100644 --- a/l0bnb/relaxation/core.py +++ b/l0bnb/relaxation/core.py @@ -9,8 +9,7 @@ from ._coordinate_descent import cd_loop, cd from ._cost import get_primal_cost, get_dual_cost from ._utils import get_ratio_threshold, get_active_components - -GS_FLAG = False +from . import GS_FLAG def is_integral(solution, tol): @@ -105,7 +104,7 @@ def _above_threshold_indices_gs(zub, r, x, y, threshold, gs_xtr, gs_xb, beta): rx[beta_supp] = r @ x[:, beta_supp] above_threshold_restricted = \ - np.where(zub[v_hat] * np.abs(rx_restricted) - threshold > 0)[0] + np.where(zub[v_hat] * np.abs(rx_restricted) - threshold > 0)[0] above_threshold = v_hat[above_threshold_restricted] return above_threshold, rx, gs_xtr, gs_xb @@ -115,6 +114,7 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None, warm_start=None, r=None, rel_tol=1e-4, tree_upper_bound=None, mio_gap=0, check_if_integral=True): + zlb_main, zub_main = zlb.copy(), zub.copy() st = time() _sol_str = \ 'primal_value dual_value support primal_beta sol_time z r gs_xtr gs_xb' @@ -170,8 +170,8 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None, if prim_dual_gap > rel_tol: if is_integral(z_active, 1e-4): ws = {i: j for i, j in zip(active_set, beta_active)} - sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_active, - zub=zub_active, gs_xtr=gs_xtr, gs_xb=gs_xb, + sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_main, + zub=zub_main, gs_xtr=gs_xtr, gs_xb=gs_xb, xi_norm=xi_norm, warm_start=ws, r=r, rel_tol=rel_tol, tree_upper_bound=tree_upper_bound, mio_gap=1, check_if_integral=False) diff --git a/setup.py b/setup.py index 53a261b..2a09093 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ def readme(): author_email="alikassemsaab@gmail.com", url='https://github.com/alisaab/l0bnb', download_url="https://github.com/alisaab/l0bnb/archive/0.0.6.tar.gz", - install_requires=["numpy >= 1.18.1", "scipy >= 1.4.1", "numba >= 0.48.0"], - version="0.0.6", + install_requires=["numpy >= 1.18.1", "scipy >= 1.4.1", "numba >= 0.53.1"], + version="0.1.0", packages=find_packages(), include_package_data=True, license=LICENSE,