Replies: 5 comments 2 replies
-
Another idea I had is to leverage jax's abstract evaluation to extract out all the shapes of the parameters, without doing any actual computation. This sounds like a much less hacky approach, but I'm not sure how feasible it is. |
Beta Was this translation helpful? Give feedback.
-
This problem is caused by compiling random init routines over-and-over. XLA inlines everything so code isn't shared between initializer calls that create the same shapes. A lazy init is indeed an option. There is a util in flax that can do this: https://flax.readthedocs.io/en/latest/flax.jax_utils.html#flax.jax_utils.partial_eval_by_shape Feel free to try and see if it speeds up the call |
Beta Was this translation helpful? Give feedback.
-
@jheek Thanks for pointing out that function. |
Beta Was this translation helpful? Give feedback.
-
Thanks for creating this issue @thisiscam! I agree the issues you mentioned do not explain it sufficiently, and I personally would like to understand this better as well. @jheek your solution of using
which, according to the comment on Module.init (which I think we added together), initializes the model lazily using only the shape of the arguments. So will XLA inline everything in this case as well, or will compiled code blocks be shared? (Which is what we want). |
Beta Was this translation helpful? Give feedback.
-
Maybe I should ask a rather more basic question: why should we use It seems like from the linked issues, However, if the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm experiencing very slow
jit
for a very small 3 layer Transformer model on TPU.Without
jit
(using op-by-op mode), I can initialize my parameters within a minute.I wonder why that's the case? Here's my guess: with
jit
, all the Dead Code Elimination is done by XLA compiler, which is known to be slow.I see that several related issues/PRs: #879 , #910 #1277, but it seems like none of them really documents how to resolve this slowness?
Here's my hacky solution:
jit
the initialization on CPU, then transfer to TPU.My guess is that because XLA is not well optimized (and thus doing crazy things) on CPU, the compilation is much much faster (< 3 mins):
It is a cumbersome solution, which derivation is mostly based on my guesses..
Could someone help me understand if there's any drawback with this approach?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions