Skip to content

Commit

Permalink
Rollback for: Implement initial vmap over pallas_call w/ ragged input…
Browse files Browse the repository at this point in the history
…s (via jumbles)

It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32.

Reverts 24394a1

PiperOrigin-RevId: 664742891
  • Loading branch information
apaszke authored and jax authors committed Aug 19, 2024
1 parent dad2f57 commit 66a3f87
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 560 deletions.
24 changes: 3 additions & 21 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,7 +1954,6 @@ def __init__(self, aval, data):
assert data.shape == pad_shape
self._aval = aval
self._data = data

shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
aval = property(lambda self: self._aval)
Expand All @@ -1965,38 +1964,21 @@ def __repr__(self) -> str:

dtypestr = _short_dtype_name(self._aval.dtype)
shapestr = ','.join(map(str, self.shape))
data = self.data
slices = tuple(slice(int(d._data)) if type(d) is DArray and
type(d.dtype) is bint else slice(None) for d in self.shape)
data = self._data[slices]
return f'{dtypestr}[{shapestr}] with value: {data}'

def __hash__(self) -> int:
if not self.shape:
return hash((self._aval, int(self._data)))
raise TypeError("unhashable type: DArray")

def __eq__(self, other):
if isinstance(other, DArray) and self._aval == other._aval:
return self._data == other._data
return False

def __len__(self):
return self.shape[0]

@property
def data(self):
if not self.shape and type(self.dtype) is bint:
# special-case scalar bints
return self._data

slices = tuple(
slice(int(d._data))
if type(d) is DArray and type(d.dtype) is bint
else slice(None)
for d in self.shape
)
data = self._data[slices]
return data


pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)
Expand Down
6 changes: 1 addition & 5 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def _jumble_flatten(jumble):
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval

def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
Expand Down Expand Up @@ -252,10 +251,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
# TODO(mvoz): This is a terrible place to fall into if you pass
# a non jumble type in, make it clearer what went wrong.
assert False, f'Unexpected type in ELT? {type(x)}'

assert False
to_elt_handlers: dict[type, ToEltHandler] = {}

def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
Expand Down
53 changes: 12 additions & 41 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ class AbstractMemoryRef(state.AbstractRef):

def __init__(self, inner_aval: jax_core.AbstractValue,
memory_space: Any):

assert isinstance(
inner_aval, jax_core.ShapedArray
), f"Illegal ref, got {type(inner_aval)}"
assert isinstance(inner_aval, jax_core.ShapedArray)
self.inner_aval = inner_aval
self.memory_space = memory_space

Expand Down Expand Up @@ -170,7 +167,9 @@ class PallasGridContext:
mapped_dims: tuple[int, ...]

def size(self, axis: int) -> int | DynamicGridDim:
valid_grid = tuple(self.grid)
valid_grid = tuple(
s for i, s in enumerate(self.grid) if i not in self.mapped_dims
)
try:
size = valid_grid[axis]
except IndexError as e:
Expand Down Expand Up @@ -339,10 +338,7 @@ def check_invariants(self) -> None:
)

assert not self.index_map_jaxpr.consts
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), (
self.block_shape,
self.index_map_jaxpr.out_avals,
)
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals)
assert all(ov.shape == () and
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
for ov in self.index_map_jaxpr.out_avals), (
Expand Down Expand Up @@ -426,8 +422,6 @@ class GridMapping:
num_inputs: int
num_outputs: int
num_scratch_operands: int
get_grid_indices: Callable | None = None
local_grid_env: Callable | None = None

def check_invariants(self) -> None:
if not config.enable_checks.value: return
Expand All @@ -448,8 +442,8 @@ def check_invariants(self) -> None:
assert len(index_map_args) >= len(self.grid)
for i in range(len(self.grid)):
index_map_arg = index_map_args[i]
assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}"
assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}"
assert index_map_arg.shape == ()
assert index_map_arg.dtype == jnp.int32

assert len(self.vmapped_dims) <= len(self.grid)
for i in self.vmapped_dims:
Expand All @@ -460,11 +454,8 @@ def check_invariants(self) -> None:

for bm in self.block_mappings:
bm.check_invariants()
assert tuple(self.index_map_avals) == tuple(
bm.index_map_jaxpr.in_avals
), (
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
self.index_map_avals,
"|",
bm.index_map_jaxpr.in_avals,
)

Expand Down Expand Up @@ -556,17 +547,6 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
return True
return jax_core.is_dim(dim)


def _max_shape_from_aval(array_aval: jax_core.ShapedArray):
array_aval_shape = list(array_aval.shape)
for i, s in enumerate(array_aval.shape):
aval = jax_core.get_aval(s)
if isinstance(aval, jax_core.DShapedArray):
array_aval_shape[i] = aval.dtype.bound

return tuple(array_aval_shape)


def _convert_block_spec_to_block_mapping(
block_spec: BlockSpec,
origin: OriginStr,
Expand Down Expand Up @@ -595,15 +575,8 @@ def _convert_block_spec_to_block_mapping(
f"array shape {array_aval.shape}.")

unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_array_aval = array_aval.update(shape=unmapped_block_shape)
if isinstance(array_aval, jax_core.DShapedArray):
# Get the "max" shape for the ragged array.
block_array_aval = jax_core.ShapedArray(
block_array_aval.shape,
block_array_aval.dtype,
block_array_aval.weak_type,
)
block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space)
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
block_spec.memory_space)

if not jax_core.is_constant_shape(block_aval.shape):
raise ValueError(
Expand Down Expand Up @@ -636,22 +609,20 @@ def _convert_block_spec_to_block_mapping(
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}.")


if consts:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}")

array_aval_shape = _max_shape_from_aval(array_aval)

mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
),
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
origin=origin,
)
mapping.check_invariants()
Expand Down
28 changes: 7 additions & 21 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.vmapped_dims
# TODO(mvoz): Generalize to not need this
user_grid = tuple(
g for i, g in enumerate(self.grid) if i not in self.mapped_dims
)
Expand Down Expand Up @@ -346,19 +345,9 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
for _ in range(len(self.grid))
])
self._prepare_mesh_info(mesh)

if grid_mapping.get_grid_indices is None:

def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
if maybe_include_mapped_dims:
return indices
return tuple(
idx for i, idx in enumerate(indices) if i not in self.mapped_dims
)

self.get_grid_indices = _get_grid_indices
else:
self.get_grid_indices = grid_mapping.get_grid_indices
def _get_grid_indices(indices):
return indices
self.get_grid_indices = _get_grid_indices

def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
if not self.has_communication:
Expand Down Expand Up @@ -606,9 +595,7 @@ def lower_jaxpr_to_transform_func(
]
def body_func(*args):
grid_indices, scalar_prefetch = split_list(args, [num_grid])
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
grid_indices, maybe_include_mapped_dims=True
)
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
arg_block_shapes = [
*[()] * len(jaxpr_indices),
*mosaic_grid_mapping.scalar_prefetch_block_shapes,
Expand Down Expand Up @@ -676,9 +663,9 @@ def lower_jaxpr_to_func(
def body_func(*args):
grid_indices, scalar_prefetch, operands_and_scratch = split_list(
args, [num_grid, num_scalar_prefetch])
jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
grid_indices, maybe_include_mapped_dims=False
)
grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices)
jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices)
if i not in mosaic_grid_mapping.mapped_dims)
mesh_info = mosaic_grid_mapping.mesh_info
if mesh_info is not None:
mesh_context = MeshContext(
Expand Down Expand Up @@ -2378,7 +2365,6 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):


def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):

if ctx.lowering_context.user_grid_indices is None:
raise ValueError(
f"program id: {axis} was passed, but user did not provide a grid."
Expand Down
Loading

0 comments on commit 66a3f87

Please sign in to comment.