You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi - thanks for the library! I'm currently running into some issues when trying to train a "linearised" NN model. In my setup I'm taking a standard NN with weights $\theta^*$, let's denote $f(\cdot;\theta^*)$, and then forming a "linearised" version of it by taking a 1st-order Taylor expansion about $\theta^*$. Mathematically:
I then want to train this "linearised" model, keeping the original weights $\theta^*$ fixed and only updating weights $\theta$ (or alternatively updating the "weight update" $\tau = \theta - \theta^*$). For freezing the weights $\theta^*$ I have tried to adopt the approach mentioned in #4167. I'm able to perform a forward pass with my model fine, but as soon as I try to train the model, I run into issues. An example is as follows:
TypeError: Argument 'Param(
value=Traced<ConcreteArray([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]], dtype=float32)
tangent = Traced<ShapedArray(float32[1,20])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[1,20]), None)
recipe = LambdaBinding()
)' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
and from the stack trace it seems to be stemming from the jax.jvp() call.
Can you see what I'm doing wrong? I'm happy to provide more details if necessary. A PyTorch version of what I'm trying to achieve can be found here. Thanks in advance!
The text was updated successfully, but these errors were encountered:
Hey! The issue is that jax.jvp doesn't handle nnx.Variable instances (like Param) correctly. To fix it call nnx.state before passing them, I tested this runs on your code:
Great, thanks, that worked! And yeah, I think having nnx.jvp would probably be a bit more intuitive - I actually searched to see if this existed whilst debugging.
Hi - thanks for the library! I'm currently running into some issues when trying to train a "linearised" NN model. In my setup I'm taking a standard NN with weights$\theta^*$ , let's denote $f(\cdot;\theta^*)$ , and then forming a "linearised" version of it by taking a 1st-order Taylor expansion about $\theta^*$ . Mathematically:
I then want to train this "linearised" model, keeping the original weights$\theta^*$ fixed and only updating weights $\theta$ (or alternatively updating the "weight update" $\tau = \theta - \theta^*$ ). For freezing the weights $\theta^*$ I have tried to adopt the approach mentioned in #4167. I'm able to perform a forward pass with my model fine, but as soon as I try to train the model, I run into issues. An example is as follows:
The error I'm getting is:
and from the stack trace it seems to be stemming from the
jax.jvp()
call.Can you see what I'm doing wrong? I'm happy to provide more details if necessary. A PyTorch version of what I'm trying to achieve can be found here. Thanks in advance!
The text was updated successfully, but these errors were encountered: