Skip to content

Commit f99dd1c

Browse files
install jax jaxlib at runtime
1 parent aad3ad2 commit f99dd1c

18 files changed

+1737
-1635
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,8 @@ install: clean ## install the package to the active Python's site-packages
9696
run-examples: ## run all examples with one command
9797
find examples -maxdepth 2 -name "*.py" -exec python3 {} \;
9898

99+
run-booster: ## run all boosting estimators examples with one command
100+
find examples -maxdepth 2 -name "*boost_*.py" -exec python3 {} \;
101+
99102
run-lazy: ## run all lazy estimators examples with one command
100103
find examples -maxdepth 2 -name "*lazy*.py" -exec python3 {} \;

mlsauce.egg-info/PKG-INFO

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: mlsauce
3-
Version: 0.20.1
3+
Version: 0.20.2
44
Summary: Miscellaneous Statistical/Machine Learning tools
55
Maintainer: T. Moudiki
66
Maintainer-email: thierry.moudiki@gmail.com
@@ -29,8 +29,6 @@ Requires-Dist: requests
2929
Requires-Dist: scikit-learn
3030
Requires-Dist: scipy
3131
Requires-Dist: tqdm
32-
Requires-Dist: jax
33-
Requires-Dist: jaxlib
3432
Provides-Extra: alldeps
3533
Requires-Dist: numpy>=1.13.0; extra == "alldeps"
3634
Requires-Dist: scipy>=0.19.0; extra == "alldeps"

mlsauce.egg-info/requires.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ requests
88
scikit-learn
99
scipy
1010
tqdm
11-
jax
12-
jaxlib
1311

1412
[alldeps]
1513
numpy>=1.13.0

mlsauce/booster/_booster_classifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from . import _boosterc as boosterc
1111
except ImportError:
1212
import _boosterc as boosterc
13-
from ..utils import cluster
13+
from ..utils import cluster, check_and_install
1414

1515

1616
class LSBoostClassifier(BaseEstimator, ClassifierMixin):
@@ -167,6 +167,9 @@ def __init__(
167167
self.degree = degree
168168
self.poly_ = None
169169
self.weights_distr = weights_distr
170+
if self.backend in ("gpu", "tpu"):
171+
check_and_install("jax")
172+
check_and_install("jaxlib")
170173

171174
def fit(self, X, y, **kwargs):
172175
"""Fit Booster (classifier) to training data (X, y)

mlsauce/booster/_booster_regressor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
except ImportError:
1212
import _boosterc as boosterc
1313
from ..predictioninterval import PredictionInterval
14-
from ..utils import cluster
14+
from ..utils import cluster, check_and_install
1515

1616

1717
class LSBoostRegressor(BaseEstimator, RegressorMixin):
@@ -183,6 +183,9 @@ def __init__(
183183
self.degree = degree
184184
self.poly_ = None
185185
self.weights_distr = weights_distr
186+
if self.backend in ("gpu", "tpu"):
187+
check_and_install("jax")
188+
check_and_install("jaxlib")
186189

187190
def fit(self, X, y, **kwargs):
188191
"""Fit Booster (regressor) to training data (X, y)

mlsauce/elasticnet/enet.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from sklearn.base import BaseEstimator
55
from sklearn.base import RegressorMixin
66
from numpy.linalg import inv
7-
from ..utils import get_beta
7+
from ..utils import get_beta, check_and_install
88
from ._enet import fit_elasticnet, predict_elasticnet
99

10-
if platform.system() in ("Linux", "Darwin"):
10+
try:
1111
import jax.numpy as jnp
1212
from jax import device_put
1313
from jax.numpy.linalg import inv as jinv
14+
except ImportError:
15+
pass
1416

1517

1618
class ElasticNetRegressor(BaseEstimator, RegressorMixin):
@@ -48,6 +50,9 @@ def __init__(self, reg_lambda=0.1, alpha=0.5, backend="cpu"):
4850
self.reg_lambda = reg_lambda
4951
self.alpha = alpha
5052
self.backend = backend
53+
if self.backend in ("gpu", "tpu"):
54+
check_and_install("jax")
55+
check_and_install("jaxlib")
5156

5257
def fit(self, X, y, **kwargs):
5358
"""Fit matrixops (classifier) to training data (X, y)

mlsauce/lasso/_lasso.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from . import _lassoc as mo
1111
except ImportError:
1212
import _lassoc as mo
13-
from ..utils import get_beta
13+
from ..utils import get_beta, check_and_install
1414

15-
if platform.system() in ("Linux", "Darwin"):
15+
try:
1616
import jax.numpy as jnp
1717
from jax import device_put
1818
from jax.numpy.linalg import inv as jinv
19+
except ImportError:
20+
pass
1921

2022

2123
class LassoRegressor(BaseEstimator, RegressorMixin):
@@ -56,6 +58,9 @@ def __init__(self, reg_lambda=0.1, max_iter=10, tol=1e-3, backend="cpu"):
5658
self.max_iter = max_iter
5759
self.tol = tol
5860
self.backend = backend
61+
if self.backend in ("gpu", "tpu"):
62+
check_and_install("jax")
63+
check_and_install("jaxlib")
5964

6065
def fit(self, X, y, **kwargs):
6166
"""Fit matrixops (classifier) to training data (X, y)

0 commit comments

Comments
 (0)