Replies: 1 comment
-
This is something that fundamentally JAX is unable to do at the moment due to XLA being a static shape compiler. There are 2 options to work around this limitation:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
TL;DR: How to implement a stochastic Flax Module where the output shape is not deterministic?
I am trying to implement a stochastic Flax Module. For the sake of simplicity, I want to implement a module that performs the following:
I first implement something like this:
To initialize the module and applying it I did the following:
To call the model:
This works fine when calling the Module with different RNGs.
However, 'jitting' module didn't work:
(1) first approach was to directly jit the module as suggested in this issue
But I get the following error:
(2) From the error I understand that I need to use jax.lax control-flow operations. But since the output shape is different, then I cannot simply use jax.lax.cond. See my attempt below:
But this doesn't work since the output shape for both branches should be the same. I receive the following error:
NOTE: In my case I want to use pmap - but I figured since the error is due to compilation issues so resolving this for jit would resolve my issue.
Thanks in advance for any help
Beta Was this translation helpful? Give feedback.
All reactions