How to correctly update parameters wrt to other model #2835
Unanswered
averageFlaxUser
asked this question in
Q&A
Replies: 1 comment 4 replies
-
At least in this example I think it should be grad_fn = jax.value_and_grad(outer_loss, has_aux=True)
(loss, model_B), grads = grad_fn(model_A.params)
... You need to reassign |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I am trying to validate my implementation of a bi-level optimization implementation I have. My goal is the following:
"Some
model_A
that produces output vectorz
, which is used as a regularization term in theloss_fn
of somemodel_B
that is trained on some data. I am looking to optimize the params ofmodel_A
with respect to the performance ofmodel_B
."My pseudo-Jax/Flax implementation:
Train loop, model initialization and other items are done the same way as in the Flax Getting Started Guide.
Beta Was this translation helpful? Give feedback.
All reactions