Skip to content

Commit

Permalink
new version: 0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaab committed Apr 19, 2021
1 parent 059ce32 commit ec2a5ae
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
21 changes: 11 additions & 10 deletions l0bnb/node/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..relaxation import cd_solve, l0gurobi, l0mosek
from ._utils import upper_bound_solve

GS_FLAG = False

class Node:
def __init__(self, parent, zlb: list, zub: list, **kwargs):
Expand All @@ -24,14 +23,14 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs):
Other Parameters
----------------
x: np.array
The data matrix (n x p). If not specified the data will be inherited
from the parent node
The data matrix (n x p). If not specified the data will be
inherited from the parent node
y: np.array
The data vector (n x 1). If not specified the data will be inherited
from the parent node
The data vector (n x 1). If not specified the data will be
inherited from the parent node
xi_xi: np.array
The norm of each column in x (p x 1). If not specified the data will
be inherited from the parent node
The norm of each column in x (p x 1). If not specified the data
will be inherited from the parent node
l0: float
The zeroth norm coefficient. If not specified the data will
be inherited from the parent node
Expand Down Expand Up @@ -74,9 +73,11 @@ def __init__(self, parent, zlb: list, zub: list, **kwargs):
# Gradient screening params.
self.gs_xtr = None
self.gs_xb = None
if GS_FLAG and parent:
self.gs_xtr = np.copy(parent.gs_xtr)
self.gs_xb = np.copy(parent.gs_xb)
if parent:
if parent.gs_xtr is not None:
self.gs_xtr = parent.gs_xtr.copy()
if parent.gs_xb is not None:
self.gs_xb = parent.gs_xb.copy()

def lower_solve(self, l0, l2, m, solver, rel_tol, int_tol=1e-6,
tree_upper_bound=None, mio_gap=None):
Expand Down
2 changes: 2 additions & 0 deletions l0bnb/relaxation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPerformanceWarning)

GS_FLAG = False

from .core import solve as cd_solve
from .mosek import l0mosek
from .gurobi import l0gurobi
10 changes: 5 additions & 5 deletions l0bnb/relaxation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from ._coordinate_descent import cd_loop, cd
from ._cost import get_primal_cost, get_dual_cost
from ._utils import get_ratio_threshold, get_active_components

GS_FLAG = False
from . import GS_FLAG


def is_integral(solution, tol):
Expand Down Expand Up @@ -105,7 +104,7 @@ def _above_threshold_indices_gs(zub, r, x, y, threshold, gs_xtr, gs_xb, beta):
rx[beta_supp] = r @ x[:, beta_supp]

above_threshold_restricted = \
np.where(zub[v_hat] * np.abs(rx_restricted) - threshold > 0)[0]
np.where(zub[v_hat] * np.abs(rx_restricted) - threshold > 0)[0]
above_threshold = v_hat[above_threshold_restricted]

return above_threshold, rx, gs_xtr, gs_xb
Expand All @@ -115,6 +114,7 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None,
warm_start=None, r=None,
rel_tol=1e-4, tree_upper_bound=None, mio_gap=0,
check_if_integral=True):
zlb_main, zub_main = zlb.copy(), zub.copy()
st = time()
_sol_str = \
'primal_value dual_value support primal_beta sol_time z r gs_xtr gs_xb'
Expand Down Expand Up @@ -170,8 +170,8 @@ def solve(x, y, l0, l2, m, zlb, zub, gs_xtr, gs_xb, xi_norm=None,
if prim_dual_gap > rel_tol:
if is_integral(z_active, 1e-4):
ws = {i: j for i, j in zip(active_set, beta_active)}
sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_active,
zub=zub_active, gs_xtr=gs_xtr, gs_xb=gs_xb,
sol = solve(x=x, y=y, l0=l0, l2=l2, m=m, zlb=zlb_main,
zub=zub_main, gs_xtr=gs_xtr, gs_xb=gs_xb,
xi_norm=xi_norm, warm_start=ws, r=r,
rel_tol=rel_tol, tree_upper_bound=tree_upper_bound,
mio_gap=1, check_if_integral=False)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def readme():
author_email="alikassemsaab@gmail.com",
url='https://github.com/alisaab/l0bnb',
download_url="https://github.com/alisaab/l0bnb/archive/0.0.6.tar.gz",
install_requires=["numpy >= 1.18.1", "scipy >= 1.4.1", "numba >= 0.48.0"],
version="0.0.6",
install_requires=["numpy >= 1.18.1", "scipy >= 1.4.1", "numba >= 0.53.1"],
version="0.1.0",
packages=find_packages(),
include_package_data=True,
license=LICENSE,
Expand Down

0 comments on commit ec2a5ae

Please sign in to comment.