You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to assess use of flax and jax for a project to steer whether or not we want to use it.
I have a doc2vec corpus of about 150 million words, and I discovered that it converges much better when the batch sizes are very small (batch size <=100). However, with torch (GPU) and such small batch sizes, most of the time is spent in the Python interpreter in the flow control, even when the data and model exists fully on GPU.
In other words, the model training (gradient computation and updates) are fast on GPU in torch and I've eliminated almost all CPU <-> GPU data transfer, but my wall clock times are way to slow because of flow control being in the Python interpreter.
My question is whether flax/jax and the underlying primitives can help with this type of speedup across batches (not just speedups with a batch). Any advice on a quick "maybe"/"definitely not" would help as well as what features of whether these tools would help solve this problem.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to assess use of flax and jax for a project to steer whether or not we want to use it.
I have a doc2vec corpus of about 150 million words, and I discovered that it converges much better when the batch sizes are very small (batch size <=100). However, with torch (GPU) and such small batch sizes, most of the time is spent in the Python interpreter in the flow control, even when the data and model exists fully on GPU.
In other words, the model training (gradient computation and updates) are fast on GPU in torch and I've eliminated almost all CPU <-> GPU data transfer, but my wall clock times are way to slow because of flow control being in the Python interpreter.
My question is whether flax/jax and the underlying primitives can help with this type of speedup across batches (not just speedups with a batch). Any advice on a quick "maybe"/"definitely not" would help as well as what features of whether these tools would help solve this problem.
Beta Was this translation helpful? Give feedback.
All reactions