Skip to content

Commit

Permalink
Merge pull request #18845 from jakevdp:jaxpr-repr
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588522842
  • Loading branch information
jax authors committed Dec 6, 2023
2 parents fe6e195 + c2a0530 commit 1dd68c5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,7 +3086,10 @@ def _pp_eqn(eqn, context, settings) -> pp.Doc:
rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation),
pp_kv_pairs(sorted(eqn.params.items()), context, settings),
pp.text(" ") + pp_vars(eqn.invars, context)]
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
if lhs.format():
return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs])
else:
return pp.concat(rhs)
CustomPpEqnRule = Callable[[JaxprEqn, JaxprPpContext, JaxprPpSettings], pp.Doc]
pp_eqn_rules: dict[Primitive, CustomPpEqnRule] = {}

Expand Down
8 changes: 8 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6197,6 +6197,14 @@ def test_convert_element_type_literal_constant_folding(self):
jaxpr = api.make_jaxpr(lambda: cet(3.))()
self.assertLen(jaxpr.eqns, 0)

def test_eqn_repr_with_no_lhs(self):
def f(x):
jax.debug.print("{}", x)
return x
jaxpr = jax.make_jaxpr(f)(np.int32(0))
self.assertEqual(jaxpr.eqns[0].primitive, jax._src.debugging.debug_callback_p)
self.assertStartsWith(str(jaxpr.eqns[0]), "debug_callback[", )


class DCETest(jtu.JaxTestCase):

Expand Down

0 comments on commit 1dd68c5

Please sign in to comment.