Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance degradation of map_coordinates with the new CPU runtime #26152

Open
rifu4 opened this issue Jan 28, 2025 · 2 comments
Open

Performance degradation of map_coordinates with the new CPU runtime #26152

rifu4 opened this issue Jan 28, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@rifu4
Copy link

rifu4 commented Jan 28, 2025

Description

With the new CPU runtime, I experience a performance degradation of jax.scipy.ndimage.map_coordinates if I combine it with jax.jit and jax.value_and_grad. Below is a reproducer with a switch to change between the old and new CPU runtime by setting --xla_cpu_use_thunk_runtime=false:

# switch between new and old CPU runtime
old_cpu_runtime = False

if old_cpu_runtime:
    import os
    XLA_flag = "--xla_cpu_use_thunk_runtime=false "
    print(f'set: {XLA_flag}')
    os.environ["XLA_FLAGS"] = XLA_flag

import jax
import jax.numpy as jnp
import timeit
from jax.scipy.ndimage import map_coordinates

def timeit_like(stmt, number=100, globals=None):
    execution_time = timeit.timeit(stmt, number=number, globals=globals)
    print(f"{stmt} executed in: {execution_time:.5f} seconds (for {number} runs)")

shape1 = (128,128)
shape2 = (256,256)

def fun(x):
    y = map_coordinates(x, jnp.indices(shape2), order=1)
    return jnp.max(y)

fun_jit = jax.jit(fun)
fun_grad = jax.value_and_grad(fun)
fun_jit_grad = jax.jit(fun_grad)

# warmup
inp = jnp.ones(shape1)
fun_jit(inp)
fun_grad(inp)
fun_jit_grad(inp)

# benchmarks
timeit_like("jax.block_until_ready(fun(inp))", globals=globals())
timeit_like("jax.block_until_ready(fun_jit(inp))", globals=globals())
timeit_like("jax.block_until_ready(fun_grad(inp))", globals=globals())
timeit_like("jax.block_until_ready(fun_jit_grad(inp))", globals=globals())

For the old CPU runtime (old_cpu_runtime = True) I obtain:

set: --xla_cpu_use_thunk_runtime=false 
jax.block_until_ready(fun(inp)) executed in: 0.17072 seconds (for 100 runs)
jax.block_until_ready(fun_jit(inp)) executed in: 0.01677 seconds (for 100 runs)
jax.block_until_ready(fun_grad(inp)) executed in: 0.54201 seconds (for 100 runs)
jax.block_until_ready(fun_jit_grad(inp)) executed in: 0.18899 seconds (for 100 runs)

With the new CPU runtime (old_cpu_runtime = False) I get:

jax.block_until_ready(fun(inp)) executed in: 0.16992 seconds (for 100 runs)
jax.block_until_ready(fun_jit(inp)) executed in: 0.01715 seconds (for 100 runs)
jax.block_until_ready(fun_grad(inp)) executed in: 2.18482 seconds (for 100 runs)
jax.block_until_ready(fun_jit_grad(inp)) executed in: 1.80749 seconds (for 100 runs)

Thus, the new CPU runtime degrades the performance of map_coordinates by approximately a factor of 10 when being combined with jax.jit and jax.value_and_grad.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.2
python: 3.12.7 (main, Nov 5 2024, 15:03:07) [Clang 16.0.0 (clang-1600.0.26.3)]
device info: cpu, 1 local devices
process_count: 8
uname_result(system='Darwin', node='MacBook-Pro.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec 6 18:41:43 PST 2024, machine='x86_64')

@rifu4 rifu4 added the bug Something isn't working label Jan 28, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 28, 2025

Ping @ezhulenev, @penpornk re CPU thunks runtime performance

This looks to me like the usual small while loop performance issue, but the lowered HLO doesn't have any loops (they get added when compiled), so there might be something else to look into. See HLO below:

Input HLO for grad(fun)
module @jit_fun attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<128x128xf32>) -> (tensor<128x128xf32> {jax.result_info = ""}) {
    %0 = stablehlo.iota dim = 0 : tensor<256xi32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<256xi32>) -> tensor<256x256xi32>
    %2 = stablehlo.iota dim = 0 : tensor<256xi32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [1] : (tensor<256xi32>) -> tensor<256x256xi32>
    %4 = stablehlo.broadcast_in_dim %1, dims = [1, 2] : (tensor<256x256xi32>) -> tensor<1x256x256xi32>
    %5 = stablehlo.broadcast_in_dim %3, dims = [1, 2] : (tensor<256x256xi32>) -> tensor<1x256x256xi32>
    %6 = stablehlo.concatenate %4, %5, dim = 0 : (tensor<1x256x256xi32>, tensor<1x256x256xi32>) -> tensor<2x256x256xi32>
    %7:13 = call @_map_coordinates(%arg0, %6) : (tensor<128x128xf32>, tensor<2x256x256xi32>) -> (tensor<256x256xf32>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>)
    %cst = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %8 = stablehlo.reduce(%7#0 init: %cst) applies stablehlo.maximum across dimensions = [0, 1] : (tensor<256x256xf32>, tensor<f32>) -> tensor<f32>
    %9 = stablehlo.reshape %8 : (tensor<f32>) -> tensor<1x1xf32>
    %10 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xf32>) -> tensor<256x256xf32>
    %11 = stablehlo.compare  EQ, %7#0, %10,  FLOAT : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xi1>
    %12 = stablehlo.convert %11 : (tensor<256x256xi1>) -> tensor<256x256xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %13 = stablehlo.reduce(%12 init: %cst_0) applies stablehlo.add across dimensions = [0, 1] : (tensor<256x256xf32>, tensor<f32>) -> tensor<f32>
    %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %14 = stablehlo.divide %cst_1, %13 : tensor<f32>
    %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor<f32>) -> tensor<256x256xf32>
    %16 = stablehlo.multiply %15, %12 : tensor<256x256xf32>
    %17 = call @_map_coordinates_0(%7#1, %7#2, %7#3, %7#4, %7#5, %7#6, %7#7, %7#8, %7#9, %7#10, %7#11, %7#12, %16) : (tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>) -> tensor<128x128xf32>
    return %17 : tensor<128x128xf32>
  }
  func.func private @_map_coordinates(%arg0: tensor<128x128xf32>, %arg1: tensor<2x256x256xi32>) -> (tensor<256x256xf32>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>) {
    %0 = stablehlo.slice %arg1 [0:1, 0:256, 0:256] : (tensor<2x256x256xi32>) -> tensor<1x256x256xi32>
    %1 = stablehlo.reshape %0 : (tensor<1x256x256xi32>) -> tensor<256x256xi32>
    %2 = stablehlo.slice %arg1 [1:2, 0:256, 0:256] : (tensor<2x256x256xi32>) -> tensor<1x256x256xi32>
    %3 = stablehlo.reshape %2 : (tensor<1x256x256xi32>) -> tensor<256x256xi32>
    %4 = stablehlo.subtract %1, %1 : tensor<256x256xi32>
    %c = stablehlo.constant dense<1> : tensor<i32>
    %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %6 = stablehlo.subtract %5, %4 : tensor<256x256xi32>
    %7 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %8 = stablehlo.add %1, %7 : tensor<256x256xi32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %9 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %10 = stablehlo.compare  GE, %1, %9,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %c_1 = stablehlo.constant dense<128> : tensor<i32>
    %11 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %12 = stablehlo.compare  LT, %1, %11,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %13 = stablehlo.and %10, %12 : tensor<256x256xi1>
    %14 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %15 = stablehlo.compare  GE, %8, %14,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %16 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %17 = stablehlo.compare  LT, %8, %16,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %18 = stablehlo.and %15, %17 : tensor<256x256xi1>
    %19 = stablehlo.subtract %3, %3 : tensor<256x256xi32>
    %20 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %21 = stablehlo.subtract %20, %19 : tensor<256x256xi32>
    %22 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %23 = stablehlo.add %3, %22 : tensor<256x256xi32>
    %24 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %25 = stablehlo.compare  GE, %3, %24,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %26 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %27 = stablehlo.compare  LT, %3, %26,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %28 = stablehlo.and %25, %27 : tensor<256x256xi1>
    %29 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %30 = stablehlo.compare  GE, %23, %29,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %31 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %32 = stablehlo.compare  LT, %23, %31,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %33 = stablehlo.and %30, %32 : tensor<256x256xi1>
    %34 = stablehlo.and %13, %28 : tensor<256x256xi1>
    %35 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %36 = stablehlo.compare  LT, %1, %35,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %37 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %38 = stablehlo.add %1, %37 : tensor<256x256xi32>
    %39 = stablehlo.select %36, %38, %1 : tensor<256x256xi1>, tensor<256x256xi32>
    %40 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %41 = stablehlo.compare  LT, %3, %40,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %42 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %43 = stablehlo.add %3, %42 : tensor<256x256xi32>
    %44 = stablehlo.select %41, %43, %3 : tensor<256x256xi1>, tensor<256x256xi32>
    %45 = stablehlo.broadcast_in_dim %39, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %46 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %47 = stablehlo.concatenate %45, %46, dim = 2 : (tensor<256x256x1xi32>, tensor<256x256x1xi32>) -> tensor<256x256x2xi32>
    %48 = "stablehlo.gather"(%arg0, %47) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<128x128xf32>, tensor<256x256x2xi32>) -> tensor<256x256xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %49 = call @_where(%34, %48, %cst) : (tensor<256x256xi1>, tensor<256x256xf32>, tensor<f32>) -> tensor<256x256xf32>
    %50 = stablehlo.multiply %6, %21 : tensor<256x256xi32>
    %51 = stablehlo.convert %50 : (tensor<256x256xi32>) -> tensor<256x256xf32>
    %52 = stablehlo.multiply %51, %49 : tensor<256x256xf32>
    %53 = stablehlo.and %13, %33 : tensor<256x256xi1>
    %54 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %55 = stablehlo.compare  LT, %1, %54,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %56 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %57 = stablehlo.add %1, %56 : tensor<256x256xi32>
    %58 = stablehlo.select %55, %57, %1 : tensor<256x256xi1>, tensor<256x256xi32>
    %59 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %60 = stablehlo.compare  LT, %23, %59,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %61 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %62 = stablehlo.add %23, %61 : tensor<256x256xi32>
    %63 = stablehlo.select %60, %62, %23 : tensor<256x256xi1>, tensor<256x256xi32>
    %64 = stablehlo.broadcast_in_dim %58, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %65 = stablehlo.broadcast_in_dim %63, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %66 = stablehlo.concatenate %64, %65, dim = 2 : (tensor<256x256x1xi32>, tensor<256x256x1xi32>) -> tensor<256x256x2xi32>
    %67 = "stablehlo.gather"(%arg0, %66) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<128x128xf32>, tensor<256x256x2xi32>) -> tensor<256x256xf32>
    %68 = call @_where(%53, %67, %cst) : (tensor<256x256xi1>, tensor<256x256xf32>, tensor<f32>) -> tensor<256x256xf32>
    %69 = stablehlo.multiply %6, %19 : tensor<256x256xi32>
    %70 = stablehlo.convert %69 : (tensor<256x256xi32>) -> tensor<256x256xf32>
    %71 = stablehlo.multiply %70, %68 : tensor<256x256xf32>
    %72 = stablehlo.and %18, %28 : tensor<256x256xi1>
    %73 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %74 = stablehlo.compare  LT, %8, %73,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %75 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %76 = stablehlo.add %8, %75 : tensor<256x256xi32>
    %77 = stablehlo.select %74, %76, %8 : tensor<256x256xi1>, tensor<256x256xi32>
    %78 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %79 = stablehlo.compare  LT, %3, %78,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %80 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %81 = stablehlo.add %3, %80 : tensor<256x256xi32>
    %82 = stablehlo.select %79, %81, %3 : tensor<256x256xi1>, tensor<256x256xi32>
    %83 = stablehlo.broadcast_in_dim %77, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %84 = stablehlo.broadcast_in_dim %82, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %85 = stablehlo.concatenate %83, %84, dim = 2 : (tensor<256x256x1xi32>, tensor<256x256x1xi32>) -> tensor<256x256x2xi32>
    %86 = "stablehlo.gather"(%arg0, %85) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<128x128xf32>, tensor<256x256x2xi32>) -> tensor<256x256xf32>
    %87 = call @_where(%72, %86, %cst) : (tensor<256x256xi1>, tensor<256x256xf32>, tensor<f32>) -> tensor<256x256xf32>
    %88 = stablehlo.multiply %4, %21 : tensor<256x256xi32>
    %89 = stablehlo.convert %88 : (tensor<256x256xi32>) -> tensor<256x256xf32>
    %90 = stablehlo.multiply %89, %87 : tensor<256x256xf32>
    %91 = stablehlo.and %18, %33 : tensor<256x256xi1>
    %92 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %93 = stablehlo.compare  LT, %8, %92,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %94 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %95 = stablehlo.add %8, %94 : tensor<256x256xi32>
    %96 = stablehlo.select %93, %95, %8 : tensor<256x256xi1>, tensor<256x256xi32>
    %97 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %98 = stablehlo.compare  LT, %23, %97,  SIGNED : (tensor<256x256xi32>, tensor<256x256xi32>) -> tensor<256x256xi1>
    %99 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor<i32>) -> tensor<256x256xi32>
    %100 = stablehlo.add %23, %99 : tensor<256x256xi32>
    %101 = stablehlo.select %98, %100, %23 : tensor<256x256xi1>, tensor<256x256xi32>
    %102 = stablehlo.broadcast_in_dim %96, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %103 = stablehlo.broadcast_in_dim %101, dims = [0, 1] : (tensor<256x256xi32>) -> tensor<256x256x1xi32>
    %104 = stablehlo.concatenate %102, %103, dim = 2 : (tensor<256x256x1xi32>, tensor<256x256x1xi32>) -> tensor<256x256x2xi32>
    %105 = "stablehlo.gather"(%arg0, %104) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<128x128xf32>, tensor<256x256x2xi32>) -> tensor<256x256xf32>
    %106 = call @_where(%91, %105, %cst) : (tensor<256x256xi1>, tensor<256x256xf32>, tensor<f32>) -> tensor<256x256xf32>
    %107 = stablehlo.multiply %4, %19 : tensor<256x256xi32>
    %108 = stablehlo.convert %107 : (tensor<256x256xi32>) -> tensor<256x256xf32>
    %109 = stablehlo.multiply %108, %106 : tensor<256x256xf32>
    %110 = stablehlo.add %52, %71 : tensor<256x256xf32>
    %111 = stablehlo.add %110, %90 : tensor<256x256xf32>
    %112 = stablehlo.add %111, %109 : tensor<256x256xf32>
    return %112, %51, %47, %34, %70, %66, %53, %89, %85, %72, %108, %104, %91 : tensor<256x256xf32>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>, tensor<256x256xf32>, tensor<256x256x2xi32>, tensor<256x256xi1>
  }
  func.func private @_where(%arg0: tensor<256x256xi1>, %arg1: tensor<256x256xf32>, %arg2: tensor<f32>) -> tensor<256x256xf32> {
    %0 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor<f32>) -> tensor<256x256xf32>
    %1 = stablehlo.select %arg0, %arg1, %0 : tensor<256x256xi1>, tensor<256x256xf32>
    return %1 : tensor<256x256xf32>
  }
  func.func private @_map_coordinates_0(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256x2xi32>, %arg2: tensor<256x256xi1>, %arg3: tensor<256x256xf32>, %arg4: tensor<256x256x2xi32>, %arg5: tensor<256x256xi1>, %arg6: tensor<256x256xf32>, %arg7: tensor<256x256x2xi32>, %arg8: tensor<256x256xi1>, %arg9: tensor<256x256xf32>, %arg10: tensor<256x256x2xi32>, %arg11: tensor<256x256xi1>, %arg12: tensor<256x256xf32>) -> tensor<128x128xf32> {
    %0 = stablehlo.multiply %arg9, %arg12 : tensor<256x256xf32>
    %1 = call @_where_1(%arg11, %0) : (tensor<256x256xi1>, tensor<256x256xf32>) -> tensor<256x256xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x128xf32>
    %3 = "stablehlo.scatter"(%2, %arg10, %1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %19 = stablehlo.add %arg13, %arg14 : tensor<f32>
      stablehlo.return %19 : tensor<f32>
    }) : (tensor<128x128xf32>, tensor<256x256x2xi32>, tensor<256x256xf32>) -> tensor<128x128xf32>
    %4 = stablehlo.multiply %arg6, %arg12 : tensor<256x256xf32>
    %5 = call @_where_1(%arg8, %4) : (tensor<256x256xi1>, tensor<256x256xf32>) -> tensor<256x256xf32>
    %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x128xf32>
    %7 = "stablehlo.scatter"(%6, %arg7, %5) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %19 = stablehlo.add %arg13, %arg14 : tensor<f32>
      stablehlo.return %19 : tensor<f32>
    }) : (tensor<128x128xf32>, tensor<256x256x2xi32>, tensor<256x256xf32>) -> tensor<128x128xf32>
    %8 = stablehlo.add %3, %7 : tensor<128x128xf32>
    %9 = stablehlo.multiply %arg3, %arg12 : tensor<256x256xf32>
    %10 = call @_where_1(%arg5, %9) : (tensor<256x256xi1>, tensor<256x256xf32>) -> tensor<256x256xf32>
    %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x128xf32>
    %12 = "stablehlo.scatter"(%11, %arg4, %10) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %19 = stablehlo.add %arg13, %arg14 : tensor<f32>
      stablehlo.return %19 : tensor<f32>
    }) : (tensor<128x128xf32>, tensor<256x256x2xi32>, tensor<256x256xf32>) -> tensor<128x128xf32>
    %13 = stablehlo.add %8, %12 : tensor<128x128xf32>
    %14 = stablehlo.multiply %arg0, %arg12 : tensor<256x256xf32>
    %15 = call @_where_1(%arg2, %14) : (tensor<256x256xi1>, tensor<256x256xf32>) -> tensor<256x256xf32>
    %16 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<128x128xf32>
    %17 = "stablehlo.scatter"(%16, %arg1, %15) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 2>, unique_indices = false}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %19 = stablehlo.add %arg13, %arg14 : tensor<f32>
      stablehlo.return %19 : tensor<f32>
    }) : (tensor<128x128xf32>, tensor<256x256x2xi32>, tensor<256x256xf32>) -> tensor<128x128xf32>
    %18 = stablehlo.add %13, %17 : tensor<128x128xf32>
    return %18 : tensor<128x128xf32>
  }
  func.func private @_where_1(%arg0: tensor<256x256xi1>, %arg1: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<256x256xf32>
    %1 = stablehlo.select %arg0, %arg1, %0 : tensor<256x256xi1>, tensor<256x256xf32>
    return %1 : tensor<256x256xf32>
  }
}
Compiled HLO for grad(fun)
HloModule jit_fun, is_scheduled=true, entry_computation_layout={(f32[128,128]{1,0})->f32[128,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

%and.reduce_sub_computation (lhs: pred[], rhs: pred[]) -> pred[] {
  %lhs = pred[] parameter(0)
  %rhs = pred[] parameter(1)
  ROOT %and.8 = pred[] and(pred[] %lhs, pred[] %rhs)
}

%fused_computation (param_0: f32[128,128], param_1.3: f32[65536], param_2.4: s32[], param_3.6: s32[2], param_4.8: s32[65536,2]) -> f32[128,128] {
  %param_0 = f32[128,128]{1,0} parameter(0)
  %constant.116 = s32[] constant(0)
  %broadcast.105 = s32[2]{0} broadcast(s32[] %constant.116), dimensions={}
  %param_3.6 = s32[2]{0} parameter(3)
  %compare.29 = pred[2]{0} compare(s32[2]{0} %broadcast.105, s32[2]{0} %param_3.6), direction=LE
  %constant.115 = s32[] constant(127)
  %broadcast.104 = s32[2]{0} broadcast(s32[] %constant.115), dimensions={}
  %compare.28 = pred[2]{0} compare(s32[2]{0} %broadcast.104, s32[2]{0} %param_3.6), direction=GE
  %and.16 = pred[2]{0} and(pred[2]{0} %compare.29, pred[2]{0} %compare.28)
  %constant.114 = pred[] constant(true)
  %reduce.6 = pred[] reduce(pred[2]{0} %and.16, pred[] %constant.114), dimensions={0}, to_apply=%and.reduce_sub_computation
  %bitcast.37 = pred[1,1]{1,0} bitcast(pred[] %reduce.6)
  %param_4.8 = s32[65536,2]{1,0} parameter(4)
  %param_2.4 = s32[] parameter(2)
  %dynamic-slice.18 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_4.8, s32[] %param_2.4, s32[] %constant.116), dynamic_slice_sizes={1,2}
  %slice.43 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.18), slice={[0:1], [0:1]}
  %bitcast.36 = s32[] bitcast(s32[1,1]{1,0} %slice.43)
  %bitcast.38 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.18)
  %slice.42 = s32[1]{0} slice(s32[2]{0} %bitcast.38), slice={[1:2]}
  %bitcast.35 = s32[] bitcast(s32[1]{0} %slice.42)
  %dynamic-slice.17 = f32[1,1]{1,0} dynamic-slice(f32[128,128]{1,0} %param_0, s32[] %bitcast.36, s32[] %bitcast.35), dynamic_slice_sizes={1,1}
  %param_1.3 = f32[65536]{0} parameter(1)
  %dynamic-slice.16 = f32[1]{0} dynamic-slice(f32[65536]{0} %param_1.3, s32[] %param_2.4), dynamic_slice_sizes={1}
  %bitcast.34 = f32[1,1]{1,0} bitcast(f32[1]{0} %dynamic-slice.16)
  %add.32 = f32[1,1]{1,0} add(f32[1,1]{1,0} %dynamic-slice.17, f32[1,1]{1,0} %bitcast.34), metadata={op_name="/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.32 = f32[1,1]{1,0} select(pred[1,1]{1,0} %bitcast.37, f32[1,1]{1,0} %add.32, f32[1,1]{1,0} %dynamic-slice.17)
  ROOT %dynamic-update-slice.4 = f32[128,128]{1,0} dynamic-update-slice(f32[128,128]{1,0} %param_0, f32[1,1]{1,0} %select.32, s32[] %bitcast.36, s32[] %bitcast.35)
}

%fused_computation.1 (param_0.3: s32[65536,2], param_1.8: s32[]) -> s32[2] {
  %param_0.3 = s32[65536,2]{1,0} parameter(0)
  %param_1.8 = s32[] parameter(1)
  %constant.117 = s32[] constant(0)
  %dynamic-slice.19 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_0.3, s32[] %param_1.8, s32[] %constant.117), dynamic_slice_sizes={1,2}
  %slice.45 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.19), slice={[0:1], [0:1]}
  %bitcast.40 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.45)
  %bitcast.39 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.19)
  %slice.44 = s32[1]{0} slice(s32[2]{0} %bitcast.39), slice={[1:2]}
  ROOT %concatenate.8 = s32[2]{0} concatenate(s32[1]{0} %bitcast.40, s32[1]{0} %slice.44), dimensions={0}
}

%while_body (param.1: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> (s32[], f32[128,128], s32[65536,2], f32[65536]) {
  %param.1 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.60 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.1), index=0
  %copy.18 = s32[] copy(s32[] %get-tuple-element.60)
  %get-tuple-element.66 = s32[65536,2]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.1), index=2
  %slice_concatenate_fusion = s32[2]{0} fusion(s32[65536,2]{1,0} %get-tuple-element.66, s32[] %copy.18), kind=kLoop, calls=%fused_computation.1
  %get-tuple-element.67 = f32[65536]{0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.1), index=3
  %get-tuple-element.61 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.1), index=1
  %select_dynamic-update-slice_fusion = f32[128,128]{1,0} fusion(f32[128,128]{1,0} %get-tuple-element.61, f32[65536]{0} %get-tuple-element.67, s32[] %copy.18, s32[2]{0} %slice_concatenate_fusion, s32[65536,2]{1,0} %get-tuple-element.66), kind=kLoop, calls=%fused_computation
  %constant.22 = s32[] constant(1)
  %add.16 = s32[] add(s32[] %copy.18, s32[] %constant.22)
  ROOT %tuple.21 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %add.16, f32[128,128]{1,0} %select_dynamic-update-slice_fusion, s32[65536,2]{1,0} %get-tuple-element.66, f32[65536]{0} %get-tuple-element.67)
}

%while_cond (param.0: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> pred[] {
  %constant.17 = s32[] constant(65536)
  %param.0 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.0), index=0
  ROOT %compare.16 = pred[] compare(s32[] %get-tuple-element, s32[] %constant.17), direction=LT
}

%and.reduce_sub_computation.1 (lhs.1: pred[], rhs.1: pred[]) -> pred[] {
  %lhs.1 = pred[] parameter(0)
  %rhs.1 = pred[] parameter(1)
  ROOT %and.10 = pred[] and(pred[] %lhs.1, pred[] %rhs.1)
}

%fused_computation.2 (param_0.4: f32[128,128], param_1.12: f32[65536], param_2.10: s32[], param_3.13: s32[2], param_4.17: s32[65536,2]) -> f32[128,128] {
  %param_0.4 = f32[128,128]{1,0} parameter(0)
  %constant.120 = s32[] constant(0)
  %broadcast.107 = s32[2]{0} broadcast(s32[] %constant.120), dimensions={}
  %param_3.13 = s32[2]{0} parameter(3)
  %compare.31 = pred[2]{0} compare(s32[2]{0} %broadcast.107, s32[2]{0} %param_3.13), direction=LE
  %constant.119 = s32[] constant(127)
  %broadcast.106 = s32[2]{0} broadcast(s32[] %constant.119), dimensions={}
  %compare.30 = pred[2]{0} compare(s32[2]{0} %broadcast.106, s32[2]{0} %param_3.13), direction=GE
  %and.17 = pred[2]{0} and(pred[2]{0} %compare.31, pred[2]{0} %compare.30)
  %constant.118 = pred[] constant(true)
  %reduce.7 = pred[] reduce(pred[2]{0} %and.17, pred[] %constant.118), dimensions={0}, to_apply=%and.reduce_sub_computation.1
  %bitcast.44 = pred[1,1]{1,0} bitcast(pred[] %reduce.7)
  %param_4.17 = s32[65536,2]{1,0} parameter(4)
  %param_2.10 = s32[] parameter(2)
  %dynamic-slice.22 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_4.17, s32[] %param_2.10, s32[] %constant.120), dynamic_slice_sizes={1,2}
  %slice.49 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.22), slice={[0:1], [0:1]}
  %bitcast.43 = s32[] bitcast(s32[1,1]{1,0} %slice.49)
  %bitcast.45 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.22)
  %slice.47 = s32[1]{0} slice(s32[2]{0} %bitcast.45), slice={[1:2]}
  %bitcast.42 = s32[] bitcast(s32[1]{0} %slice.47)
  %dynamic-slice.21 = f32[1,1]{1,0} dynamic-slice(f32[128,128]{1,0} %param_0.4, s32[] %bitcast.43, s32[] %bitcast.42), dynamic_slice_sizes={1,1}
  %param_1.12 = f32[65536]{0} parameter(1)
  %dynamic-slice.20 = f32[1]{0} dynamic-slice(f32[65536]{0} %param_1.12, s32[] %param_2.10), dynamic_slice_sizes={1}
  %bitcast.41 = f32[1,1]{1,0} bitcast(f32[1]{0} %dynamic-slice.20)
  %add.33 = f32[1,1]{1,0} add(f32[1,1]{1,0} %dynamic-slice.21, f32[1,1]{1,0} %bitcast.41), metadata={op_name="/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.33 = f32[1,1]{1,0} select(pred[1,1]{1,0} %bitcast.44, f32[1,1]{1,0} %add.33, f32[1,1]{1,0} %dynamic-slice.21)
  ROOT %dynamic-update-slice.5 = f32[128,128]{1,0} dynamic-update-slice(f32[128,128]{1,0} %param_0.4, f32[1,1]{1,0} %select.33, s32[] %bitcast.43, s32[] %bitcast.42)
}

%fused_computation.3 (param_0.7: s32[65536,2], param_1.17: s32[]) -> s32[2] {
  %param_0.7 = s32[65536,2]{1,0} parameter(0)
  %param_1.17 = s32[] parameter(1)
  %constant.121 = s32[] constant(0)
  %dynamic-slice.23 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_0.7, s32[] %param_1.17, s32[] %constant.121), dynamic_slice_sizes={1,2}
  %slice.51 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.23), slice={[0:1], [0:1]}
  %bitcast.47 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.51)
  %bitcast.46 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.23)
  %slice.50 = s32[1]{0} slice(s32[2]{0} %bitcast.46), slice={[1:2]}
  ROOT %concatenate.9 = s32[2]{0} concatenate(s32[1]{0} %bitcast.47, s32[1]{0} %slice.50), dimensions={0}
}

%while_body.1 (param.3: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> (s32[], f32[128,128], s32[65536,2], f32[65536]) {
  %param.3 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.72 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.3), index=0
  %copy.24 = s32[] copy(s32[] %get-tuple-element.72)
  %get-tuple-element.78 = s32[65536,2]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.3), index=2
  %slice_concatenate_fusion.1 = s32[2]{0} fusion(s32[65536,2]{1,0} %get-tuple-element.78, s32[] %copy.24), kind=kLoop, calls=%fused_computation.3
  %get-tuple-element.79 = f32[65536]{0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.3), index=3
  %get-tuple-element.73 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.3), index=1
  %select_dynamic-update-slice_fusion.1 = f32[128,128]{1,0} fusion(f32[128,128]{1,0} %get-tuple-element.73, f32[65536]{0} %get-tuple-element.79, s32[] %copy.24, s32[2]{0} %slice_concatenate_fusion.1, s32[65536,2]{1,0} %get-tuple-element.78), kind=kLoop, calls=%fused_computation.2
  %constant.36 = s32[] constant(1)
  %add.17 = s32[] add(s32[] %copy.24, s32[] %constant.36)
  ROOT %tuple.24 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %add.17, f32[128,128]{1,0} %select_dynamic-update-slice_fusion.1, s32[65536,2]{1,0} %get-tuple-element.78, f32[65536]{0} %get-tuple-element.79)
}

%while_cond.1 (param.2: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> pred[] {
  %constant.31 = s32[] constant(65536)
  %param.2 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.8 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.2), index=0
  ROOT %compare.19 = pred[] compare(s32[] %get-tuple-element.8, s32[] %constant.31), direction=LT
}

%and.reduce_sub_computation.2 (lhs.2: pred[], rhs.2: pred[]) -> pred[] {
  %lhs.2 = pred[] parameter(0)
  %rhs.2 = pred[] parameter(1)
  ROOT %and.12 = pred[] and(pred[] %lhs.2, pred[] %rhs.2)
}

%fused_computation.4 (param_0.8: f32[128,128], param_1.21: f32[65536], param_2.16: s32[], param_3.20: s32[2], param_4.26: s32[65536,2]) -> f32[128,128] {
  %param_0.8 = f32[128,128]{1,0} parameter(0)
  %constant.124 = s32[] constant(0)
  %broadcast.109 = s32[2]{0} broadcast(s32[] %constant.124), dimensions={}
  %param_3.20 = s32[2]{0} parameter(3)
  %compare.33 = pred[2]{0} compare(s32[2]{0} %broadcast.109, s32[2]{0} %param_3.20), direction=LE
  %constant.123 = s32[] constant(127)
  %broadcast.108 = s32[2]{0} broadcast(s32[] %constant.123), dimensions={}
  %compare.32 = pred[2]{0} compare(s32[2]{0} %broadcast.108, s32[2]{0} %param_3.20), direction=GE
  %and.18 = pred[2]{0} and(pred[2]{0} %compare.33, pred[2]{0} %compare.32)
  %constant.122 = pred[] constant(true)
  %reduce.8 = pred[] reduce(pred[2]{0} %and.18, pred[] %constant.122), dimensions={0}, to_apply=%and.reduce_sub_computation.2
  %bitcast.51 = pred[1,1]{1,0} bitcast(pred[] %reduce.8)
  %param_4.26 = s32[65536,2]{1,0} parameter(4)
  %param_2.16 = s32[] parameter(2)
  %dynamic-slice.26 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_4.26, s32[] %param_2.16, s32[] %constant.124), dynamic_slice_sizes={1,2}
  %slice.53 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.26), slice={[0:1], [0:1]}
  %bitcast.50 = s32[] bitcast(s32[1,1]{1,0} %slice.53)
  %bitcast.52 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.26)
  %slice.52 = s32[1]{0} slice(s32[2]{0} %bitcast.52), slice={[1:2]}
  %bitcast.49 = s32[] bitcast(s32[1]{0} %slice.52)
  %dynamic-slice.25 = f32[1,1]{1,0} dynamic-slice(f32[128,128]{1,0} %param_0.8, s32[] %bitcast.50, s32[] %bitcast.49), dynamic_slice_sizes={1,1}
  %param_1.21 = f32[65536]{0} parameter(1)
  %dynamic-slice.24 = f32[1]{0} dynamic-slice(f32[65536]{0} %param_1.21, s32[] %param_2.16), dynamic_slice_sizes={1}
  %bitcast.48 = f32[1,1]{1,0} bitcast(f32[1]{0} %dynamic-slice.24)
  %add.34 = f32[1,1]{1,0} add(f32[1,1]{1,0} %dynamic-slice.25, f32[1,1]{1,0} %bitcast.48), metadata={op_name="/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.34 = f32[1,1]{1,0} select(pred[1,1]{1,0} %bitcast.51, f32[1,1]{1,0} %add.34, f32[1,1]{1,0} %dynamic-slice.25)
  ROOT %dynamic-update-slice.6 = f32[128,128]{1,0} dynamic-update-slice(f32[128,128]{1,0} %param_0.8, f32[1,1]{1,0} %select.34, s32[] %bitcast.50, s32[] %bitcast.49)
}

%fused_computation.5 (param_0.11: s32[65536,2], param_1.26: s32[]) -> s32[2] {
  %param_0.11 = s32[65536,2]{1,0} parameter(0)
  %param_1.26 = s32[] parameter(1)
  %constant.125 = s32[] constant(0)
  %dynamic-slice.27 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_0.11, s32[] %param_1.26, s32[] %constant.125), dynamic_slice_sizes={1,2}
  %slice.55 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.27), slice={[0:1], [0:1]}
  %bitcast.54 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.55)
  %bitcast.53 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.27)
  %slice.54 = s32[1]{0} slice(s32[2]{0} %bitcast.53), slice={[1:2]}
  ROOT %concatenate.10 = s32[2]{0} concatenate(s32[1]{0} %bitcast.54, s32[1]{0} %slice.54), dimensions={0}
}

%while_body.2 (param.5: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> (s32[], f32[128,128], s32[65536,2], f32[65536]) {
  %param.5 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.48 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.5), index=0
  %copy.12 = s32[] copy(s32[] %get-tuple-element.48)
  %get-tuple-element.54 = s32[65536,2]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.5), index=2
  %slice_concatenate_fusion.2 = s32[2]{0} fusion(s32[65536,2]{1,0} %get-tuple-element.54, s32[] %copy.12), kind=kLoop, calls=%fused_computation.5
  %get-tuple-element.55 = f32[65536]{0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.5), index=3
  %get-tuple-element.49 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.5), index=1
  %select_dynamic-update-slice_fusion.2 = f32[128,128]{1,0} fusion(f32[128,128]{1,0} %get-tuple-element.49, f32[65536]{0} %get-tuple-element.55, s32[] %copy.12, s32[2]{0} %slice_concatenate_fusion.2, s32[65536,2]{1,0} %get-tuple-element.54), kind=kLoop, calls=%fused_computation.4
  %constant.54 = s32[] constant(1)
  %add.18 = s32[] add(s32[] %copy.12, s32[] %constant.54)
  ROOT %tuple.18 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %add.18, f32[128,128]{1,0} %select_dynamic-update-slice_fusion.2, s32[65536,2]{1,0} %get-tuple-element.54, f32[65536]{0} %get-tuple-element.55)
}

%while_cond.2 (param.4: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> pred[] {
  %constant.49 = s32[] constant(65536)
  %param.4 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.16 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.4), index=0
  ROOT %compare.22 = pred[] compare(s32[] %get-tuple-element.16, s32[] %constant.49), direction=LT
}

%and.reduce_sub_computation.3 (lhs.3: pred[], rhs.3: pred[]) -> pred[] {
  %lhs.3 = pred[] parameter(0)
  %rhs.3 = pred[] parameter(1)
  ROOT %and.14 = pred[] and(pred[] %lhs.3, pred[] %rhs.3)
}

%fused_computation.6 (param_0.12: f32[128,128], param_1.30: f32[65536], param_2.22: s32[], param_3.27: s32[2], param_4.35: s32[65536,2]) -> f32[128,128] {
  %param_0.12 = f32[128,128]{1,0} parameter(0)
  %constant.128 = s32[] constant(0)
  %broadcast.111 = s32[2]{0} broadcast(s32[] %constant.128), dimensions={}
  %param_3.27 = s32[2]{0} parameter(3)
  %compare.35 = pred[2]{0} compare(s32[2]{0} %broadcast.111, s32[2]{0} %param_3.27), direction=LE
  %constant.127 = s32[] constant(127)
  %broadcast.110 = s32[2]{0} broadcast(s32[] %constant.127), dimensions={}
  %compare.34 = pred[2]{0} compare(s32[2]{0} %broadcast.110, s32[2]{0} %param_3.27), direction=GE
  %and.19 = pred[2]{0} and(pred[2]{0} %compare.35, pred[2]{0} %compare.34)
  %constant.126 = pred[] constant(true)
  %reduce.9 = pred[] reduce(pred[2]{0} %and.19, pred[] %constant.126), dimensions={0}, to_apply=%and.reduce_sub_computation.3
  %bitcast.58 = pred[1,1]{1,0} bitcast(pred[] %reduce.9)
  %param_4.35 = s32[65536,2]{1,0} parameter(4)
  %param_2.22 = s32[] parameter(2)
  %dynamic-slice.30 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_4.35, s32[] %param_2.22, s32[] %constant.128), dynamic_slice_sizes={1,2}
  %slice.57 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.30), slice={[0:1], [0:1]}
  %bitcast.57 = s32[] bitcast(s32[1,1]{1,0} %slice.57)
  %bitcast.59 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.30)
  %slice.56 = s32[1]{0} slice(s32[2]{0} %bitcast.59), slice={[1:2]}
  %bitcast.56 = s32[] bitcast(s32[1]{0} %slice.56)
  %dynamic-slice.29 = f32[1,1]{1,0} dynamic-slice(f32[128,128]{1,0} %param_0.12, s32[] %bitcast.57, s32[] %bitcast.56), dynamic_slice_sizes={1,1}
  %param_1.30 = f32[65536]{0} parameter(1)
  %dynamic-slice.28 = f32[1]{0} dynamic-slice(f32[65536]{0} %param_1.30, s32[] %param_2.22), dynamic_slice_sizes={1}
  %bitcast.55 = f32[1,1]{1,0} bitcast(f32[1]{0} %dynamic-slice.28)
  %add.35 = f32[1,1]{1,0} add(f32[1,1]{1,0} %dynamic-slice.29, f32[1,1]{1,0} %bitcast.55), metadata={op_name="/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.36 = f32[1,1]{1,0} select(pred[1,1]{1,0} %bitcast.58, f32[1,1]{1,0} %add.35, f32[1,1]{1,0} %dynamic-slice.29)
  ROOT %dynamic-update-slice.7 = f32[128,128]{1,0} dynamic-update-slice(f32[128,128]{1,0} %param_0.12, f32[1,1]{1,0} %select.36, s32[] %bitcast.57, s32[] %bitcast.56)
}

%fused_computation.7 (param_0.15: s32[65536,2], param_1.35: s32[]) -> s32[2] {
  %param_0.15 = s32[65536,2]{1,0} parameter(0)
  %param_1.35 = s32[] parameter(1)
  %constant.129 = s32[] constant(0)
  %dynamic-slice.31 = s32[1,2]{1,0} dynamic-slice(s32[65536,2]{1,0} %param_0.15, s32[] %param_1.35, s32[] %constant.129), dynamic_slice_sizes={1,2}
  %slice.59 = s32[1,1]{1,0} slice(s32[1,2]{1,0} %dynamic-slice.31), slice={[0:1], [0:1]}
  %bitcast.61 = s32[1]{0} bitcast(s32[1,1]{1,0} %slice.59)
  %bitcast.60 = s32[2]{0} bitcast(s32[1,2]{1,0} %dynamic-slice.31)
  %slice.58 = s32[1]{0} slice(s32[2]{0} %bitcast.60), slice={[1:2]}
  ROOT %concatenate.12 = s32[2]{0} concatenate(s32[1]{0} %bitcast.61, s32[1]{0} %slice.58), dimensions={0}
}

%while_body.3 (param.7: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> (s32[], f32[128,128], s32[65536,2], f32[65536]) {
  %param.7 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.36 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.7), index=0
  %copy.6 = s32[] copy(s32[] %get-tuple-element.36)
  %get-tuple-element.42 = s32[65536,2]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.7), index=2
  %slice_concatenate_fusion.3 = s32[2]{0} fusion(s32[65536,2]{1,0} %get-tuple-element.42, s32[] %copy.6), kind=kLoop, calls=%fused_computation.7
  %get-tuple-element.43 = f32[65536]{0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.7), index=3
  %get-tuple-element.37 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.7), index=1
  %select_dynamic-update-slice_fusion.3 = f32[128,128]{1,0} fusion(f32[128,128]{1,0} %get-tuple-element.37, f32[65536]{0} %get-tuple-element.43, s32[] %copy.6, s32[2]{0} %slice_concatenate_fusion.3, s32[65536,2]{1,0} %get-tuple-element.42), kind=kLoop, calls=%fused_computation.6
  %constant.68 = s32[] constant(1)
  %add.19 = s32[] add(s32[] %copy.6, s32[] %constant.68)
  ROOT %tuple.15 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %add.19, f32[128,128]{1,0} %select_dynamic-update-slice_fusion.3, s32[65536,2]{1,0} %get-tuple-element.42, f32[65536]{0} %get-tuple-element.43)
}

%while_cond.3 (param.6: (s32[], f32[128,128], s32[65536,2], f32[65536])) -> pred[] {
  %constant.63 = s32[] constant(65536)
  %param.6 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) parameter(0)
  %get-tuple-element.24 = s32[] get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %param.6), index=0
  ROOT %compare.25 = pred[] compare(s32[] %get-tuple-element.24, s32[] %constant.63), direction=LT
}

%region_0.146 (Arg_0.147: f32[], Arg_1.148: f32[]) -> f32[] {
  %Arg_0.147 = f32[] parameter(0), metadata={op_name="jit(fun)/jit(main)/reduce_max"}
  %Arg_1.148 = f32[] parameter(1), metadata={op_name="jit(fun)/jit(main)/reduce_max"}
  ROOT %maximum.149 = f32[] maximum(f32[] %Arg_0.147, f32[] %Arg_1.148), metadata={op_name="jit(fun)/jit(main)/reduce_max" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
}

%region_1.157 (Arg_0.158: f32[], Arg_1.159: f32[]) -> f32[] {
  %Arg_0.158 = f32[] parameter(0), metadata={op_name="jit(fun)/jit(main)/reduce_sum"}
  %Arg_1.159 = f32[] parameter(1), metadata={op_name="jit(fun)/jit(main)/reduce_sum"}
  ROOT %add.160 = f32[] add(f32[] %Arg_0.158, f32[] %Arg_1.159), metadata={op_name="jit(fun)/jit(main)/reduce_sum" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
}

%fused_computation.8 (param_0.17: f32[128,128], param_1.38: f32[128,128], param_2.25: f32[128,128], param_3.28: f32[128,128]) -> f32[128,128] {
  %param_2.25 = f32[128,128]{1,0} parameter(2)
  %param_3.28 = f32[128,128]{1,0} parameter(3)
  %add.38 = f32[128,128]{1,0} add(f32[128,128]{1,0} %param_2.25, f32[128,128]{1,0} %param_3.28), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/add_any" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_1.38 = f32[128,128]{1,0} parameter(1)
  %add.37 = f32[128,128]{1,0} add(f32[128,128]{1,0} %add.38, f32[128,128]{1,0} %param_1.38), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/add_any" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_0.17 = f32[128,128]{1,0} parameter(0)
  ROOT %add.36 = f32[128,128]{1,0} add(f32[128,128]{1,0} %add.37, f32[128,128]{1,0} %param_0.17), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/add_any" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%fused_computation.16 (param_0.47: f32[8,8]) -> f32[] {
  %constant.140 = f32[] constant(1)
  %param_0.47 = f32[8,8]{1,0} parameter(0)
  %constant.141 = f32[] constant(0)
  %reduce.10 = f32[] reduce(f32[8,8]{1,0} %param_0.47, f32[] %constant.141), dimensions={0,1}, to_apply=%region_1.157, metadata={op_name="jit(fun)/jit(main)/reduce_sum" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  ROOT %divide.0 = f32[] divide(f32[] %constant.140, f32[] %reduce.10), metadata={op_name="jit(fun)/jit(main)/div" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
}

%fused_computation.9.clone (param_0.83: pred[256,256], param_1.125: f32[], param_2.112: f32[256,256], param_3.87: f32[]) -> f32[65536] {
  %param_0.83 = pred[256,256]{1,0} parameter(0)
  %param_2.112 = f32[256,256]{1,0} parameter(2)
  %param_3.87 = f32[] parameter(3)
  %broadcast.153 = f32[256,256]{1,0} broadcast(f32[] %param_3.87), dimensions={}, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %compare.70 = pred[256,256]{1,0} compare(f32[256,256]{1,0} %param_2.112, f32[256,256]{1,0} %broadcast.153), direction=EQ, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %param_1.125 = f32[] parameter(1)
  %broadcast.151 = f32[256,256]{1,0} broadcast(f32[] %param_1.125), dimensions={}, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %constant.162 = f32[] constant(0)
  %broadcast.155 = f32[256,256]{1,0} broadcast(f32[] %constant.162), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.58 = f32[256,256]{1,0} select(pred[256,256]{1,0} %compare.70, f32[256,256]{1,0} %broadcast.151, f32[256,256]{1,0} %broadcast.155), metadata={op_name="jit(fun)/jit(main)/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %select.57 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_0.83, f32[256,256]{1,0} %select.58, f32[256,256]{1,0} %broadcast.155), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.78 = f32[65536]{0} bitcast(f32[256,256]{1,0} %select.57), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_select_bitcast_fusion (p: pred[256,256], p.1: f32[], p.2: f32[256,256], p.3: f32[]) -> f32[65536] {
  %p = pred[256,256]{1,0} parameter(0)
  %p.1 = f32[] parameter(1)
  %p.2 = f32[256,256]{1,0} parameter(2)
  %p.3 = f32[] parameter(3)
  ROOT %select_bitcast_fusion.clone = f32[65536]{0} fusion(pred[256,256]{1,0} %p, f32[] %p.1, f32[256,256]{1,0} %p.2, f32[] %p.3), kind=kLoop, calls=%fused_computation.9.clone, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.10.clone () -> s32[65536,2] {
  %iota.37 = s32[256,256,1]{2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.36 = s32[256,256,1]{2,1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.21 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %iota.37, s32[256,256,1]{2,1,0} %iota.36), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.79 = s32[65536,2]{1,0} bitcast(s32[256,256,2]{2,1,0} %concatenate.21), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_concatenate_bitcast_fusion () -> s32[65536,2] {
  ROOT %concatenate_bitcast_fusion.clone = s32[65536,2]{1,0} fusion(), kind=kLoop, calls=%fused_computation.10.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.11.clone (param_0.84: pred[256,256], param_1.126: f32[], param_2.113: f32[256,256], param_3.88: f32[]) -> f32[65536] {
  %param_0.84 = pred[256,256]{1,0} parameter(0)
  %constant.163 = f32[] constant(0)
  %broadcast.158 = f32[256,256]{1,0} broadcast(f32[] %constant.163), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_2.113 = f32[256,256]{1,0} parameter(2)
  %param_3.88 = f32[] parameter(3)
  %broadcast.157 = f32[256,256]{1,0} broadcast(f32[] %param_3.88), dimensions={}, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %compare.71 = pred[256,256]{1,0} compare(f32[256,256]{1,0} %param_2.113, f32[256,256]{1,0} %broadcast.157), direction=EQ, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %param_1.126 = f32[] parameter(1)
  %broadcast.156 = f32[256,256]{1,0} broadcast(f32[] %param_1.126), dimensions={}, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %select.60 = f32[256,256]{1,0} select(pred[256,256]{1,0} %compare.71, f32[256,256]{1,0} %broadcast.156, f32[256,256]{1,0} %broadcast.158), metadata={op_name="jit(fun)/jit(main)/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %multiply.21 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.158, f32[256,256]{1,0} %select.60), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.59 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_0.84, f32[256,256]{1,0} %multiply.21, f32[256,256]{1,0} %broadcast.158), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.80 = f32[65536]{0} bitcast(f32[256,256]{1,0} %select.59), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_select_bitcast_fusion.1 (p.4: pred[256,256], p.5: f32[], p.6: f32[256,256], p.7: f32[]) -> f32[65536] {
  %p.4 = pred[256,256]{1,0} parameter(0)
  %p.5 = f32[] parameter(1)
  %p.6 = f32[256,256]{1,0} parameter(2)
  %p.7 = f32[] parameter(3)
  ROOT %select_bitcast_fusion.1.clone = f32[65536]{0} fusion(pred[256,256]{1,0} %p.4, f32[] %p.5, f32[256,256]{1,0} %p.6, f32[] %p.7), kind=kLoop, calls=%fused_computation.11.clone, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.12.clone () -> s32[65536,2] {
  %iota.38 = s32[256,256,1]{2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.39 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.165 = s32[] constant(1)
  %broadcast.161 = s32[256,256]{1,0} broadcast(s32[] %constant.165), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.65 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.39, s32[256,256]{1,0} %broadcast.161), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.166 = s32[] constant(0)
  %broadcast.160 = s32[256,256]{1,0} broadcast(s32[] %constant.166), dimensions={}
  %compare.73 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.65, s32[256,256]{1,0} %broadcast.160), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.164 = s32[] constant(129)
  %broadcast.159 = s32[256,256]{1,0} broadcast(s32[] %constant.164), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.64 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.39, s32[256,256]{1,0} %broadcast.159), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.61 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.73, s32[256,256]{1,0} %add.64, s32[256,256]{1,0} %add.65), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.82 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.61), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.22 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %iota.38, s32[256,256,1]{2,1,0} %bitcast.82), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.81 = s32[65536,2]{1,0} bitcast(s32[256,256,2]{2,1,0} %concatenate.22), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_concatenate_bitcast_fusion.1 () -> s32[65536,2] {
  ROOT %concatenate_bitcast_fusion.1.clone = s32[65536,2]{1,0} fusion(), kind=kLoop, calls=%fused_computation.12.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.13.clone (param_0.85: pred[256,256], param_1.127: f32[], param_2.114: f32[256,256], param_3.89: f32[]) -> f32[65536] {
  %param_0.85 = pred[256,256]{1,0} parameter(0)
  %constant.167 = f32[] constant(0)
  %broadcast.165 = f32[256,256]{1,0} broadcast(f32[] %constant.167), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_2.114 = f32[256,256]{1,0} parameter(2)
  %param_3.89 = f32[] parameter(3)
  %broadcast.164 = f32[256,256]{1,0} broadcast(f32[] %param_3.89), dimensions={}, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %compare.74 = pred[256,256]{1,0} compare(f32[256,256]{1,0} %param_2.114, f32[256,256]{1,0} %broadcast.164), direction=EQ, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %param_1.127 = f32[] parameter(1)
  %broadcast.162 = f32[256,256]{1,0} broadcast(f32[] %param_1.127), dimensions={}, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %select.63 = f32[256,256]{1,0} select(pred[256,256]{1,0} %compare.74, f32[256,256]{1,0} %broadcast.162, f32[256,256]{1,0} %broadcast.165), metadata={op_name="jit(fun)/jit(main)/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %multiply.22 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.165, f32[256,256]{1,0} %select.63), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.62 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_0.85, f32[256,256]{1,0} %multiply.22, f32[256,256]{1,0} %broadcast.165), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.83 = f32[65536]{0} bitcast(f32[256,256]{1,0} %select.62), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_select_bitcast_fusion.2 (p.8: pred[256,256], p.9: f32[], p.10: f32[256,256], p.11: f32[]) -> f32[65536] {
  %p.10 = f32[256,256]{1,0} parameter(2)
  %p.11 = f32[] parameter(3)
  %p.8 = pred[256,256]{1,0} parameter(0)
  %p.9 = f32[] parameter(1)
  ROOT %select_bitcast_fusion.2.clone = f32[65536]{0} fusion(pred[256,256]{1,0} %p.8, f32[] %p.9, f32[256,256]{1,0} %p.10, f32[] %p.11), kind=kLoop, calls=%fused_computation.13.clone, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.14.clone () -> s32[65536,2] {
  %iota.41 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.170 = s32[] constant(1)
  %broadcast.168 = s32[256,256]{1,0} broadcast(s32[] %constant.170), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.67 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.41, s32[256,256]{1,0} %broadcast.168), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.171 = s32[] constant(0)
  %broadcast.167 = s32[256,256]{1,0} broadcast(s32[] %constant.171), dimensions={}
  %compare.75 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.67, s32[256,256]{1,0} %broadcast.167), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.169 = s32[] constant(129)
  %broadcast.166 = s32[256,256]{1,0} broadcast(s32[] %constant.169), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.66 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.41, s32[256,256]{1,0} %broadcast.166), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.64 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.75, s32[256,256]{1,0} %add.66, s32[256,256]{1,0} %add.67), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.85 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.64), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.40 = s32[256,256,1]{2,1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.23 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %bitcast.85, s32[256,256,1]{2,1,0} %iota.40), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.84 = s32[65536,2]{1,0} bitcast(s32[256,256,2]{2,1,0} %concatenate.23), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_concatenate_bitcast_fusion.2 () -> s32[65536,2] {
  ROOT %concatenate_bitcast_fusion.2.clone = s32[65536,2]{1,0} fusion(), kind=kLoop, calls=%fused_computation.14.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.15.clone (param_0.86: pred[256,256], param_1.128: f32[], param_2.115: f32[256,256], param_3.90: f32[]) -> f32[65536] {
  %param_0.86 = pred[256,256]{1,0} parameter(0)
  %constant.172 = f32[] constant(0)
  %broadcast.172 = f32[256,256]{1,0} broadcast(f32[] %constant.172), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_2.115 = f32[256,256]{1,0} parameter(2)
  %param_3.90 = f32[] parameter(3)
  %broadcast.171 = f32[256,256]{1,0} broadcast(f32[] %param_3.90), dimensions={}, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %compare.76 = pred[256,256]{1,0} compare(f32[256,256]{1,0} %param_2.115, f32[256,256]{1,0} %broadcast.171), direction=EQ, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %param_1.128 = f32[] parameter(1)
  %broadcast.170 = f32[256,256]{1,0} broadcast(f32[] %param_1.128), dimensions={}, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %select.66 = f32[256,256]{1,0} select(pred[256,256]{1,0} %compare.76, f32[256,256]{1,0} %broadcast.170, f32[256,256]{1,0} %broadcast.172), metadata={op_name="jit(fun)/jit(main)/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %multiply.23 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.172, f32[256,256]{1,0} %select.66), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.65 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_0.86, f32[256,256]{1,0} %multiply.23, f32[256,256]{1,0} %broadcast.172), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.86 = f32[65536]{0} bitcast(f32[256,256]{1,0} %select.65), metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_select_bitcast_fusion.3 (p.12: pred[256,256], p.13: f32[], p.14: f32[256,256], p.15: f32[]) -> f32[65536] {
  %p.12 = pred[256,256]{1,0} parameter(0)
  %p.13 = f32[] parameter(1)
  %p.14 = f32[256,256]{1,0} parameter(2)
  %p.15 = f32[] parameter(3)
  ROOT %select_bitcast_fusion.3.clone = f32[65536]{0} fusion(pred[256,256]{1,0} %p.12, f32[] %p.13, f32[256,256]{1,0} %p.14, f32[] %p.15), kind=kLoop, calls=%fused_computation.15.clone, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.17.clone (param_0.87: f32[256,256], param_1.129: f32[]) -> f32[256,256] {
  %param_0.87 = f32[256,256]{1,0} parameter(0)
  %param_1.129 = f32[] parameter(1)
  %broadcast.173 = f32[256,256]{1,0} broadcast(f32[] %param_1.129), dimensions={}, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %compare.77 = pred[256,256]{1,0} compare(f32[256,256]{1,0} %param_0.87, f32[256,256]{1,0} %broadcast.173), direction=EQ, metadata={op_name="jit(fun)/jit(main)/eq" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  ROOT %convert.9 = f32[256,256]{1,0} convert(pred[256,256]{1,0} %compare.77), metadata={op_name="jit(fun)/jit(main)/convert_element_type" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
}

%parallel_compare_convert_fusion (p.16: f32[256,256], p.17: f32[]) -> f32[256,256] {
  %p.16 = f32[256,256]{1,0} parameter(0)
  %p.17 = f32[] parameter(1)
  ROOT %compare_convert_fusion.clone = f32[256,256]{1,0} fusion(f32[256,256]{1,0} %p.16, f32[] %p.17), kind=kLoop, calls=%fused_computation.17.clone, metadata={op_name="jit(fun)/jit(main)/convert_element_type" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.18.clone (param_0.88: pred[256,256], param_1.130: f32[128,128], param_2.116: s32[256,256,2], param_3.91: pred[256,256], param_4.68: pred[256,256], param_5.72: pred[256,256]) -> f32[256,256] {
  %param_5.72 = pred[256,256]{1,0} parameter(5)
  %param_1.130 = f32[128,128]{1,0} parameter(1)
  %iota.45 = s32[256,256,1]{2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.44 = s32[256,256,1]{2,1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.26 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %iota.45, s32[256,256,1]{2,1,0} %iota.44), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %gather.11 = f32[256,256]{1,0} gather(f32[128,128]{1,0} %param_1.130, s32[256,256,2]{2,1,0} %concatenate.26), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/gather" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.175 = f32[] constant(0)
  %broadcast.174 = f32[256,256]{1,0} broadcast(f32[] %constant.175), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.70 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_5.72, f32[256,256]{1,0} %gather.11, f32[256,256]{1,0} %broadcast.174), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_4.68 = pred[256,256]{1,0} parameter(4)
  %iota.43 = s32[256,256,1]{2,1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.46 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.174 = s32[] constant(1)
  %broadcast.177 = s32[256,256]{1,0} broadcast(s32[] %constant.174), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.74 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.46, s32[256,256]{1,0} %broadcast.177), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.176 = s32[] constant(0)
  %broadcast.176 = s32[256,256]{1,0} broadcast(s32[] %constant.176), dimensions={}
  %compare.78 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.74, s32[256,256]{1,0} %broadcast.176), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.173 = s32[] constant(129)
  %broadcast.175 = s32[256,256]{1,0} broadcast(s32[] %constant.173), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.72 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.46, s32[256,256]{1,0} %broadcast.175), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.72 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.78, s32[256,256]{1,0} %add.72, s32[256,256]{1,0} %add.74), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.87 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.72), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.25 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %iota.43, s32[256,256,1]{2,1,0} %bitcast.87), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %gather.10 = f32[256,256]{1,0} gather(f32[128,128]{1,0} %param_1.130, s32[256,256,2]{2,1,0} %concatenate.25), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/gather" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.69 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_4.68, f32[256,256]{1,0} %gather.10, f32[256,256]{1,0} %broadcast.174), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %multiply.26 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.174, f32[256,256]{1,0} %select.69), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.71 = f32[256,256]{1,0} add(f32[256,256]{1,0} %select.70, f32[256,256]{1,0} %multiply.26), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_3.91 = pred[256,256]{1,0} parameter(3)
  %iota.47 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.76 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.47, s32[256,256]{1,0} %broadcast.177), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.79 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.76, s32[256,256]{1,0} %broadcast.176), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.75 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.47, s32[256,256]{1,0} %broadcast.175), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.73 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.79, s32[256,256]{1,0} %add.75, s32[256,256]{1,0} %add.76), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.88 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.73), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.42 = s32[256,256,1]{2,1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.24 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %bitcast.88, s32[256,256,1]{2,1,0} %iota.42), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %gather.9 = f32[256,256]{1,0} gather(f32[128,128]{1,0} %param_1.130, s32[256,256,2]{2,1,0} %concatenate.24), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/gather" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.68 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_3.91, f32[256,256]{1,0} %gather.9, f32[256,256]{1,0} %broadcast.174), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %multiply.25 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.174, f32[256,256]{1,0} %select.68), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.69 = f32[256,256]{1,0} add(f32[256,256]{1,0} %add.71, f32[256,256]{1,0} %multiply.25), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %param_0.88 = pred[256,256]{1,0} parameter(0)
  %param_2.116 = s32[256,256,2]{2,1,0} parameter(2)
  %gather.8 = f32[256,256]{1,0} gather(f32[128,128]{1,0} %param_1.130, s32[256,256,2]{2,1,0} %param_2.116), offset_dims={}, collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/gather" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.67 = f32[256,256]{1,0} select(pred[256,256]{1,0} %param_0.88, f32[256,256]{1,0} %gather.8, f32[256,256]{1,0} %broadcast.174), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/jit(_where)/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %multiply.24 = f32[256,256]{1,0} multiply(f32[256,256]{1,0} %broadcast.174, f32[256,256]{1,0} %select.67), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/mul" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %add.68 = f32[256,256]{1,0} add(f32[256,256]{1,0} %add.69, f32[256,256]{1,0} %multiply.24), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_multiply_add_fusion (p.18: pred[256,256], p.19: f32[128,128], p.20: s32[256,256,2], p.21: pred[256,256], p.22: pred[256,256], p.23: pred[256,256]) -> f32[256,256] {
  %p.18 = pred[256,256]{1,0} parameter(0)
  %p.19 = f32[128,128]{1,0} parameter(1)
  %p.20 = s32[256,256,2]{2,1,0} parameter(2)
  %p.21 = pred[256,256]{1,0} parameter(3)
  %p.22 = pred[256,256]{1,0} parameter(4)
  %p.23 = pred[256,256]{1,0} parameter(5)
  ROOT %multiply_add_fusion.clone = f32[256,256]{1,0} fusion(pred[256,256]{1,0} %p.18, f32[128,128]{1,0} %p.19, s32[256,256,2]{2,1,0} %p.20, pred[256,256]{1,0} %p.21, pred[256,256]{1,0} %p.22, /*index=5*/pred[256,256]{1,0} %p.23), kind=kLoop, calls=%fused_computation.18.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.19.clone () -> pred[256,256] {
  %iota.49 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.179 = s32[] constant(1)
  %broadcast.181 = s32[256,256]{1,0} broadcast(s32[] %constant.179), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.77 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.49, s32[256,256]{1,0} %broadcast.181), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.180 = s32[] constant(0)
  %broadcast.180 = s32[256,256]{1,0} broadcast(s32[] %constant.180), dimensions={}
  %compare.82 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.77, s32[256,256]{1,0} %broadcast.180), direction=GE, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/ge" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.177 = s32[] constant(128)
  %broadcast.178 = s32[256,256]{1,0} broadcast(s32[] %constant.177), dimensions={}
  %compare.81 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.77, s32[256,256]{1,0} %broadcast.178), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %and.29 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.82, pred[256,256]{1,0} %compare.81), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.48 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.80 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %iota.48, s32[256,256]{1,0} %broadcast.178), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %and.28 = pred[256,256]{1,0} and(pred[256,256]{1,0} %and.29, pred[256,256]{1,0} %compare.80), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_compare_and_fusion () -> pred[256,256] {
  ROOT %compare_and_fusion.clone = pred[256,256]{1,0} fusion(), kind=kLoop, calls=%fused_computation.19.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.20.clone () -> pred[256,256] {
  %iota.51 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.181 = s32[] constant(128)
  %broadcast.182 = s32[256,256]{1,0} broadcast(s32[] %constant.181), dimensions={}
  %compare.83 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %iota.51, s32[256,256]{1,0} %broadcast.182), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.50 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.182 = s32[] constant(1)
  %broadcast.184 = s32[256,256]{1,0} broadcast(s32[] %constant.182), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.78 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.50, s32[256,256]{1,0} %broadcast.184), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.183 = s32[] constant(0)
  %broadcast.183 = s32[256,256]{1,0} broadcast(s32[] %constant.183), dimensions={}
  %compare.86 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.78, s32[256,256]{1,0} %broadcast.183), direction=GE, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/ge" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.85 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.78, s32[256,256]{1,0} %broadcast.182), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %and.31 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.86, pred[256,256]{1,0} %compare.85), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %and.30 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.83, pred[256,256]{1,0} %and.31), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_compare_and_fusion.1 () -> pred[256,256] {
  ROOT %compare_and_fusion.1.clone = pred[256,256]{1,0} fusion(), kind=kLoop, calls=%fused_computation.20.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.21.clone () -> pred[256,256] {
  %iota.53 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.184 = s32[] constant(128)
  %broadcast.185 = s32[256,256]{1,0} broadcast(s32[] %constant.184), dimensions={}
  %compare.89 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %iota.53, s32[256,256]{1,0} %broadcast.185), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.52 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.88 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %iota.52, s32[256,256]{1,0} %broadcast.185), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %and.32 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.89, pred[256,256]{1,0} %compare.88), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_compare_and_fusion.2 () -> pred[256,256] {
  ROOT %compare_and_fusion.2.clone = pred[256,256]{1,0} fusion(), kind=kLoop, calls=%fused_computation.21.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.22.clone () -> pred[256,256] {
  %iota.55 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.186 = s32[] constant(1)
  %broadcast.188 = s32[256,256]{1,0} broadcast(s32[] %constant.186), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.80 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.55, s32[256,256]{1,0} %broadcast.188), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.187 = s32[] constant(0)
  %broadcast.187 = s32[256,256]{1,0} broadcast(s32[] %constant.187), dimensions={}
  %compare.93 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.80, s32[256,256]{1,0} %broadcast.187), direction=GE, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/ge" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.185 = s32[] constant(128)
  %broadcast.186 = s32[256,256]{1,0} broadcast(s32[] %constant.185), dimensions={}
  %compare.92 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.80, s32[256,256]{1,0} %broadcast.186), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %and.35 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.93, pred[256,256]{1,0} %compare.92), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.54 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.79 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.54, s32[256,256]{1,0} %broadcast.188), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.91 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.79, s32[256,256]{1,0} %broadcast.187), direction=GE, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/ge" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.90 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.79, s32[256,256]{1,0} %broadcast.186), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %and.34 = pred[256,256]{1,0} and(pred[256,256]{1,0} %compare.91, pred[256,256]{1,0} %compare.90), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %and.33 = pred[256,256]{1,0} and(pred[256,256]{1,0} %and.35, pred[256,256]{1,0} %and.34), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_and_and_fusion () -> pred[256,256] {
  ROOT %and_and_fusion.clone = pred[256,256]{1,0} fusion(), kind=kLoop, calls=%fused_computation.22.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/and" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.23.clone () -> s32[65536,2] {
  %iota.57 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.190 = s32[] constant(1)
  %broadcast.192 = s32[256,256]{1,0} broadcast(s32[] %constant.190), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.84 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.57, s32[256,256]{1,0} %broadcast.192), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.191 = s32[] constant(0)
  %broadcast.191 = s32[256,256]{1,0} broadcast(s32[] %constant.191), dimensions={}
  %compare.95 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.84, s32[256,256]{1,0} %broadcast.191), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.189 = s32[] constant(129)
  %broadcast.190 = s32[256,256]{1,0} broadcast(s32[] %constant.189), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.83 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.57, s32[256,256]{1,0} %broadcast.190), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.76 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.95, s32[256,256]{1,0} %add.83, s32[256,256]{1,0} %add.84), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.91 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.76), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.56 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.82 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.56, s32[256,256]{1,0} %broadcast.192), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.94 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.82, s32[256,256]{1,0} %broadcast.191), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.81 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.56, s32[256,256]{1,0} %broadcast.190), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.75 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.94, s32[256,256]{1,0} %add.81, s32[256,256]{1,0} %add.82), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.90 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.75), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %concatenate.27 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %bitcast.91, s32[256,256,1]{2,1,0} %bitcast.90), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %bitcast.89 = s32[65536,2]{1,0} bitcast(s32[256,256,2]{2,1,0} %concatenate.27), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_concatenate_bitcast_fusion.3 () -> s32[65536,2] {
  ROOT %concatenate_bitcast_fusion.3.clone = s32[65536,2]{1,0} fusion(), kind=kLoop, calls=%fused_computation.23.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

%fused_computation.24.clone () -> s32[256,256,2] {
  %iota.59 = s32[256,256]{1,0} iota(), iota_dimension=0, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.193 = s32[] constant(1)
  %broadcast.195 = s32[256,256]{1,0} broadcast(s32[] %constant.193), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/sub" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.90 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.59, s32[256,256]{1,0} %broadcast.195), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.194 = s32[] constant(0)
  %broadcast.194 = s32[256,256]{1,0} broadcast(s32[] %constant.194), dimensions={}
  %compare.97 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.90, s32[256,256]{1,0} %broadcast.194), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %constant.192 = s32[] constant(129)
  %broadcast.193 = s32[256,256]{1,0} broadcast(s32[] %constant.192), dimensions={}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.89 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.59, s32[256,256]{1,0} %broadcast.193), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.78 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.97, s32[256,256]{1,0} %add.89, s32[256,256]{1,0} %add.90), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.93 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.78), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %iota.58 = s32[256,256]{1,0} iota(), iota_dimension=1, metadata={op_name="jit(fun)/jit(main)/broadcast_in_dim" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.87 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.58, s32[256,256]{1,0} %broadcast.195), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %compare.96 = pred[256,256]{1,0} compare(s32[256,256]{1,0} %add.87, s32[256,256]{1,0} %broadcast.194), direction=LT, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/lt" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %add.86 = s32[256,256]{1,0} add(s32[256,256]{1,0} %iota.58, s32[256,256]{1,0} %broadcast.193), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %select.77 = s32[256,256]{1,0} select(pred[256,256]{1,0} %compare.96, s32[256,256]{1,0} %add.86, s32[256,256]{1,0} %add.87), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %bitcast.92 = s32[256,256,1]{2,1,0} bitcast(s32[256,256]{1,0} %select.77), metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/select_n" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %concatenate.28 = s32[256,256,2]{2,1,0} concatenate(s32[256,256,1]{2,1,0} %bitcast.93, s32[256,256,1]{2,1,0} %bitcast.92), dimensions={2}, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

%parallel_bitcast_concatenate_fusion () -> s32[256,256,2] {
  ROOT %bitcast_concatenate_fusion.clone = s32[256,256,2]{2,1,0} fusion(), kind=kLoop, calls=%fused_computation.24.clone, metadata={op_name="jit(fun)/jit(main)/jvp(jit(_map_coordinates))/concatenate" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"outer_dimension_partitions":["2"]}
}

ENTRY %main.237 (Arg_0.1: f32[128,128]) -> f32[128,128] {
  %call.10 = pred[256,256]{1,0} call(), to_apply=%parallel_compare_and_fusion.1
  %call.11 = pred[256,256]{1,0} call(), to_apply=%parallel_compare_and_fusion.2
  %call.12 = pred[256,256]{1,0} call(), to_apply=%parallel_and_and_fusion
  %call.9 = pred[256,256]{1,0} call(), to_apply=%parallel_compare_and_fusion
  %call.14 = s32[256,256,2]{2,1,0} call(), to_apply=%parallel_bitcast_concatenate_fusion
  %Arg_0.1 = f32[128,128]{1,0} parameter(0), metadata={op_name="x"}
  %call.8 = f32[256,256]{1,0} call(pred[256,256]{1,0} %call.12, f32[128,128]{1,0} %Arg_0.1, s32[256,256,2]{2,1,0} %call.14, pred[256,256]{1,0} %call.9, pred[256,256]{1,0} %call.10, /*index=5*/pred[256,256]{1,0} %call.11), to_apply=%parallel_multiply_add_fusion
  %constant.4 = f32[] constant(-inf)
  %reduce-window = f32[8,8]{1,0} reduce-window(f32[256,256]{1,0} %call.8, f32[] %constant.4), window={size=32x32 stride=32x32}, to_apply=%region_0.146
  %reduce.150 = f32[] reduce(f32[8,8]{1,0} %reduce-window, f32[] %constant.4), dimensions={0,1}, to_apply=%region_0.146, metadata={op_name="jit(fun)/jit(main)/reduce_max" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %call.7 = f32[256,256]{1,0} call(f32[256,256]{1,0} %call.8, f32[] %reduce.150), to_apply=%parallel_compare_convert_fusion
  %constant.3 = f32[] constant(0)
  %reduce-window.1 = f32[8,8]{1,0} reduce-window(f32[256,256]{1,0} %call.7, f32[] %constant.3), window={size=32x32 stride=32x32}, to_apply=%region_1.157
  %reduce_divide_fusion = f32[] fusion(f32[8,8]{1,0} %reduce-window.1), kind=kLoop, calls=%fused_computation.16, metadata={op_name="jit(fun)/jit(main)/div" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=24}
  %call.4 = f32[65536]{0} call(pred[256,256]{1,0} %call.9, f32[] %reduce_divide_fusion, f32[256,256]{1,0} %call.8, f32[] %reduce.150), to_apply=%parallel_select_bitcast_fusion.2
  %broadcast.18 = f32[128,128]{1,0} broadcast(f32[] %constant.3), dimensions={}
  %copy.23 = f32[128,128]{1,0} copy(f32[128,128]{1,0} %broadcast.18)
  %call.5 = s32[65536,2]{1,0} call(), to_apply=%parallel_concatenate_bitcast_fusion.2
  %constant.8 = s32[] constant(0)
  %copy.22 = s32[] copy(s32[] %constant.8)
  %tuple.22 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %copy.22, f32[128,128]{1,0} %copy.23, s32[65536,2]{1,0} %call.5, f32[65536]{0} %call.4)
  %while.1 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) while((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %tuple.22), condition=%while_cond.1, body=%while_body.1, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"known_trip_count":{"n":"65536"}}
  %get-tuple-element.13 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %while.1), index=1, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %call.6 = f32[65536]{0} call(pred[256,256]{1,0} %call.12, f32[] %reduce_divide_fusion, f32[256,256]{1,0} %call.8, f32[] %reduce.150), to_apply=%parallel_select_bitcast_fusion.3
  %copy.17 = f32[128,128]{1,0} copy(f32[128,128]{1,0} %broadcast.18)
  %call.13 = s32[65536,2]{1,0} call(), to_apply=%parallel_concatenate_bitcast_fusion.3
  %copy.16 = s32[] copy(s32[] %constant.8)
  %tuple.19 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %copy.16, f32[128,128]{1,0} %copy.17, s32[65536,2]{1,0} %call.13, f32[65536]{0} %call.6)
  %while = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) while((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %tuple.19), condition=%while_cond, body=%while_body, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"known_trip_count":{"n":"65536"}}
  %get-tuple-element.5 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %while), index=1, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %call.2 = f32[65536]{0} call(pred[256,256]{1,0} %call.10, f32[] %reduce_divide_fusion, f32[256,256]{1,0} %call.8, f32[] %reduce.150), to_apply=%parallel_select_bitcast_fusion.1
  %copy.11 = f32[128,128]{1,0} copy(f32[128,128]{1,0} %broadcast.18)
  %call.3 = s32[65536,2]{1,0} call(), to_apply=%parallel_concatenate_bitcast_fusion.1
  %copy.10 = s32[] copy(s32[] %constant.8)
  %tuple.16 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %copy.10, f32[128,128]{1,0} %copy.11, s32[65536,2]{1,0} %call.3, f32[65536]{0} %call.2)
  %while.2 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) while((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %tuple.16), condition=%while_cond.2, body=%while_body.2, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"known_trip_count":{"n":"65536"}}
  %get-tuple-element.21 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %while.2), index=1, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  %call = f32[65536]{0} call(pred[256,256]{1,0} %call.11, f32[] %reduce_divide_fusion, f32[256,256]{1,0} %call.8, f32[] %reduce.150), to_apply=%parallel_select_bitcast_fusion
  %copy.5 = f32[128,128]{1,0} copy(f32[128,128]{1,0} %broadcast.18)
  %call.1 = s32[65536,2]{1,0} call(), to_apply=%parallel_concatenate_bitcast_fusion
  %copy.4 = s32[] copy(s32[] %constant.8)
  %tuple.13 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) tuple(s32[] %copy.4, f32[128,128]{1,0} %copy.5, s32[65536,2]{1,0} %call.1, f32[65536]{0} %call)
  %while.3 = (s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) while((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %tuple.13), condition=%while_cond.3, body=%while_body.3, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}, backend_config={"known_trip_count":{"n":"65536"}}
  %get-tuple-element.29 = f32[128,128]{1,0} get-tuple-element((s32[], f32[128,128]{1,0}, s32[65536,2]{1,0}, f32[65536]{0}) %while.3), index=1, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/scatter-add" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
  ROOT %add_add_fusion = f32[128,128]{1,0} fusion(f32[128,128]{1,0} %get-tuple-element.29, f32[128,128]{1,0} %get-tuple-element.21, f32[128,128]{1,0} %get-tuple-element.5, f32[128,128]{1,0} %get-tuple-element.13), kind=kLoop, calls=%fused_computation.8, metadata={op_name="jit(fun)/jit(main)/transpose(jvp(jit(_map_coordinates)))/add_any" source_file="<ipython-input-1-aaf6d48ac58c>" source_line=23}
}

@lockwo
Copy link
Contributor

lockwo commented Jan 28, 2025

This looks to me like the usual small while loop performance issue

This small while loop issue is unknown to me (although I use a lot of while loops and have encountered CPU performance issues with them), is there any documentation I could reference on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants