Skip to content

Commit 7234157

Browse files
author
Flax Authors
committed
Merge pull request #4343 from IvyZX:conds
PiperOrigin-RevId: 691964939
2 parents 8292d9c + 4f5d6fb commit 7234157

File tree

5 files changed

+332
-3
lines changed

5 files changed

+332
-3
lines changed

docs_nnx/api_reference/flax.nnx/transforms.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,7 @@ transforms
2020
.. autofunction:: value_and_grad
2121
.. autofunction:: vmap
2222
.. autofunction:: eval_shape
23-
.. autofunction:: cond
2423
.. autofunction:: custom_vjp
24+
.. autofunction:: cond
25+
.. autofunction:: switch
26+
.. autofunction:: while_loop

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@
150150
from .transforms.iteration import pmap as pmap
151151
from .transforms.transforms import eval_shape as eval_shape
152152
from .transforms.transforms import cond as cond
153+
from .transforms.transforms import switch as switch
154+
from .transforms.iteration import while_loop as while_loop
153155
from .transforms.iteration import StateAxes as StateAxes
154156
from .variablelib import A as A
155157
from .variablelib import BatchStat as BatchStat

flax/nnx/transforms/iteration.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
M = tp.TypeVar('M', bound=Module)
4141
MA = tp.TypeVar('MA', bound=Module)
4242
N = tp.TypeVar('N', bound=Module)
43+
T = tp.TypeVar('T')
4344
StrInt = tp.TypeVar('StrInt', str, int)
4445
AxisName = tp.Hashable
4546
Leaves = tp.List[Leaf]
@@ -1304,3 +1305,147 @@ def scan_wrapper(*args, **kwargs):
13041305
return out
13051306

13061307
return scan_wrapper # type: ignore
1308+
1309+
1310+
1311+
1312+
1313+
# -------------------------------
1314+
# while_loop
1315+
# -------------------------------
1316+
1317+
1318+
@dataclasses.dataclass(eq=False)
1319+
class WhileLoopCondFn:
1320+
f: tp.Callable[..., tp.Any]
1321+
1322+
def __post_init__(self):
1323+
functools.update_wrapper(self, self.f)
1324+
1325+
def __call__(self, pure_val):
1326+
val = extract.from_tree(pure_val)
1327+
out = self.f(val)
1328+
return out
1329+
1330+
1331+
def _add_fake_index_mapping(tree: tp.Any):
1332+
def per_node_state(ns: extract.NodeStates | tp.Any):
1333+
global_index_mapping = {}
1334+
if not isinstance(ns, extract.NodeStates) or not isinstance(
1335+
ns._graphdef, graph.NodeDef
1336+
):
1337+
return ns
1338+
1339+
def per_node_def(nd: graph.NodeDef | tp.Any):
1340+
if nd.index >= 0:
1341+
global_index_mapping[nd.index] = nd.index
1342+
for sub_nd in nd.subgraphs.values():
1343+
per_node_def(sub_nd)
1344+
for l in nd.leaves.values():
1345+
if isinstance(l, graph.NodeRef) and l.index >= 0:
1346+
global_index_mapping[l.index] = l.index
1347+
return
1348+
1349+
per_node_def(ns._graphdef)
1350+
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
1351+
ns._graphdef,
1352+
index_mapping=FrozenDict(global_index_mapping)
1353+
))
1354+
1355+
return jax.tree.map(per_node_state, tree,
1356+
is_leaf=lambda x: isinstance(x, extract.NodeStates))
1357+
1358+
1359+
def _remove_index_mapping(tree: tp.Any):
1360+
'''Remove a fake index_mapping for the input to match that of the output.'''
1361+
def per_node_state(ns: extract.NodeStates | tp.Any):
1362+
if not isinstance(ns, extract.NodeStates) or not isinstance(
1363+
ns._graphdef, graph.NodeDef
1364+
):
1365+
return ns
1366+
assert isinstance(ns._graphdef, graph.NodeDef)
1367+
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
1368+
ns._graphdef, index_mapping=None
1369+
))
1370+
1371+
return jax.tree.map(per_node_state, tree,
1372+
is_leaf=lambda x: isinstance(x, extract.NodeStates))
1373+
1374+
1375+
@dataclasses.dataclass(eq=False)
1376+
class WhileLoopBodyFn:
1377+
f: tp.Callable[..., tp.Any]
1378+
1379+
def __post_init__(self):
1380+
functools.update_wrapper(self, self.f)
1381+
1382+
@graph.update_context('while_loop_body')
1383+
def __call__(self, pure_val):
1384+
# Removing the dummy index mapping being added outside of body function.
1385+
pure_val_in = _remove_index_mapping(pure_val)
1386+
1387+
val = extract.from_tree(pure_val_in, ctxtag='while_loop_body')
1388+
out = self.f(val)
1389+
pure_out = extract.to_tree(out, ctxtag='while_loop_body')
1390+
1391+
try:
1392+
jax.tree.map(lambda a, b: None, pure_val, pure_out)
1393+
except ValueError as e:
1394+
msg = ("nnx.while_loop requires body function's input and output to "
1395+
"have the same reference and pytree structure, but they differ. "
1396+
"If the mismatch comes from `index_mapping` field, you might "
1397+
"have modified reference structure within the body function, "
1398+
"which is not allowed."
1399+
f"Detail of the mismatch: \n {str(e)}")
1400+
raise ValueError(msg)
1401+
1402+
return pure_out
1403+
1404+
1405+
@graph.update_context('while_loop')
1406+
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
1407+
body_fun: tp.Callable[[T], T],
1408+
init_val: T) -> T:
1409+
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
1410+
1411+
Caution: for the NNX internal reference tracing mechanism to work, you cannot
1412+
change the reference structure of `init_val` inside `body_fun`.
1413+
1414+
Example::
1415+
1416+
>>> import jax
1417+
>>> from flax import nnx
1418+
>>> def fwd_fn(input):
1419+
... module, x, count = input
1420+
... return module, module(x), count - 1.0
1421+
1422+
>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
1423+
>>> x = jax.random.normal(jax.random.key(0), (10,))
1424+
>>> # `module` will be called three times
1425+
>>> _, y, _ = nnx.while_loop(
1426+
... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
1427+
1428+
1429+
Args:
1430+
cond_fun: a function for the continue condition of the while loop, taking a
1431+
single input of type `T` and outputting a boolean.
1432+
body_fun: a function that takes an input of type `T` and outputs an `T`.
1433+
Note that both data and modules of `T` must have the same reference
1434+
structure between inputs and outputs.
1435+
init_val: the initial input for cond_fun and body_fun. Must be of type `T`.
1436+
1437+
"""
1438+
1439+
pure_init_val = extract.to_tree(init_val, ctxtag='while_loop')
1440+
1441+
# Adding the expected reference mapping to `pure_init_val` to match
1442+
# `body_fun`'s output pytree structure, to make JAX while_loop happy.
1443+
pure_init_val = _add_fake_index_mapping(pure_init_val)
1444+
1445+
pure_out = jax.lax.while_loop(
1446+
WhileLoopCondFn(cond_fun),
1447+
WhileLoopBodyFn(body_fun),
1448+
pure_init_val,
1449+
)
1450+
out = extract.from_tree(pure_out, ctxtag='while_loop')
1451+
return out

flax/nnx/transforms/transforms.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _eval_shape_fn(*args, **kwargs):
141141

142142

143143
# -------------------------------
144-
# cond
144+
# cond and switch
145145
# -------------------------------
146146

147147

@@ -160,3 +160,17 @@ def cond(
160160
*operands,
161161
**kwargs,
162162
)
163+
164+
165+
@general.split_inputs(ctxtag='switch')
166+
def switch(
167+
index,
168+
branches: tp.Sequence[tp.Callable[..., A]],
169+
*operands,
170+
) -> A:
171+
return jax.lax.switch(
172+
index,
173+
[general.merge_inputs(f, ctxtag='switch') for f in branches],
174+
*operands,
175+
)
176+

tests/nnx/transforms_test.py

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,6 @@ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:
18171817

18181818
x = jnp.ones((16, 10, 20))
18191819
y = rnn_forward(cell, x)
1820-
print(y.shape)
18211820

18221821

18231822
class TestRemat(absltest.TestCase):
@@ -2756,6 +2755,173 @@ def no_nothing(env: Env):
27562755
)
27572756

27582757

2758+
class TestSwitch(absltest.TestCase):
2759+
def test_basic(self):
2760+
class RoundTable(nnx.Module):
2761+
def __init__(self):
2762+
self.next_index = 0
2763+
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2764+
self.linear.kernel.value = jnp.identity(10)
2765+
self.rounds_count = nnx.Variable(jnp.array(0))
2766+
2767+
def __call__(self, x):
2768+
def fn0(m, x):
2769+
m.rounds_count += 1
2770+
return m.linear(x)
2771+
def fn1(m, x):
2772+
return m.linear(x) * 2
2773+
def fn2(m, x):
2774+
m.linear.kernel.value = jnp.zeros((10, 10))
2775+
return m.linear(x)
2776+
2777+
# y = nnx.cond(self.next_index.value == 0, fn0, fn1, self, x)
2778+
y = nnx.switch(self.next_index, (fn0, fn1, fn2), self, x)
2779+
self.next_index = (self.next_index + 1) % 3
2780+
return y
2781+
2782+
model = RoundTable()
2783+
x = jnp.ones((10,))
2784+
np.testing.assert_array_equal(model(x), x)
2785+
assert model.rounds_count.value == 1
2786+
assert model.next_index == 1
2787+
np.testing.assert_array_equal(model(x), x * 2)
2788+
assert model.rounds_count.value == 1
2789+
assert model.next_index == 2
2790+
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
2791+
assert model.rounds_count.value == 1
2792+
assert model.next_index == 0
2793+
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
2794+
assert model.rounds_count.value == 2
2795+
assert model.next_index == 1
2796+
2797+
2798+
class TestWhileLoop(absltest.TestCase):
2799+
def test_basic(self):
2800+
def fwd_fn(input):
2801+
m, x, c = input
2802+
y = m(x)
2803+
return m, y, c - 1.0
2804+
2805+
module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2806+
module.kernel.value = jnp.identity(10) * 2
2807+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2808+
2809+
_, y, _ = nnx.while_loop(
2810+
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
2811+
np.testing.assert_array_equal(y, x * 8)
2812+
2813+
def test_multiple_objects(self):
2814+
def fwd_fn(input):
2815+
m1, (w2,), x, c = input
2816+
y = m1(x) @ w2
2817+
return m1, (w2,), y, c - 1.0
2818+
2819+
m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2820+
m1.kernel.value = jnp.identity(10) * 2
2821+
w2 = nnx.Variable(jnp.identity(10) * 0.5)
2822+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2823+
2824+
_, _, y, _ = nnx.while_loop(
2825+
lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0))
2826+
np.testing.assert_allclose(y, x)
2827+
2828+
def test_nested_module(self):
2829+
def fwd_fn(input):
2830+
m, x, c = input
2831+
y = m(x)
2832+
return m, y, c - 1.0
2833+
2834+
module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2835+
module.kernel.value = jnp.identity(10) * 2
2836+
module = nnx.Sequential(module)
2837+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2838+
2839+
_, y, _ = nnx.while_loop(
2840+
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
2841+
np.testing.assert_array_equal(y, x * 8)
2842+
2843+
2844+
def test_shared_module(self):
2845+
m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2846+
m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0))
2847+
m2.kernel = m1.kernel
2848+
module = nnx.Sequential(m1, m2)
2849+
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params
2850+
2851+
def fwd_fn(input):
2852+
m, x, c = input
2853+
y = m(x)
2854+
m.layers[0].kernel.value = jnp.zeros_like(m.layers[0].kernel.value)
2855+
return m, y, c - 1.0
2856+
2857+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2858+
_, y, _ = nnx.while_loop(
2859+
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))
2860+
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params
2861+
np.testing.assert_array_equal(m1.kernel.value, jnp.zeros((10, 10,)))
2862+
np.testing.assert_array_equal(m2.kernel.value, jnp.zeros((10, 10,)))
2863+
np.testing.assert_array_equal(y, jnp.zeros((10,)))
2864+
2865+
2866+
def test_value_changed(self):
2867+
def fwd_fn(input):
2868+
m, x, c = input
2869+
m.kernel.value = jnp.zeros_like(m.kernel.value)
2870+
y = m(x)
2871+
return m, y, c - 1.0
2872+
2873+
module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2874+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2875+
2876+
_, y, _ = nnx.while_loop(
2877+
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
2878+
np.testing.assert_array_equal(module.kernel.value, jnp.zeros((10, 10,)))
2879+
np.testing.assert_array_equal(y, jnp.zeros((10,)))
2880+
2881+
2882+
def test_ref_changed(self):
2883+
def fwd_fn(input):
2884+
m, x, c = input
2885+
y = m(x)
2886+
m.kernel = nnx.Param(jnp.zeros_like(m.kernel.value))
2887+
return m, y, c - 1.0
2888+
2889+
module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2890+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2891+
2892+
with self.assertRaises(ValueError):
2893+
_, y, _ = nnx.while_loop(
2894+
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))
2895+
2896+
2897+
def test_structure_changed(self):
2898+
def fwd_fn(input):
2899+
m, x, c = input
2900+
m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1))
2901+
m.kernel.value = jnp.identity(10) * 2
2902+
y = m(x)
2903+
return m, y, c - 1.0
2904+
2905+
module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0))
2906+
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
2907+
2908+
with self.assertRaises(ValueError):
2909+
_, y, _ = nnx.while_loop(
2910+
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
2911+
2912+
def test_repeated_object(self):
2913+
m = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
2914+
2915+
def body_fn(val):
2916+
count, m, _ = val
2917+
return count + 1, m, m
2918+
2919+
count, m, _ = nnx.while_loop(
2920+
lambda val: val[0] < 2,
2921+
body_fn,
2922+
(0, m, m),
2923+
)
2924+
27592925
class TestSplitMergeInputs(absltest.TestCase):
27602926
def test_split_inputs(self):
27612927
class StatefulLinear(nnx.Linear):

0 commit comments

Comments
 (0)