Skip to content

flax.linen.sow (called once) returns multiple values #3042

Answered by andsteing
jdeschena asked this question in Q&A
Discussion options

You must be logged in to vote

It is because it is already sowed during initializiation:

>>> model = StatefulMLP()
>>> variables = model.init(key, rand_input)
>>> jax.tree_util.tree_map(jnp.shape, variables)
FrozenDict({
    params: {
        hidden: {
            bias: (3,),
            kernel: (2, 3),
        },
    },
    test: {
        key: ((6, 2),),
    },
})
>>> out1 = model.apply(variables, rand_input, mutable=['test'])[1]
>>> jax.tree_util.tree_map(jnp.shape, out1)
FrozenDict({
    test: {
        key: ((6, 2), (6, 2)),
    },
})

Note that model.init() returns all variable collections, not only "params", so you might want to write:

>>> params = variables['params']
>>> out2 = model.apply({'params': params}, ra…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jdeschena
Comment options

Answer selected by cgarciae
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants