Skip to content

jax.tree.map applies to both Variable and VariableState, despite setting is_leaf #4762

Closed
@theo-brown

Description

@theo-brown

Currently, applying jax.tree.map applies to both Variable and VariableState, despite setting is_leaf to select only the VariableStates.

Expected (function is not applied to Variables when selecting based on VariableState):

from flax import nnx

state = nnx.State({"param1": nnx.Param(10.0)})
print(state)
# State({
#  'param1': Param(
#    value=10.0
#  )
# })

print(isinstance(state.param1, nnx.VariableState))
# False

mapped = jax.tree.map(
    lambda x: x.value, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
)
print(mapped)
# State({
#  'param1': Param(
#    value=10.0
#  )
# })

Actual (function is applied to Variables when selecting based on VariableState)

from flax import nnx

state = nnx.State({"param1": nnx.Param(10.0)})
print(state)
# State({
#  'param1': Param(
#    value=10.0
#  )

# })
mapped = jax.tree.map(
    lambda x: x.value, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)
)
print(mapped)
# State({
#  'param1': 10.0
# })

Is this behaviour intended?

Context: I found this by accident when doing the following. I need to apply a set of transforms to leaves of my State. These transforms modify/set various metadata of the leaves. Variable.replace allows you to modify metadata, however, VariableState.replace only allows modification of the value. As leaves of State can either be Variables or VariableState, I need to be able to filter between these two cases and apply different updates in each case (for Variable, I can do x.replace(transformed_value, **transformed_metadata), but for VariableState, I need to instantiate a new object, nnx.VariableState(variable_type, transformed_value, **transformed_metadata)).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions