jax.Array must be fully replicated to be saved in aggregate file #3143
Unanswered
tatami-galaxy
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hey @tatami-galaxy, can you give a minimal example? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to save a checkpoint and getting this error message. Saving code :
ckpt = {'state': state, 'config': model.config} save_args = orbax_utils.save_args_from_target(ckpt) checkpoint_manager.save(global_step + 1, ckpt, save_kwargs={'save_args': save_args})
state
is an instance offlax.training.train_state
. What could be causing this? I tried disablingjax.Array
withjax.config.update('jax_array', False)
but that does not work with jax and jaxlib 0.4.7.Beta Was this translation helpful? Give feedback.
All reactions