-
Notifications
You must be signed in to change notification settings - Fork 703
How to train an NNX netowork using scan and jit? #4722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I currently lack the capacity to answer this question very specifically and in depth, but I happen to have recently done something extremely similar (train RL PPO in Flax NNX, also featuring the Apologies for being unable to answer the question in a more targeted manner, but I hope this helps. |
Hi @puct9, your answer might be just what I needed. Thank you very much. I'll review the code and use it to make my own. I'll come back if I have any questions. |
Hey @cdelv, it would be great to have a reproducible example we can run to simplify debugging. Based on what you showed its a little bit tricky to figure out. The only thing I can say is that passing |
Hi @cgarciae, thank you for your response. I wrote this standalone script that should reproduce the problem I'm talking about. I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything. I want to have the highest performance possible. I aprecciate any sugestions on how to achive this. I'll come back when I finish adapting @puct9's script. |
Hi @cgarciae, After trying different approaches, I found this example in the documentation using TrainState. I attempted to make the code work using TrainState, and it got me really far. However, the code fails at the end with Here is the complete script: main.txt There is an issue with passing the train state around, as it becomes a tracer that JAX cannot use. I am unable to get around it. I know that in linen this is possible. I don't understand why in nnx it does not work. |
I was unable to run the provided script due to a potentially unrelated (?) issue --
I updated the train loop to perform any number of iterations at a time 😉 Now more on the actual question
I notice a small edit you've attempted to make from what I assume was the reference material: - def train(rng):
+ @jax.jit
+ def train(key): And your wish for my original notebook's training loop to be in a Hence, I think the point of performance and why we want to sometimes
This also isn't to say that if you remove
Regarding points [1, 3], I've personally observed a complete lack of difference in performance in my script (attached previously) after lowering the train loop into a |
Hi @puct9, Thank you for getting back to me. Indeed, the error you see is the main reason why I opened this issue. No matter what I did, I was unable to escape from it. That's why I tried the usual NNX way, the functional API, and the legacy TrainState before giving up and opening the issue. However, I just discovered that the source of the error was completely unrelated to the training function or to jax.lax.scan, jit or nnx.scan. Turns out that the culprit was my actor critic implementation: class SharedActorCritic(nnx.Module):
def __init__(self,
observation_space: int,
action_space: int,
key: nnx.Rngs,
architecture: List[int] = [64, 64],
in_scale: float = jnp.sqrt(2),
actor_scale: float = 1.0,
critic_scale: float = 0.01,
activation = nnx.relu
):
layers = []
input_dim = observation_space
for output_dim in architecture:
layers.append(
nnx.Linear(
in_features=input_dim,
out_features=output_dim,
kernel_init=nnx.initializers.orthogonal(in_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
)
layers.append(activation)
input_dim = output_dim
self.network = nnx.Sequential(*layers)
self.actor = nnx.Linear(
in_features=input_dim,
out_features=action_space,
kernel_init=nnx.initializers.orthogonal(actor_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
self.critic = nnx.Linear(
in_features=input_dim,
out_features=1,
kernel_init=nnx.initializers.orthogonal(critic_scale),
bias_init=nnx.initializers.constant(0.0),
rngs=key
)
self.log_std = nnx.Param(jax.random.uniform(key.params(), (1, action_space,)))
def __call__(self, x):
x = self.network(x)
pi = distrax.MultivariateNormalDiag(self.actor(x), jnp.exp(self.log_std.value)) # the error was here, .value is required
return pi, self.critic(x) When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue. I discovered that removing all the scans and jited functions until the error became clear. Thank you for your comments. I will implement them. I spent some time writing a pure JAX physics engine that is quite fast to create my environments. So, using @jax.jit instead of nnx is just a way to keep everything consistent. But I wasn't trying to avoid @nnx.jit. The script came out that way after many modifications. About creating the rng key from a jax.key. I just did that to define only one key. Didn't know it was a dubious practice. What do you recommend for key handling, as I have other things that require a random key? Please let me know if you have more recommendations. I would love to hear them. Also, thanks for the update in the notebook. I'll be checking it out. Best regards. |
@cdelv the reason most likely is that |
Hi Flax team,
I'm writing as I've been struggling the past week trying to implement PPO for reinforcement learning using NNX. I'm at the point where I need some help from the experts.
I've been following the implementation of purejaxrl and the one in Brax. The thing is that these two use the old linen API. I wanted to move to the new NNX API but had lots of problems with the JIT. Before I start showing what I have, I also tried to follow the guidelines for performance in the docs and this reply by @puct9 #4045 (comment).
Ok then, first a simple example:
This code is very straightforward. Just use the scan to get information about the trajectories. The problem comes when I try to JIT it. No matter what I do: defining the nnx parameters as static, using nnx.jit capturing them in the scope, etc., the function always errors with something similar to
TypeError: Argument 'args[0]' of shape float32[1,2] of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
Is there a way to pass/do inference inside a JITed function? I know that in the MINIST tutorial, it works, but I'm unable to replicate the same result.
Now, the important problem. Training using scan. Following the same problem I had before, no matter if I use nnx.scan or jax.lax.Scan, I'm having trouble passing the NNX object. This makes it impossible for me even to attempt to train the network. My code looks something like this:
I cant make this function work as long as I don't JIT
rollout_trajectories
and don't return or pass any of the nnx objects. The problem with this is that I am unable to update them. I could use a regular for loop to call train_step and only use scan for the minibatches, but I was hoping to achieve high-performance code. I tried using train_state, and that helped me make some progress, but in the end, I still encountered the same problem.If you want to play with this code, clone this repo and execute this script inside. I think everything should work (you will get the nnx error with the JIT though).
I see that with the linen API this moving in and out of JIT land is not that hard. I would appreciate it if someone could explain to me how to do it. Additionally, if there is a recommended approach to these types of problems with NNX, I would appreciate it if you could point me to it.
Best regards.
The text was updated successfully, but these errors were encountered: