Is there a way to give pytree state to a NNX module? #4687
-
I want to include the dense solution of a diffrax ODE in a NNX module. Since this is a pytree object with arrays, nnx.jit complains about array root leafs. Is there a way to tell nnx to simply ignore this? |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Apr 5, 2025
Replies: 1 comment 3 replies
-
Hi @wittenator, Array leaves are already supported as of #4612, please update flax to the latest version. |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
wittenator
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @wittenator, Array leaves are already supported as of #4612, please update flax to the latest version.