How to use nn.scan and nn.BatchNorm in Flax? #2067
-
Hi, I'm trying to use from flax.core import Scope, Array, init, unfreeze, lift
from flax import linen as nn
import jax
from jax import random, numpy as jnp
class MLP(nn.Module):
@nn.compact
def __call__(self, c, x):
h = nn.Dense(features=10)(x)
h = nn.BatchNorm(use_running_average=False)(h)
y = nn.Dense(features=1)(h)
return c, y
xs = jnp.zeros((10, 2))
p = MLP().init(random.PRNGKey(1), (), xs[0])
scan_mlp = nn.scan(MLP, variable_carry='batch_stats', variable_broadcast='params', split_rngs={'params': False})()
(cs, ys), variables = scan_mlp.apply(p, (), xs, mutable='batch_stats') But if I wrap up from flax.core import Scope, Array, init, unfreeze, lift
from flax import linen as nn
import jax
from jax import random, numpy as jnp
class MLP(nn.Module):
@nn.compact
def __call__(self, c, x):
h = nn.Dense(features=10)(x)
h = nn.BatchNorm(use_running_average=False)(h)
y = nn.Dense(features=1)(h)
return c, y
class ScanMLP(nn.Module):
@nn.compact
def __call__(self, c, xs):
scan = nn.scan(MLP, variable_carry='batch_stats', variable_broadcast='params', split_rngs={'params': False})
return scan()(c, xs)
xs = jnp.zeros((10, 2))
scan_mlp = ScanMLP()
p = scan_mlp.init(random.PRNGKey(1), (), xs)
(cs, ys), variables = scan_mlp.apply(p, (), xs, mutable='batch_stats') The code above returns the following error:
Any idea why these two similar pieces of code behave differently? |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments
-
The issue here is that JAX doesn't support a carry that changes structure within the loop. For |
Beta Was this translation helpful? Give feedback.
-
@denisyarats For now this works: from flax.core import Scope, Array, init, unfreeze, lift
from flax import linen as nn
import jax
from jax import random, numpy as jnp
class MLP(nn.Module):
@nn.compact
def __call__(self, c, x):
h = nn.Dense(features=10)(x)
h = nn.BatchNorm(use_running_average=False)(h)
y = nn.Dense(features=1)(h)
return c, y
class ScanMLP(nn.Module):
@nn.compact
def __call__(self, c, xs):
scan = nn.scan(
MLP,
variable_carry="batch_stats",
variable_broadcast="params",
split_rngs={"params": False},
)
is_initializing = "batch_stats" not in self.variables
if is_initializing:
return MLP(name="MLP")(c, xs)
else:
return scan(name="MLP")(c, xs)
xs = jnp.zeros((10, 2))
scan_mlp = ScanMLP()
p = scan_mlp.init(random.PRNGKey(1), (), xs)
print(jax.tree_map(lambda x: x.shape, p))
(cs, ys), updates = scan_mlp.apply(p, (), xs, mutable="batch_stats")
p = p.copy(updates)
print(jax.tree_map(lambda x: x.shape, p)) @jheek per #652 it would be nice to have an easier mechanism to know if we are inside |
Beta Was this translation helpful? Give feedback.
-
Thanks @jheek and @cgarciae, verified this on my end as well. The trick is to not initialize nn.scan and give the same name i.e. |
Beta Was this translation helpful? Give feedback.
-
Thank @cgarciae for the work around. I ran into this issue today and while I understand why from a Jax's perspective, from a Flax user's perspective it feels like a bug to me. When the module to be scanned is a bit more complicated and depends on more inputs, it leads to lots of code inside the module to handles whether it's being scanned or not. Looking forward to this feature from @jheek
|
Beta Was this translation helpful? Give feedback.
-
@cgarciae, for flax, is this still the solution? getting the same issue with
Scan code:
|
Beta Was this translation helpful? Give feedback.
@denisyarats For now this works: