Description
Currently, applying jax.tree.map
applies to both Variable
and VariableState
, despite setting is_leaf
to select only the VariableState
s.
Expected (function is not applied to Variables when selecting based on VariableState):
from flax import nnx
state = nnx.State({"param1": nnx.Param(10.0)})
print(state)
# State({
# 'param1': Param(
# value=10.0
# )
# })
print(isinstance(state.param1, nnx.VariableState))
# False
mapped = jax.tree.map(
lambda x: x.value, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
)
print(mapped)
# State({
# 'param1': Param(
# value=10.0
# )
# })
Actual (function is applied to Variables when selecting based on VariableState)
from flax import nnx
state = nnx.State({"param1": nnx.Param(10.0)})
print(state)
# State({
# 'param1': Param(
# value=10.0
# )
# })
mapped = jax.tree.map(
lambda x: x.value, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
)
print(mapped)
# State({
# 'param1': 10.0
# })
Is this behaviour intended?
Context: I found this by accident when doing the following. I need to apply a set of transforms to leaves of my State. These transforms modify/set various metadata of the leaves. Variable.replace
allows you to modify metadata, however, VariableState.replace
only allows modification of the value. As leaves of State
can either be Variables or VariableState, I need to be able to filter between these two cases and apply different updates in each case (for Variable
, I can do x.replace(transformed_value, **transformed_metadata)
, but for VariableState
, I need to instantiate a new object, nnx.VariableState(variable_type, transformed_value, **transformed_metadata)
).