Freezing filtered parameter collections #4714
Unanswered
Jacob-Worrell
asked this question in
Q&A
Replies: 1 comment
-
The missing link for me was nnx.DiffState. This is described in documentation for nnx.grad() on the nnx "transforms" page: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html Anyway, effectively the only changes that need be made for the code to work as intended are:
And that does it! I realized the precise nature of my problem upon returning to issue #4167 - The following is a corrected version of the example code:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to work out how to do transfer learning with nnx. Below is my attempt to freeze the kernel of an nnx.Linear instance and optimize only the bias. I think maybe I'm not correctly setting up the 'wrt' argument to my optimizer.
Possibly of note, the parameter filtering here is inspired by the response to issue #4167.
However, this results in the following error:
ValueError: Mismatch custom node data: ('bias', 'kernel') != ('bias',); value: State({ 'bias': VariableState( type=Param, value=Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)> ) }).
Beta Was this translation helpful? Give feedback.
All reactions