Skip to content

How to use nn.scan and nn.BatchNorm in Flax? #2067

Answered by cgarciae
denisyarats asked this question in Q&A
Discussion options

You must be logged in to vote

@denisyarats For now this works:

from flax.core import Scope, Array, init, unfreeze, lift
from flax import linen as nn
import jax
from jax import random, numpy as jnp


class MLP(nn.Module):
    @nn.compact
    def __call__(self, c, x):
        h = nn.Dense(features=10)(x)
        h = nn.BatchNorm(use_running_average=False)(h)
        y = nn.Dense(features=1)(h)
        return c, y

class ScanMLP(nn.Module):
    @nn.compact
    def __call__(self, c, xs):
        scan = nn.scan(
            MLP,
            variable_carry="batch_stats",
            variable_broadcast="params",
            split_rngs={"params": False},
        )

        is_initializing = "batch_stats" not in self.variables

Replies: 5 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants
Converted from issue

This discussion was converted from issue #2052 on April 25, 2022 09:06.