Skip to content

Commit

Permalink
Merge pull request #4472 from google:nnx-fix-fori
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712665956
  • Loading branch information
Flax Authors committed Jan 6, 2025
2 parents 53bde74 + f8164dd commit 595e711
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/jax_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Tests for flax.jax_utils."""

from functools import partial
import os
import re

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -26,9 +28,21 @@

NDEV = 4

_xla_device_count_flag_regexp = (
r'[-]{0,2}xla_force_host_platform_device_count=(\d+)?(\s|$)'
)


def set_n_cpu_devices(n: int):
xla_flags = os.getenv('XLA_FLAGS', '')
xla_flags = re.sub(_xla_device_count_flag_regexp, '', xla_flags)
os.environ['XLA_FLAGS'] = ' '.join(
[f'--xla_force_host_platform_device_count={n}'] + xla_flags.split()
)


def setUpModule():
chex.set_n_cpu_devices(NDEV)
set_n_cpu_devices(NDEV)


class PadShardUnpadTest(chex.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,18 @@ def loop_fn(inputs):
nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2))
nnx.fori_loop(0, 2, fori_loop_fn, (a, b))

def test_fori_output(self):
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))

def f(i, x):
return x

model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2))

self.assertIs(model, model_out)
self.assertIs(model2, model2_out)


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
Expand Down

0 comments on commit 595e711

Please sign in to comment.