-
Hello, While playing with class StatefulMLP(nn.Module):
@nn.compact
def __call__(self, x):
h = nn.Dense(3, name='hidden')(x)
self.sow("test", "key", x)
return h
key = jax.random.PRNGKey(0)
init_key, input_key, key = jax.random.split(key, num=3)
rand_input = jax.random.normal(input_key, (6, 2))
model = StatefulMLP()
params = model.init(key, rand_input)
out = model.apply(params, rand_input, mutable=['test'])[1]
print(out)
"""
FrozenDict({
test: {
key: (Array([[-0.39216983, -0.22660121],
[-0.366704 , -0.9057831 ],
[ 0.8396763 , -0.6405287 ],
[-0.93676966, -1.8203913 ],
[ 1.3435824 , 0.61339337],
[-1.8785557 , -0.07508122]], dtype=float32), Array([[-0.39216983, -0.22660121],
[-0.366704 , -0.9057831 ],
[ 0.8396763 , -0.6405287 ],
[-0.93676966, -1.8203913 ],
[ 1.3435824 , 0.61339337],
[-1.8785557 , -0.07508122]], dtype=float32)),
},
})
""" |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It is because it is already sowed during initializiation: >>> model = StatefulMLP()
>>> variables = model.init(key, rand_input)
>>> jax.tree_util.tree_map(jnp.shape, variables)
FrozenDict({
params: {
hidden: {
bias: (3,),
kernel: (2, 3),
},
},
test: {
key: ((6, 2),),
},
})
>>> out1 = model.apply(variables, rand_input, mutable=['test'])[1]
>>> jax.tree_util.tree_map(jnp.shape, out1)
FrozenDict({
test: {
key: ((6, 2), (6, 2)),
},
}) Note that >>> params = variables['params']
>>> out2 = model.apply({'params': params}, rand_input, mutable=['test'])[1]
>>> jax.tree_util.tree_map(jnp.shape, out2)
FrozenDict({
test: {
key: ((6, 2),),
},
}) Alternatively, you could specify to only initialize "params": >>> variables = model.init(key, rand_input, mutable=['params'])
>>> jax.tree_util.tree_map(jnp.shape, variables)
FrozenDict({
params: {
hidden: {
bias: (3,),
kernel: (2, 3),
},
},
}) BTW, |
Beta Was this translation helpful? Give feedback.
It is because it is already sowed during initializiation:
Note that
model.init()
returns all variable collections, not only "params", so you might want to write: