Implementing Weight Decay Masking with nnx.Optimizer and Filters #4737
Unanswered
cadazar
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I have been using the NNX API to train a custom language model, needing to apply weight decay only to certain parameters (excluding bias/normalization/embeddings) while using
nnx.Optimizer
. I found that creating twonnx.Optimizer
instances with different base optimizers (e.g.optax.adamw
for WD,optax.adam
for no WD) and applying them to filtered parameter subsets usingnnx.filterlib
works well without any significant overhead:Hope this can come of use to someone out there.
Beta Was this translation helpful? Give feedback.
All reactions