Replies: 1 comment
-
Hey @Tharinda-EV, its a very nice benchmark!
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am running the same ResNet50 model with the same weights using Pytorch and Flax, but I am seeing quite a bad performance for Flax. And also when inferencing the PyTorch model, GPU utilization is around 100% but in the FLAX model, GPU utilization is about around 35%. I checked the arrays also, they are stored in GPU.
I have warmed the jit function, and I think I have put the model and data on the device in both cases. What am I doing wrong here? Is there anything optimized way to use jit functions in building FLAX models?
Notebook: https://colab.research.google.com/drive/1b486yGovsLLuGawwhFmp6r6YUNw6Eupo?usp=sharing
Environment info
pip install flax
)pip install --upgrade jax==0.4.4 jaxlib==0.4.4+cuda11.cudnn82 -f https://storage.googleapis.com/jax releases/jax_cuda_releases.html
Beta Was this translation helpful? Give feedback.
All reactions