From 65a0b13bb8924e894bb5c84f372b5b38dc7ef4a9 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 18 Oct 2024 15:17:49 -0700 Subject: [PATCH] Resolve linter errors --- jax/_src/api.py | 1 - jax/_src/interpreters/jaxpr_passes.py | 5 ++--- jax/_src/lax/slicing.py | 2 +- jax/_src/pjit.py | 2 +- jax/_src/state/primitives.py | 2 +- tests/lax_test.py | 2 +- tests/resolve_edtypes_test.py | 30 +++++++++++++-------------- 7 files changed, 20 insertions(+), 24 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 0e3706050c5b..d2ac5465eded 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -77,7 +77,6 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import jaxpr_passes from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import xla diff --git a/jax/_src/interpreters/jaxpr_passes.py b/jax/_src/interpreters/jaxpr_passes.py index 959aa877ece7..094e4b22cf33 100644 --- a/jax/_src/interpreters/jaxpr_passes.py +++ b/jax/_src/interpreters/jaxpr_passes.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Iterable, Iterator, Sequence +from collections.abc import Callable, Sequence import dataclasses import functools from functools import partial @@ -70,7 +70,7 @@ def _rule(ctx: ResolveEdtypesContext, *args, **params): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) else: jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - phys_jaxpr = resolve_edtypes_jaxpr(core.ClosedJaxpr(jaxpr, consts)) + phys_jaxpr = resolve_edtypes_jaxpr(core.ClosedJaxpr(jaxpr, consts)) result = core.eval_jaxpr(phys_jaxpr.jaxpr, phys_jaxpr.consts, *args) if multiple_results: return result @@ -155,4 +155,3 @@ def write_env(var: core.Var, val: Any): core.clean_up_dead_vars(eqn, env, last_used) return map(read_env, jaxpr.outvars) - diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index b07a4fbe83bc..67cabe6240bd 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -2103,7 +2103,7 @@ def _gather_edtype_rule(ctx, operand, indices, *, offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims)) slice_sizes = (*slice_sizes, *elt_shape) return gather(operand, - indices, + indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b2f88f712b89..a12c474546f6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2612,7 +2612,7 @@ def _sharding_constraint_batcher( _sharding_constraint_batcher, None) def _sharding_constraint_edtype_rule(ctx: jaxpr_passes.ResolveEdtypesContext, - x, *, + x, *, sharding, layout, resource_env, unconstrained_dims): aval_in, = ctx.avals_in diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index ad822f2883c4..57ab6b6b8a81 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -703,4 +703,4 @@ def _broadcast_to_edtype_rule(ctx: jaxpr_passes.ResolveEdtypesContext, a, *, shape): raise NotImplementedError() -jaxpr_passes.register_edtype_rule(broadcast_to_p, _broadcast_to_edtype_rule) \ No newline at end of file +jaxpr_passes.register_edtype_rule(broadcast_to_p, _broadcast_to_edtype_rule) diff --git a/tests/lax_test.py b/tests/lax_test.py index 1ec1b34b2792..a84b870b9f00 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3729,7 +3729,7 @@ def handler(_, buf): buf.aval = core.ShapedArray(buf.shape, buf.dtype) return FooArray(aval.shape, buf) return handler - + @staticmethod def physical_const(val): return val.data diff --git a/tests/resolve_edtypes_test.py b/tests/resolve_edtypes_test.py index 201404902a8f..0c451a70ea65 100644 --- a/tests/resolve_edtypes_test.py +++ b/tests/resolve_edtypes_test.py @@ -39,7 +39,7 @@ Shape = Sequence[int] -def find_primitive(jaxpr: core.Jaxpr, +def find_primitive(jaxpr: core.Jaxpr, primitive: core.Primitive): for eqn in jaxpr.eqns: if eqn.primitive == primitive: @@ -86,7 +86,7 @@ def fun(k): phys_aval = phys_jaxpr.jaxpr.invars[0].aval self.assertEqual(phys_aval.shape, (2, 2)) self.assertEqual(phys_aval.dtype, jnp.uint32) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.broadcast_in_dim_p, expected_in_shapes=[(2, 2)], expected_in_dtypes=[jnp.uint32], @@ -102,7 +102,7 @@ def fun(k): self.assertEqual(result.shape, (4, 1)) traced = jax.jit(fun).trace(k) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.slice_p, expected_in_shapes=[(4, 4, 2)], expected_in_dtypes=[jnp.uint32], @@ -118,7 +118,7 @@ def fun(k, starts): self.assertEqual(result.shape, (2, 3)) traced = jax.jit(fun).trace(k, (0, 1)) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.dynamic_slice_p, expected_in_shapes=[(4, 4, 2)], expected_in_dtypes=[jnp.uint32], @@ -143,7 +143,7 @@ def fun(k, updates, starts): self.assertEqual(result.shape, (4, 4)) traced = jax.jit(fun).trace(k, updates, (0, 1)) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.dynamic_update_slice_p, expected_in_shapes=[(4, 4, 2), (2, 3, 2)], expected_in_dtypes=[jnp.uint32, jnp.uint32], @@ -173,7 +173,7 @@ def fun(x, y): self.assertEqual(result, 5) traced = jax.jit(fun).trace(2, 3) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.convert_element_type_p, expected_in_shapes=[()], expected_in_dtypes=[jnp.int32], @@ -201,7 +201,7 @@ def fun(k, starts): jax.random.key_data(k[0:2, 1:4])) traced = jax.jit(fun).trace(k, starts) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.gather_p, expected_in_shapes=[(4, 4, 2)], expected_in_dtypes=[jnp.uint32], @@ -224,7 +224,7 @@ def fun(k, indices, updates): jax.random.key_data(updates))) traced = jax.jit(fun).trace(k, indices, updates) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.scatter_p, expected_in_shapes=[(4, 4, 2), (2,), (2, 3, 2)], expected_in_dtypes=[jnp.uint32, jnp.int32, jnp.uint32], @@ -242,7 +242,7 @@ def fun(x, y): jnp.zeros((2, 3))) traced = jax.jit(fun).trace(x, y) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.eq_p, expected_in_shapes=[(2, 3, 2), (2, 3, 2)], expected_in_dtypes=[jnp.uint32, jnp.uint32], @@ -262,7 +262,7 @@ def fun(which, x, y): jnp.stack([x[0], y[1], x[2]])) traced = jax.jit(fun).trace(which, x, y) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.select_n_p, expected_in_shapes=[(3, 2), (3, 2), (3, 2)], expected_in_dtypes=[jnp.bool_, jnp.uint32, jnp.uint32], @@ -278,7 +278,7 @@ def fun(x): self.assertEqual(result.shape, (3, 7, 2)) traced = jax.jit(fun).trace(x) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.transpose_p, expected_in_shapes=[(2, 3, 7, 2)], expected_in_dtypes=[jnp.uint32], @@ -294,7 +294,7 @@ def fun(x): self.assertEqual(result.shape, (6, 7)) traced = jax.jit(fun).trace(x) phys_jaxpr = jaxpr_passes.resolve_edtypes_jaxpr(traced.jaxpr) - self.assert_jaxpr(phys_jaxpr, + self.assert_jaxpr(phys_jaxpr, lax.reshape_p, expected_in_shapes=[(2, 3, 7, 2)], expected_in_dtypes=[jnp.uint32], @@ -308,7 +308,6 @@ def fun(x): k1, k2 = random.split(k) k1 = random.fold_in(k1, 0) x1 = random.uniform(k1, shape=(4, 8)) - k2 = random.key_data(k2) k2 = random.wrap_key_data(k2, impl=k.dtype._impl) x2 = random.uniform(k2, shape=(4, 8)) @@ -366,7 +365,7 @@ def fun(k, y): mapped_fun = jax.jit(mapped_fun) result = mapped_fun(keys, y) self.assertEqual(result.shape, (x_dim * 5, 8)) - + def test_pjit_sharding(self): if jax.device_count() < 2: self.skipTest('sharding test requires at least 2 devices.') @@ -436,7 +435,7 @@ def fun(k): def test_batched_while_loop(self): def _cond_fun(val): return val[0] < 10 - + def _body_fun(val): return (val[0] + 1, val[1]) @@ -495,4 +494,3 @@ def fun_with_cond(key, pred): if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) -