Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ToCupy operator #622

Merged
merged 6 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/cupy_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/numpy_cupy_bd_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/numpy_cupy_vs_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Basic operators
Real
Imag
Conj
ToCupy


Smoothing and derivatives
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
226 changes: 169 additions & 57 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,172 @@ provide data vectors to the solvers, e.g., when using
For JAX, apart from following the same procedure described for CuPy, the PyLops operator must
be also wrapped into a :class:`pylops.JaxOperator`.

See below for a comphrensive list of supported operators and additional functionalities for both the
``cupy`` and ``jax`` backends.


Examples
--------

Let's now briefly look at some use cases.

End-to-end GPU powered inverse problems
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

First we consider the most common scenario when both the model and data
vectors fit onto the GPU memory. We can therefore simply replace all our
``numpy`` arrays with ``cupy`` arrays and solve the inverse problem of
interest end-to-end on the GPU.

.. image:: _static/cupy_diagram.png
:width: 600
:align: center

Let's first write a code snippet using ``numpy`` arrays, which PyLops
will run on your CPU:

.. code-block:: python

ny, nx = 400, 400
G = np.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = np.ones(nx, dtype=np.float32)

# Create operator
Gop = MatrixMult(G, dtype='float32')

# Create data and invert
y = Gop @ x
xest = Gop / y

Now we write a code snippet using ``cupy`` arrays, which PyLops will run on
your GPU:

.. code-block:: python

ny, nx = 400, 400
G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = cp.ones(nx, dtype=np.float32)

# Create operator
Gop = MatrixMult(G, dtype='float32')

# Create data and invert
y = Gop @ x
xest = Gop / y

The code is almost unchanged apart from the fact that we now use ``cupy`` arrays,
PyLops will figure this out.

Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on
your GPU/TPU:

.. code-block:: python

ny, nx = 400, 400
G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
x = jnp.ones(nx, dtype=np.float32)

# Create operator
Gop = JaxOperator(MatrixMult(G, dtype='float32'))

# Create data and invert
y = Gop @ x
xest = Gop / y

# Adjoint via AD
xadj = Gop.rmatvecad(x, y)

Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays.


Mixed CPU-GPU powered inverse problems
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Let us now consider a more intricate scenario where we have access to
a GPU-powered operator, however the model and/or data vectors are too large
to fit onto the GPU memory (or VRAM).

For the sake of clarity, we consider a problem where
the operator can be written as a :class:`pylops.BlockDiag` of
PyLops operators. Note how, by simply sandwiching any of the GPU-powered
operator within two :class:`pylops.ToCupy` operators, we are
able to tell PyLops to transfer to the GPU only the part of the model vector
required by a given operator and transfer back the output to the CPU before
forming the combine output vector (i.e., the output vector of the
:class:`pylops.BlockDiag`).

.. image:: _static/numpy_cupy_bd_diagram.png
:width: 1000
:align: center

.. code-block:: python

nops, n = 5, 4
Ms = [np.diag((i + 1) * np.ones(n, dtype=dtype)) \
for i in range(nops)]
Ms = [M.T @ M for M in Ms]

# Create operator
Mops = []
for iop in range(nops):
Mop = MatrixMult(cp.asarray(Ms[iop], dtype=dtype))
Top = ToCupy(Mop.dims, dtype=dtype)
Top1 = ToCupy(Mop.dimsd, dtype=dtype)
Mop = Top1.H @ Mop @ Top
Mops.append(Mop)
Mops = BlockDiag(Mops, forceflat=True)

# Create data and invert
x = np.ones(n * nops, dtype=dtype)
y = Mops @ x.ravel()
xest = Mops / y


Finally, let us consider a problem where
the operator can be written as a :class:`pylops.VStack` of
PyLops operators and the model vector can be fully transferred to the GPU.
We can use again the :class:`pylops.ToCupy` operator, however this
time we will only use it to move the output of each operator to the CPU.
Since we are now in a special scenario, where the input of the overall
operator sits on the GPU and the output on the
CPU, we need to inform the :class:`pylops.VStack` operator about this.
This can be easily done using the additional ``inoutengine`` parameter. Let's
see this with an example.

.. image:: _static/numpy_cupy_vs_diagram.png
:width: 1000
:align: center

.. code-block:: python

nops, n, m = 3, 4, 5
Ms = [np.random.normal(0, 1, (n, m)) for _ in range(nops)]

# Create operator
Mops = []
for iop in range(nops):
Mop = MatrixMult(cp.asarray(Ms[iop]), dtype=dtype)
Top1 = ToCupy(Mop.dimsd, dtype=dtype)
Mop = Top1.H @ Mop
Mops.append(Mop)
Mops = VStack(Mops, inoutengine=("numpy", "cupy"))

# Create data and invert
x = cp.ones(m, dtype=dtype)
y = Mops @ x.ravel()
xest = pylops_cgls(Mops, y, x0=cp.zeros_like(x))[0]

These features are currently not available for ``jax`` arrays.


.. note::

More examples for the CuPy and JAX backends be found at `link1 <https://github.com/PyLops/pylops_notebooks/tree/master/developement-cupy>`_
and `link2 <https://github.com/PyLops/pylops_notebooks/tree/master/developement/Basic_JAX.ipynb>`_.


Supported Operators
-------------------

In the following, we provide a list of methods in :class:`pylops.LinearOperator` with their current status (available on CPU,
GPU with CuPy, and GPU with JAX):
Expand Down Expand Up @@ -195,6 +361,7 @@ Smoothing and derivatives:
- |:white_check_mark:|
- |:white_check_mark:|


Signal processing:

.. list-table::
Expand Down Expand Up @@ -322,6 +489,7 @@ Signal processing:
- |:white_check_mark:|
- |:white_check_mark:|


Wave-Equation processing

.. list-table::
Expand Down Expand Up @@ -369,6 +537,7 @@ Wave-Equation processing
- |:red_circle:|
- |:red_circle:|


Geophysical subsurface characterization:

.. list-table::
Expand Down Expand Up @@ -407,60 +576,3 @@ Geophysical subsurface characterization:
operator currently works only with ``explicit=True`` due to the same issue as
in point 1 for the :class:`pylops.signalprocessing.Convolve1D` operator employed
when ``explicit=False``.


Example
-------

Finally, let's briefly look at an example. First we write a code snippet using
``numpy`` arrays which PyLops will run on your CPU:

.. code-block:: python

ny, nx = 400, 400
G = np.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = np.ones(nx, dtype=np.float32)

Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y

Now we write a code snippet using ``cupy`` arrays which PyLops will run on
your GPU:

.. code-block:: python

ny, nx = 400, 400
G = cp.random.normal(0, 1, (ny, nx)).astype(np.float32)
x = cp.ones(nx, dtype=np.float32)

Gop = MatrixMult(G, dtype='float32')
y = Gop * x
xest = Gop / y

The code is almost unchanged apart from the fact that we now use ``cupy`` arrays,
PyLops will figure this out.

Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on
your GPU/TPU:

.. code-block:: python

ny, nx = 400, 400
G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
x = jnp.ones(nx, dtype=np.float32)

Gop = JaxOperator(MatrixMult(G, dtype='float32'))
y = Gop * x
xest = Gop / y

# Adjoint via AD
xadj = Gop.rmatvecad(x, y)


Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays,

.. note::

More examples for the CuPy and JAX backends be found `here <https://github.com/PyLops/pylops_notebooks/tree/master/developement-cupy>`__
and `here <https://github.com/PyLops/pylops_notebooks/tree/master/developement/Basic_JAX.ipynb>`__.
4 changes: 4 additions & 0 deletions pylops/basicoperators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Gradient Gradient.
FirstDirectionalDerivative First Directional derivative.
SecondDirectionalDerivative Second Directional derivative.
ToCupy Convert to CuPy.
"""

from .functionoperator import *
Expand Down Expand Up @@ -72,6 +73,8 @@
from .laplacian import *
from .gradient import *
from .directionalderivative import *
from .tocupy import *


__all__ = [
"FunctionOperator",
Expand Down Expand Up @@ -107,4 +110,5 @@
"Gradient",
"FirstDirectionalDerivative",
"SecondDirectionalDerivative",
"ToCupy",
]
22 changes: 19 additions & 3 deletions pylops/basicoperators/blockdiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.backend import get_array_module, get_module, inplace_set
from pylops.utils.typing import DTypeLike, NDArray


Expand All @@ -48,6 +48,12 @@ class BlockDiag(LinearOperator):
.. versionadded:: 2.2.0

Force an array to be flattened after matvec and rmatvec.
inoutengine : :obj:`tuple`, optional
.. versionadded:: 2.4.0

Type of output vectors of `matvec` and `rmatvec. If ``None``, this is
inferred directly from the input vectors. Note that this is ignored
if ``nproc>1``.
dtype : :obj:`str`, optional
Type of elements in input array.

Expand Down Expand Up @@ -113,6 +119,7 @@ def __init__(
ops: Sequence[LinearOperator],
nproc: int = 1,
forceflat: bool = None,
inoutengine: Optional[tuple] = None,
dtype: Optional[DTypeLike] = None,
) -> None:
self.ops = ops
Expand Down Expand Up @@ -149,6 +156,7 @@ def __init__(
if self.nproc > 1:
self.pool = mp.Pool(processes=nproc)

self.inoutengine = inoutengine
dtype = _get_dtype(ops) if dtype is None else np.dtype(dtype)
clinear = all([getattr(oper, "clinear", True) for oper in self.ops])
super().__init__(
Expand All @@ -172,7 +180,11 @@ def nproc(self, nprocnew: int) -> None:
self._nproc = nprocnew

def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
ncp = (
get_array_module(x)
if self.inoutengine is None
else get_module(self.inoutengine[0])
)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
y = inplace_set(
Expand All @@ -183,7 +195,11 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
return y

def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
ncp = (
get_array_module(x)
if self.inoutengine is None
else get_module(self.inoutengine[1])
)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
y = inplace_set(
Expand Down
Loading