Force no split in make_rng
#3113
Unanswered
zaccharieramzi
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have the following situation: I am using a
Dropout
layer multiple times without ann.scan
ornn.while_loop
, therefore I cannot usesplit_rngs={"dropout": False}
.However, I would still like to use the same dropout mask twice.
Is it possible to specify "no split" to make rng for certain collections?
If I just take the original dropout example I would like to do something like:
and still have
jnp.sum(y == 0.) / (3*4*3) == 0.5
approx.For more context I am actually trying to implement Deep Equilibrium Models using
jaxopt
andflax
, where the fixed point defining function uses dropout.I also tried to see if the
split_rngs
functionality could be extended tojaxopt
but I think it's going to be difficult.Beta Was this translation helpful? Give feedback.
All reactions