Skip to content

How to train an NNX netowork using scan and jit? #4722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cdelv opened this issue Apr 20, 2025 · 8 comments
Open

How to train an NNX netowork using scan and jit? #4722

cdelv opened this issue Apr 20, 2025 · 8 comments

Comments

@cdelv
Copy link

cdelv commented Apr 20, 2025

Hi Flax team,

I'm writing as I've been struggling the past week trying to implement PPO for reinforcement learning using NNX. I'm at the point where I need some help from the experts.

I've been following the implementation of purejaxrl and the one in Brax. The thing is that these two use the old linen API. I wanted to move to the new NNX API but had lots of problems with the JIT. Before I start showing what I have, I also tried to follow the guidelines for performance in the docs and this reply by @puct9 #4045 (comment).

Ok then, first a simple example:

def rollout_trajectories(env: rl.Env, graphdef: nnx.GraphDef, gstate: nnx.GraphState, env_state: rl.EnvState, rng: jnp.ndarray, num_steps: int):
    model, _, _    = nnx.merge(graphdef, gstate)
    
    def _env_step(carry, _):
        env_state, rng = carry

        obs            = env.observation(env_state)
        pi, value      = model(obs)
        rng, subkey    = jax.random.split(rng)
        action         = pi.sample(seed=subkey)
        logp           = pi.log_prob(action)

        state_jdem, system_jdem, env_params_jdem = env.step(env_state, action)
        new_env_state  = rl.EnvState(state_jdem, system_jdem, env_params_jdem)
        reward         = env.reward(new_env_state)
        done           = env.done(new_env_state)
        info           = env.info(new_env_state)

        trans = TrajectoryData(done, action, value, reward, logp, obs, info)
        return (new_env_state, rng), trans

    return jax.lax.scan(_env_step, (env_state, rng), None, num_steps)

This code is very straightforward. Just use the scan to get information about the trajectories. The problem comes when I try to JIT it. No matter what I do: defining the nnx parameters as static, using nnx.jit capturing them in the scope, etc., the function always errors with something similar to TypeError: Argument 'args[0]' of shape float32[1,2] of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

Is there a way to pass/do inference inside a JITed function? I know that in the MINIST tutorial, it works, but I'm unable to replicate the same result.

Now, the important problem. Training using scan. Following the same problem I had before, no matter if I use nnx.scan or jax.lax.Scan, I'm having trouble passing the NNX object. This makes it impossible for me even to attempt to train the network. My code looks something like this:

    rng, sub  = jax.random.split(rng)
    network   = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(sub))
    tx        = optax.chain(
                    optax.clip_by_global_norm(max_grad_norm),
                    optax.adam(learning_rate, eps=1e-5)
                )
    optimizer = nnx.Optimizer(network, tx)
    metrics   = nnx.MultiMetric(loss=nnx.metrics.Average())
    graphdef, init_gstate = nnx.split((network, optimizer, metrics))

    # reset envs
    rng, sub       = jax.random.split(rng)
    reset_rngs     = jax.random.split(sub, num_envs)
    env_state      = env.reset(reset_rngs, env_params)

    train_state = (init_gstate, env_state, rng)

    def train_step(train_state, _):
        gstate, env_state, rng = train_state
        m, opt, met = nnx.merge(graphdef, gstate)

        (env_state, rng), trajectory_data = rollout_trajectories(env, graphdef, gstate, env_state, rng, num_steps)
        advantages, targets = calculate_general_advantage(trajectory_data, gamma, gae_lambda)


        # PPO logic

        return (nnx.state((m, opt, met)), env_state, rng), 0

I cant make this function work as long as I don't JIT rollout_trajectories and don't return or pass any of the nnx objects. The problem with this is that I am unable to update them. I could use a regular for loop to call train_step and only use scan for the minibatches, but I was hoping to achieve high-performance code. I tried using train_state, and that helped me make some progress, but in the end, I still encountered the same problem.

If you want to play with this code, clone this repo and execute this script inside. I think everything should work (you will get the nnx error with the JIT though).

I see that with the linen API this moving in and out of JIT land is not that hard. I would appreciate it if someone could explain to me how to do it. Additionally, if there is a recommended approach to these types of problems with NNX, I would appreciate it if you could point me to it.

Best regards.

@puct9
Copy link

puct9 commented Apr 21, 2025

I currently lack the capacity to answer this question very specifically and in depth, but I happen to have recently done something extremely similar (train RL PPO in Flax NNX, also featuring the scan trick). The code is in a single notebook here.

Apologies for being unable to answer the question in a more targeted manner, but I hope this helps.

@cdelv
Copy link
Author

cdelv commented Apr 21, 2025

Hi @puct9, your answer might be just what I needed. Thank you very much. I'll review the code and use it to make my own. I'll come back if I have any questions.

@cgarciae
Copy link
Collaborator

Hey @cdelv, it would be great to have a reproducible example we can run to simplify debugging. Based on what you showed its a little bit tricky to figure out. The only thing I can say is that passing model as a capture is generally not a good idea as any mutation would raise an error, it better to pass it as an explicit input to nnx.scan-ed function.

@cdelv
Copy link
Author

cdelv commented Apr 22, 2025

Hi @cgarciae, thank you for your response.

I wrote this standalone script that should reproduce the problem I'm talking about. I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.

I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.

I'll come back when I finish adapting @puct9's script.

@cdelv
Copy link
Author

cdelv commented Apr 22, 2025

Hi @cgarciae,

After trying different approaches, I found this example in the documentation using TrainState. I attempted to make the code work using TrainState, and it got me really far. However, the code fails at the end with TypeError: Argument 'Traced<ShapedArray(float32[1,2])>with<DynamicJaxprTrace>' of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type.

Here is the complete script: main.txt

There is an issue with passing the train state around, as it becomes a tracer that JAX cannot use. I am unable to get around it. I know that in linen this is possible. I don't understand why in nnx it does not work.

@puct9
Copy link

puct9 commented Apr 23, 2025

I was unable to run the provided script due to a potentially unrelated (?) issue -- ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field pos is not allowed: use default_factory (Python 3.11) so I don't think that I can directly address the problem this time, but I have a couple points anyway.

I haven't quite finished going through @puct9's script, but it looks like they use a regular for loop for the training loop and scan for the batches. What I'm trying to do is use a scan for everything.

I updated the train loop to perform any number of iterations at a time 😉

Now more on the actual question

I want to have the highest performance possible. I aprecciate any sugestions on how to achive this.

I notice a small edit you've attempted to make from what I assume was the reference material:

-    def train(rng):
+    @jax.jit
+    def train(key):

And your wish for my original notebook's training loop to be in a scan.

Hence, I think the point of performance and why we want to sometimes jit has been somewhat missed. For example:

  1. In extremely quick train loops, it can be insufficient to simply jit the function that applies a single update. The reason is because we will spend more time traversing between the compiled function and Python. In this case, due to the more computationally intensive nature of needing to run the environment, computing losses, and updating the network, it's unlikely we can achieve more than a few dozen iterations per second. This means that the theoretical difference in performance between repeatedly calling a jitted function from Python and having it in something like a scan decreases to a point where it is negligible.
  2. Jax's performance certainly is solid, but it's not magic. The greatest gain from being able to describe the environment in Jax is the level of convenience for the good performance. The only other competitive alternative (actually superior in this case) is probably to handle the environment and a lot of what Jax does yourself in a low-level language like C++ at a great inconvenience. Likewise, the expectation that going beyond jitting a significant portion of the train loop (e.g., the train step function) and jitting the train loop itself will bring much performance benefit will likely result in disappointment and a nonnegligible amount of wasted time. If you want to understand where you're gaining/losing time, consider profiling the program.
  3. Whether deliberate or by coincidence, I didn't find a single instance of nnx.scan or nnx.jit in your script. There is a difference between the nnx and jax versions of those functions, such as being able to pass nnx.Module and nnx.Rngs objects correctly. I will make no assumptions about your intent, but "don't use nnx.jit because it's slower" would certainly be the wrong lesson to take from Significant performance difference of NNX relative to equinox #4045. The source of the slowdown is the traversal between Python and a function compiled with nnx.jit, and even that cost is negligible as long as the operation is sufficiently expensive (which in this case, it is).

This also isn't to say that if you remove jit things are much more likely to work. There are a number of other dubious practices some of which I'm not sure are meant to work at all:

  1. Constructing nnx.Rngs from a vanilla Jax key
            key, subkey = jax.random.split(key)
            network = SharedActorCritic(env.observation_space, env.action_space, nnx.Rngs(subkey))
  2. Use of TrainState. I'd recommend attempting the method shown in the Flax NNX vs JAX transformations (click on JAX transforms) or here as the ideal way of using objects originating from Flax or NNX in places like jax.jitted functions. Though, your hand may be forced -- Improve support for optax.GradientTransformationExtraArgs in NNX #4545 (comment), since optax.chain returns GradientTransformationExtraArgs.

Regarding points [1, 3], I've personally observed a complete lack of difference in performance in my script (attached previously) after lowering the train loop into a scan.

@cdelv
Copy link
Author

cdelv commented Apr 23, 2025

Hi @puct9, Thank you for getting back to me.

Indeed, the error you see is the main reason why I opened this issue. No matter what I did, I was unable to escape from it. That's why I tried the usual NNX way, the functional API, and the legacy TrainState before giving up and opening the issue. However, I just discovered that the source of the error was completely unrelated to the training function or to jax.lax.scan, jit or nnx.scan. Turns out that the culprit was my actor critic implementation:

class SharedActorCritic(nnx.Module):
    def __init__(self, 
        observation_space: int, 
        action_space: int, 
        key: nnx.Rngs, 
        architecture: List[int] = [64, 64], 
        in_scale: float = jnp.sqrt(2), 
        actor_scale: float = 1.0, 
        critic_scale: float = 0.01,
        activation = nnx.relu
    ):
        layers = []
        input_dim = observation_space

        for output_dim in architecture:
            layers.append(
                nnx.Linear(
                    in_features=input_dim,
                    out_features=output_dim,
                    kernel_init=nnx.initializers.orthogonal(in_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
            )
            layers.append(activation)  
            input_dim = output_dim

        self.network = nnx.Sequential(*layers)
        self.actor = nnx.Linear(
                    in_features=input_dim,
                    out_features=action_space,
                    kernel_init=nnx.initializers.orthogonal(actor_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
        self.critic = nnx.Linear(
                    in_features=input_dim,
                    out_features=1,
                    kernel_init=nnx.initializers.orthogonal(critic_scale),
                    bias_init=nnx.initializers.constant(0.0),
                    rngs=key
                )
        self.log_std = nnx.Param(jax.random.uniform(key.params(), (1, action_space,)))

    def __call__(self, x):
        x = self.network(x)
        pi = distrax.MultivariateNormalDiag(self.actor(x), jnp.exp(self.log_std.value))  # the error was here, .value is required
        return pi, self.critic(x)

When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue. I discovered that removing all the scans and jited functions until the error became clear.

Thank you for your comments. I will implement them. I spent some time writing a pure JAX physics engine that is quite fast to create my environments. So, using @jax.jit instead of nnx is just a way to keep everything consistent. But I wasn't trying to avoid @nnx.jit. The script came out that way after many modifications.

About creating the rng key from a jax.key. I just did that to define only one key. Didn't know it was a dubious practice. What do you recommend for key handling, as I have other things that require a random key?

Please let me know if you have more recommendations. I would love to hear them. Also, thanks for the update in the notebook. I'll be checking it out.

Best regards.

@cgarciae
Copy link
Collaborator

When calling the probability distribution, I had to call .value in the log_std parameter. If you change this, the script that uses train state should work. I'm not entirely sure why, but that seems to resolve the issue.

@cdelv the reason most likely is that nnx.Variable (e.g. nnx.Param) implements the __jax_array__ protocol so JAX functions treat it as an Array. The main issue is that the protocol is not fully supported so you might get errors from time to time so we're evaluation is we should continue using it or have users explicitly access the .value or use other syntax like [...].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants