nnx.value_and_grad fail to update gradient #4678
Unanswered
LIANHAN3920
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think the batch data should also be passed to loss_fn ? @jit
def train_step(state: TrainState, batch):
def loss_fn(params, batch):
x, y = batch
model = nnx.merge(state.graphdef, params)
preds = model(x)
loss = jnp.mean((preds - y) ** 2)
return loss
# grads = jax.grad(loss_fn)(state.params, batch) <---------working
loss, grads = nnx.value_and_grad(loss_fn)(state.params, batch) <------------not working
state = state.apply_gradients(grads=grads)
return state |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
In this code, when I use nnx.value_and_grad to compute gradient of loss funtion, it fails to update at every iteration, but jax.grad works
Beta Was this translation helpful? Give feedback.
All reactions