Skip to content

Commit

Permalink
Resolve linter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjfu committed Oct 18, 2024
1 parent dbe5cae commit 65a0b13
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 24 deletions.
1 change: 0 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@

from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import jaxpr_passes
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/interpreters/jaxpr_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Iterable, Iterator, Sequence
from collections.abc import Callable, Sequence
import dataclasses
import functools
from functools import partial
Expand Down Expand Up @@ -70,7 +70,7 @@ def _rule(ctx: ResolveEdtypesContext, *args, **params):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
phys_jaxpr = resolve_edtypes_jaxpr(core.ClosedJaxpr(jaxpr, consts))
phys_jaxpr = resolve_edtypes_jaxpr(core.ClosedJaxpr(jaxpr, consts))
result = core.eval_jaxpr(phys_jaxpr.jaxpr, phys_jaxpr.consts, *args)
if multiple_results:
return result
Expand Down Expand Up @@ -155,4 +155,3 @@ def write_env(var: core.Var, val: Any):
core.clean_up_dead_vars(eqn, env, last_used)

return map(read_env, jaxpr.outvars)

2 changes: 1 addition & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,7 +2103,7 @@ def _gather_edtype_rule(ctx, operand, indices, *,
offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims))
slice_sizes = (*slice_sizes, *elt_shape)
return gather(operand,
indices,
indices,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes,
unique_indices=unique_indices,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,7 +2612,7 @@ def _sharding_constraint_batcher(
_sharding_constraint_batcher, None)

def _sharding_constraint_edtype_rule(ctx: jaxpr_passes.ResolveEdtypesContext,
x, *,
x, *,
sharding, layout,
resource_env, unconstrained_dims):
aval_in, = ctx.avals_in
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,4 +703,4 @@ def _broadcast_to_edtype_rule(ctx: jaxpr_passes.ResolveEdtypesContext,
a, *, shape):
raise NotImplementedError()

jaxpr_passes.register_edtype_rule(broadcast_to_p, _broadcast_to_edtype_rule)
jaxpr_passes.register_edtype_rule(broadcast_to_p, _broadcast_to_edtype_rule)
2 changes: 1 addition & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3729,7 +3729,7 @@ def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf)
return handler

@staticmethod
def physical_const(val):
return val.data
Expand Down
30 changes: 14 additions & 16 deletions tests/resolve_edtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
Shape = Sequence[int]


def find_primitive(jaxpr: core.Jaxpr,
def find_primitive(jaxpr: core.Jaxpr,
primitive: core.Primitive):
for eqn in jaxpr.eqns:
if eqn.primitive == primitive:
Expand Down Expand Up @@ -86,7 +86,7 @@ def fun(k):
phys_aval = phys_jaxpr.jaxpr.invars[0].aval
self.assertEqual(phys_aval.shape, (2, 2))
self.assertEqual(phys_aval.dtype, jnp.uint32)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.broadcast_in_dim_p,
expected_in_shapes=[(2, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -102,7 +102,7 @@ def fun(k):
self.assertEqual(result.shape, (4, 1))
traced = jax.jit(fun).trace(k)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.slice_p,
expected_in_shapes=[(4, 4, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -118,7 +118,7 @@ def fun(k, starts):
self.assertEqual(result.shape, (2, 3))
traced = jax.jit(fun).trace(k, (0, 1))
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.dynamic_slice_p,
expected_in_shapes=[(4, 4, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -143,7 +143,7 @@ def fun(k, updates, starts):
self.assertEqual(result.shape, (4, 4))
traced = jax.jit(fun).trace(k, updates, (0, 1))
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.dynamic_update_slice_p,
expected_in_shapes=[(4, 4, 2), (2, 3, 2)],
expected_in_dtypes=[jnp.uint32, jnp.uint32],
Expand Down Expand Up @@ -173,7 +173,7 @@ def fun(x, y):
self.assertEqual(result, 5)
traced = jax.jit(fun).trace(2, 3)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.convert_element_type_p,
expected_in_shapes=[()],
expected_in_dtypes=[jnp.int32],
Expand Down Expand Up @@ -201,7 +201,7 @@ def fun(k, starts):
jax.random.key_data(k[0:2, 1:4]))
traced = jax.jit(fun).trace(k, starts)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.gather_p,
expected_in_shapes=[(4, 4, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -224,7 +224,7 @@ def fun(k, indices, updates):
jax.random.key_data(updates)))
traced = jax.jit(fun).trace(k, indices, updates)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.scatter_p,
expected_in_shapes=[(4, 4, 2), (2,), (2, 3, 2)],
expected_in_dtypes=[jnp.uint32, jnp.int32, jnp.uint32],
Expand All @@ -242,7 +242,7 @@ def fun(x, y):
jnp.zeros((2, 3)))
traced = jax.jit(fun).trace(x, y)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.eq_p,
expected_in_shapes=[(2, 3, 2), (2, 3, 2)],
expected_in_dtypes=[jnp.uint32, jnp.uint32],
Expand All @@ -262,7 +262,7 @@ def fun(which, x, y):
jnp.stack([x[0], y[1], x[2]]))
traced = jax.jit(fun).trace(which, x, y)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.select_n_p,
expected_in_shapes=[(3, 2), (3, 2), (3, 2)],
expected_in_dtypes=[jnp.bool_, jnp.uint32, jnp.uint32],
Expand All @@ -278,7 +278,7 @@ def fun(x):
self.assertEqual(result.shape, (3, 7, 2))
traced = jax.jit(fun).trace(x)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.transpose_p,
expected_in_shapes=[(2, 3, 7, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -294,7 +294,7 @@ def fun(x):
self.assertEqual(result.shape, (6, 7))
traced = jax.jit(fun).trace(x)
phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr)
self.assert_jaxpr(phys_jaxpr,
self.assert_jaxpr(phys_jaxpr,
lax.reshape_p,
expected_in_shapes=[(2, 3, 7, 2)],
expected_in_dtypes=[jnp.uint32],
Expand All @@ -308,7 +308,6 @@ def fun(x):
k1, k2 = random.split(k)
k1 = random.fold_in(k1, 0)
x1 = random.uniform(k1, shape=(4, 8))

k2 = random.key_data(k2)
k2 = random.wrap_key_data(k2, impl=k.dtype._impl)
x2 = random.uniform(k2, shape=(4, 8))
Expand Down Expand Up @@ -366,7 +365,7 @@ def fun(k, y):
mapped_fun = jax.jit(mapped_fun)
result = mapped_fun(keys, y)
self.assertEqual(result.shape, (x_dim * 5, 8))

def test_pjit_sharding(self):
if jax.device_count() < 2:
self.skipTest('sharding test requires at least 2 devices.')
Expand Down Expand Up @@ -436,7 +435,7 @@ def fun(k):
def test_batched_while_loop(self):
def _cond_fun(val):
return val[0] < 10

def _body_fun(val):
return (val[0] + 1, val[1])

Expand Down Expand Up @@ -495,4 +494,3 @@ def fun_with_cond(key, pred):

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 65a0b13

Please sign in to comment.