Skip to content

Commit

Permalink
refactoring handle_device_shifting to support function without input … (
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulasiraj authored Aug 23, 2023
1 parent 8d1bd2a commit 27636cc
Show file tree
Hide file tree
Showing 19 changed files with 284 additions and 338 deletions.
20 changes: 16 additions & 4 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,15 +818,27 @@ def _handle_device_shifting(*args, **kwargs):
-------
The return of the function.
"""
dev = None
if "device" in kwargs and kwargs["device"] is not None:
dev = ivy.as_native_dev(kwargs["device"])
if ivy.soft_device_mode:
return ivy.handle_soft_device_variable(*args, fn=fn, **kwargs)
return ivy.handle_soft_device_variable(
*args, fn=fn, device_shifting_dev=dev, **kwargs
)
inputs = args + tuple(kwargs.values())
devices = tuple(ivy.dev(x) for x in inputs if ivy.is_native_array(x))
unique_devices = set(devices)
# check if arrays are on the same device
if len(unique_devices) == 1:
with ivy.DefaultDevice(next(iter(unique_devices))):
return ivy.handle_soft_device_variable(*args, fn=fn, **kwargs)
if len(unique_devices) <= 1:
# len(unique_devices) == 0 when there are no arrays
dst_dev = (
dev
if dev is not None
else None if len(unique_devices) == 0 else next(iter(unique_devices))
)
return ivy.handle_soft_device_variable(
*args, fn=fn, device_shifting_dev=dst_dev, **kwargs
)
# raise when arrays are on different devices
elif len(unique_devices) > 1:
raise ivy.utils.exceptions.IvyException(
Expand Down
45 changes: 16 additions & 29 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import ivy
from ivy import as_native_dtype
from ivy.functional.backends.jax import JaxArray
from ivy.functional.backends.jax.device import _to_device
from ivy.functional.ivy.creation import (
asarray_to_native_arrays_and_back,
asarray_infer_device,
Expand Down Expand Up @@ -41,7 +40,7 @@ def arange(
if dtype:
dtype = as_native_dtype(dtype)
ivy.utils.assertions._check_jax_x64_flag(dtype.name)
res = _to_device(jnp.arange(start, stop, step, dtype=dtype), device=device)
res = jnp.arange(start, stop, step, dtype=dtype)
if not dtype:
if res.dtype == jnp.float64:
return res.astype(jnp.float32)
Expand Down Expand Up @@ -75,9 +74,9 @@ def asarray(
) -> JaxArray:
ivy.utils.assertions._check_jax_x64_flag(dtype)
if copy is True:
return _to_device(jnp.array(obj, dtype=dtype, copy=True), device=device)
return jnp.array(obj, dtype=dtype, copy=True)
else:
return _to_device(jnp.asarray(obj, dtype=dtype), device=device)
return jnp.asarray(obj, dtype=dtype)


def empty(
Expand All @@ -87,7 +86,7 @@ def empty(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(jnp.empty(shape, dtype), device=device)
return jnp.empty(shape, dtype)


def empty_like(
Expand All @@ -98,7 +97,7 @@ def empty_like(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(jnp.empty_like(x, dtype=dtype), device=device)
return jnp.empty_like(x, dtype=dtype)


def eye(
Expand All @@ -116,11 +115,11 @@ def eye(
n_cols = n_rows
i = jnp.eye(n_rows, n_cols, k, dtype)
if batch_shape is None:
return _to_device(i, device=device)
return i
reshape_dims = [1] * len(batch_shape) + [n_rows, n_cols]
tile_dims = list(batch_shape) + [1, 1]
return_mat = jnp.tile(jnp.reshape(i, reshape_dims), tile_dims)
return _to_device(return_mat, device=device)
return return_mat


def from_dlpack(x, /, *, out: Optional[JaxArray] = None) -> JaxArray:
Expand All @@ -138,10 +137,7 @@ def full(
) -> JaxArray:
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
ivy.utils.assertions.check_fill_value_and_dtype_are_compatible(fill_value, dtype)
return _to_device(
jnp.full(shape, fill_value, dtype),
device=device,
)
return jnp.full(shape, fill_value, dtype)


def full_like(
Expand All @@ -154,10 +150,7 @@ def full_like(
out: Optional[JaxArray] = None,
) -> JaxArray:
ivy.utils.assertions.check_fill_value_and_dtype_are_compatible(fill_value, dtype)
return _to_device(
jnp.full_like(x, fill_value, dtype=dtype),
device=device,
)
return jnp.full_like(x, fill_value, dtype=dtype)


# https://github.com/google/jax/blob/8b2e4f975c8c830502f5cc749b7253b02e78c9e8/jax/_src/numpy/lax_numpy.py#L2164
Expand Down Expand Up @@ -229,7 +222,7 @@ def linspace(

ans = jax.lax.convert_element_type(out, dtype)

return _to_device(ans, device=device)
return ans


def meshgrid(
Expand All @@ -248,7 +241,7 @@ def ones(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(jnp.ones(shape, dtype), device=device)
return jnp.ones(shape, dtype)


def ones_like(
Expand All @@ -259,7 +252,7 @@ def ones_like(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(jnp.ones_like(x, dtype=dtype), device=device)
return jnp.ones_like(x, dtype=dtype)


def tril(x: JaxArray, /, *, k: int = 0, out: Optional[JaxArray] = None) -> JaxArray:
Expand All @@ -277,10 +270,7 @@ def zeros(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(
jnp.zeros(shape, dtype),
device=device,
)
return jnp.zeros(shape, dtype)


def zeros_like(
Expand All @@ -291,7 +281,7 @@ def zeros_like(
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
return _to_device(jnp.zeros_like(x, dtype=dtype), device=device)
return jnp.zeros_like(x, dtype=dtype)


# Extra #
Expand Down Expand Up @@ -347,7 +337,7 @@ def one_hot(
if axis is not None:
res = jnp.moveaxis(res, -1, axis)

return _to_device(res, device)
return res


def frombuffer(
Expand All @@ -367,7 +357,4 @@ def triu_indices(
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray]:
return _to_device(
jnp.triu_indices(n=n_rows, k=k, m=n_cols),
device=device,
)
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)
13 changes: 12 additions & 1 deletion ivy/functional/backends/jax/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
# local
import ivy
from ivy.functional.backends.jax import JaxArray
from ivy.functional.ivy.device import Profiler as BaseProfiler
from ivy.functional.ivy.device import (
_shift_native_arrays_on_default_device,
Profiler as BaseProfiler,
)


# Helpers #
Expand Down Expand Up @@ -96,6 +99,14 @@ def as_native_dev(device, /):
return jax.devices(device)[idx]


def handle_soft_device_variable(*args, fn, device_shifting_dev=None, **kwargs):
args, kwargs, device_shifting_dev = _shift_native_arrays_on_default_device(
*args, device_shifting_dev=device_shifting_dev, **kwargs
)
with jax.default_device(device_shifting_dev):
return fn(*args, **kwargs)


def clear_cached_mem_on_dev(device: str, /):
return None

Expand Down
6 changes: 1 addition & 5 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# local
from ivy.functional.backends.jax import JaxArray
from ivy.functional.backends.jax.device import _to_device
import ivy

# Array API Standard #
Expand Down Expand Up @@ -75,10 +74,7 @@ def tril_indices(
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray, ...]:
return _to_device(
jnp.tril_indices(n=n_rows, k=k, m=n_cols),
device=device,
)
return jnp.tril_indices(n=n_rows, k=k, m=n_cols)


def unsorted_segment_min(
Expand Down
12 changes: 4 additions & 8 deletions ivy/functional/backends/jax/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
_check_bounds_and_get_shape,
_check_shapes_broadcastable,
)
from ivy.functional.backends.jax.device import to_device
from ivy.func_wrapper import with_unsupported_dtypes
from .. import backend_version

Expand Down Expand Up @@ -54,7 +53,7 @@ def beta(
_setRNG(RNG_)
if seed is not None:
jax.random.PRNGKey(seed)
return to_device(jax.random.beta(rng_input, a, b, shape, dtype), device)
return jax.random.beta(rng_input, a, b, shape, dtype)


@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16",)}, backend_version)
Expand All @@ -74,7 +73,7 @@ def gamma(
_setRNG(RNG_)
if seed is not None:
jax.random.PRNGKey(seed)
return to_device(jax.random.gamma(rng_input, alpha, shape, dtype) / beta, device)
return jax.random.gamma(rng_input, alpha, shape, dtype) / beta


def poisson(
Expand Down Expand Up @@ -105,10 +104,7 @@ def poisson(
ret = jnp.where(lam < 0, fill_value, ret)
else:
ret = jax.random.poisson(rng_input, lam, shape=list_shape).astype(dtype)
return to_device(
ret,
device,
)
return ret


def bernoulli(
Expand All @@ -130,4 +126,4 @@ def bernoulli(
probs = jax.nn.softmax(logits, axis=-1)
if not _check_shapes_broadcastable(shape, probs.shape):
shape = probs.shape
return to_device(jax.random.bernoulli(rng_input, probs, shape=shape), device)
return jax.random.bernoulli(rng_input, probs, shape=shape)
10 changes: 5 additions & 5 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.backends.jax.device import _to_device, _to_array
from ivy.functional.backends.jax.device import _to_array
from ivy.functional.ivy.general import _broadcast_to
from ivy.functional.backends.jax import JaxArray, NativeArray
from . import backend_version
Expand Down Expand Up @@ -148,7 +148,7 @@ def gather(
result.append(r)
result = jnp.array(result)
result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
return _to_device(result)
return result


def gather_nd_helper(params, indices):
Expand Down Expand Up @@ -209,7 +209,7 @@ def gather_nd(
result.append(r)
result = jnp.array(result)
result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
return _to_device(result)
return result


def get_num_dims(x: JaxArray, /, *, as_array: bool = False) -> Union[JaxArray, int]:
Expand Down Expand Up @@ -390,8 +390,8 @@ def scatter_nd(
'"sum", "min", "max" or "replace"'.format(reduction)
)
if ivy.exists(out):
return ivy.inplace_update(out, _to_device(target))
return _to_device(target)
return ivy.inplace_update(out, target)
return target


scatter_nd.support_native_out = True
Expand Down
24 changes: 5 additions & 19 deletions ivy/functional/backends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
_check_valid_scale,
)
from ivy.functional.backends.jax import JaxArray
from ivy.functional.backends.jax.device import to_device
from ivy.func_wrapper import with_unsupported_dtypes
from . import backend_version

Expand Down Expand Up @@ -57,11 +56,8 @@ def random_uniform(
else:
RNG_, rng_input = jax.random.split(_getRNG())
_setRNG(RNG_)
return to_device(
jax.random.uniform(
rng_input, shape, minval=low, maxval=high, dtype=jnp.float32
),
device,
return jax.random.uniform(
rng_input, shape, minval=low, maxval=high, dtype=jnp.float32
).astype(dtype)


Expand All @@ -83,14 +79,7 @@ def random_normal(
else:
RNG_, rng_input = jax.random.split(_getRNG())
_setRNG(RNG_)
return (
to_device(
jax.random.normal(rng_input, shape, dtype=dtype),
device,
)
* std
+ mean
)
return jax.random.normal(rng_input, shape, dtype=dtype) * std + mean


@with_unsupported_dtypes({"0.4.14 and below": ("bfloat16",)}, backend_version)
Expand Down Expand Up @@ -134,10 +123,7 @@ def multinomial(
for prob in probs_stack
]
samples_flat = jnp.stack(samples_stack)
return to_device(
jnp.reshape(samples_flat, orig_probs_shape[:-1] + [num_samples]),
device,
)
return jnp.reshape(samples_flat, orig_probs_shape[:-1] + [num_samples])


def randint(
Expand All @@ -163,7 +149,7 @@ def randint(
RNG_, rng_input = jax.random.split(_getRNG())
_setRNG(RNG_)

return to_device(jax.random.randint(rng_input, shape, low, high, dtype), device)
return jax.random.randint(rng_input, shape, low, high, dtype)


def seed(*, seed_value: int = 0) -> None:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def to_device(
return x


def handle_soft_device_variable(*args, fn, **kwargs):
def handle_soft_device_variable(*args, fn, device_shifting_dev=None, **kwargs):
return fn(*args, **kwargs)


Expand Down
Loading

0 comments on commit 27636cc

Please sign in to comment.