Skip to content

Latest commit

 

History

History
64 lines (48 loc) · 1.24 KB

installation.md

File metadata and controls

64 lines (48 loc) · 1.24 KB

Installation

Stable version

The latest stable release of GPJax can be installed via pip:

pip install gpjax

:::{note} We recommend you check your installation version:

python -c 'import gpjax; print(gpjax.__version__)'

:::

GPU support

GPU support is enabled through proper configuration of the underlying Jax installation. CPU enabled forms of both packages are installed as part of the GPJax installation. For GPU Jax support, the following commands should be run:

# Specify your installed CUDA version.
CUDA_VERSION=11.0
pip install jaxlib

Then, within a Python shell run

import jaxlib
print(jaxlib.__version__)

Development version

:::{warning} This version is possibly unstable and may contain bugs. :::

The latest development version of GPJax can be installed via running following:

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop

:::{tip} We advise you create virtual environment before installing:

conda create -n gpjax_experimental python=3.10.0
conda activate gpjax_experimental

and recommend you check your installation passes the supplied unit tests:

python -m pytest tests/

:::