Skip to content

Flax: How to save a model and then reuse in Flax? #2072

Answered by andsteing
varunsingh88 asked this question in Q&A
Discussion options

You must be logged in to vote

We usually use the module flax.training.checkpoints for this.

You can see usage in e.g. our examples/imagenet/train.py:

def restore_checkpoint(state, workdir):
return checkpoints.restore_checkpoint(workdir, state)
def save_checkpoint(state, workdir):
if jax.process_index() == 0:
# get train state from the first replica
state = jax.device_get(jax.tree_map(lambda x: x[0], state))
step = int(state.step)
checkpoints.save_checkpoint(workdir, state, step, keep=3)

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@mtavit
Comment options

@chiamp
Comment options

chiamp Mar 20, 2023
Collaborator

@mtavit
Comment options

@chiamp
Comment options

chiamp Mar 22, 2023
Collaborator

@mtavit
Comment options

Answer selected by varunsingh88
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants